开发者

Python: Optimizing a tree evaluator

开发者 https://www.devze.com 2022-12-18 19:38 出处:网络
I know tree is a well studied structure. I\'m writing a program that randomly generates many expression trees and then sorts and selects by a fitness attribute.

I know tree is a well studied structure.

I'm writing a program that randomly generates many expression trees and then sorts and selects by a fitness attribute.

I have a class MakeTreeInOrder() that turns the tree into a string that 'eval' can evaluate.

but it gets called many many times, and should be optimized for time.

below is a function that builds a tree that adds successive numbers to use as a test.

I was 开发者_运维技巧wondering if there is an optimized way to evaluate an expression that is in a tree structure. I figured that

that it's used quite a bit and somebodys already done this.

import itertools
from collections import namedtuple 

#Further developing Torsten Marek's second suggestion

KS = itertools.count()
Node = namedtuple("Node", ["cargo", "args"]) 

def build_nodes (depth = 5):     
    if (depth <= 0): 
        this_node = Node((str(KS.next())), [None, None])
        return this_node 
    else:
        this_node = Node('+', []) 
        this_node.args.extend( 
          build_nodes(depth = depth - (i + 1))                             
          for i in range(2)) 

        return this_node

The following is the code that I think can be made a lot faster. And I was hoping for some ideas.

class MakeTreeInOrder(object):
    def __init__(self, node):
        object.__init__(self)
        self.node = node
        self.str = ''
    def makeit(self, nnode = ''):
        if nnode == '':
            nnode = self.node
        if nnode == None: return
        self.str +='('
        self.makeit(nnode.args[0])
        self.str += nnode.cargo
        self.makeit(nnode.args[1])
        self.str+=')'
        return self.str

def Main():
    this_tree = build_nodes()
    expression_generator = MakeTreeInOrder(this_tree)
    this_expression = expression_generator.makeit()
    print this_expression
    print eval(this_expression)

if __name__ == '__main__':
    rresult = Main()


Adding a touch of object orientation here makes things simpler. Have subclasses of Node for each thing in your tree, and use an 'eval' method to evaluate them.

import random

class ArithmeticOperatorNode(object):
    def __init__(self, operator, *args):
        self.operator = operator
        self.children = args
    def eval(self):
        if self.operator == '+':
            return sum(x.eval() for x in self.children)
        assert False, 'Unknown arithmetic operator ' + self.operator
    def __str__(self):
        return '(%s)' % (' ' + self.operator + ' ').join(str(x) for x in self.children)

class ConstantNode(object):
    def __init__(self, constant):
        self.constant = constant
    def eval(self):
        return self.constant
    def __str__(self):
        return str(self.constant)

def build_tree(n):
    if n == 0:
        return ConstantNode(random.randrange(100))
    else:
        left = build_tree(n - 1)
        right = build_tree(n - 1)
        return ArithmeticOperatorNode('+', left, right)

node = build_tree(5)
print node
print node.eval()

To evaluate the tree, just call .eval() on the top level node.

node = build_tree(5)
print node.eval()

I also added a __str__ method to convert the tree to a string so you can see how this generalizes to other tree functions. It just does '+' at the moment, but hopefully it's clear how to extend this to the full range of arithmetic operations.


Your example imports numpy and random but never uses them. It also has a "for i in range(2))" with no body. This is clearly not valid Python code.

You don't define what 'cargo' and the nodes are supposed to contain. It appears that 'cargo' is a number, since it comes from itertools.count().next(). But that makes no sense since you want the result to be a eval'able Python string.

If you are doing a one-time evaluation of the tree then the fastest solution would be to evaluate it directly in-place, but without an actual example of the data you're working with, I can't show an example.

If you want to make it even faster then look further upstream. Why do you generate the tree and then evaluate it? Can't you evaluate the components directly in the code which currently generates the tree structure? If you have operators like "+" and "*" then consider using operator.add and operator.mul instead, which can work directly on the data without using an intermediate step.

==update==

This builds on Paul Hankin's answer. What I've done is taken away the intermediate tree structure and just evaluate the expression directly.

def build_tree2(n):
    if n == 0:
        return random.randrange(100)
    else:
        left = build_tree2(n-1)
        right = build_tree2(n-1)
        return left+right

That clocks at about 5 times faster than Paul's solution.

It may be that you need to know the actual tree structure of the best solution, or the top k of N, where k << N. If that's the case then you can post-hoc regenerate those trees if you also keep track of the RNG state used to generate the results. For example:

def build_tree3(n, rng=random._inst):
    state = rng.getstate()
    return _build_tree3(n, rng.randrange), state

def _build_tree3(n, randrange):
    if n == 0:
        return randrange(100)
    else:
        left = _build_tree3(n-1, randrange)
        right = _build_tree3(n-1, randrange)
        return left+right

Once you've found the best values, use the key to regenerate the tree

# Build Paul's tree data structure given a specific RNG
def build_tree4(n, rng):
    if n == 0:
        return ConstantNode(rng.randrange(100))
    else:
        left = build_tree4(n-1, rng)
        right = build_tree4(n-1, rng)
        return ArithmeticOperatorNode("+", left, right)

# This is a O(n log(n)) way to get the best k.
# An O(k log(k)) time solution is possible.
rng = random.Random()
best_5 = sorted(build_tree3(8, rng) for i in range(10000))[:5]
for value, state in best_5:
    rng.setstate(state)
    tree = build_tree4(8, rng)
    print tree.eval(), "should be", value
    print "  ", str(tree)[:50] + " ..."

Here's what it looks like when I run it

10793 should be 10793
   ((((((((92 + 75) + (35 + 69)) + ((39 + 79) + (6 +  ...
10814 should be 10814
   ((((((((50 + 63) + (6 + 21)) + ((75 + 98) + (76 +  ...
10892 should be 10892
   ((((((((51 + 25) + (5 + 32)) + ((40 + 71) + (17 +  ...
11070 should be 11070
   ((((((((7 + 83) + (77 + 56)) + ((16 + 29) + (2 + 1 ...
11125 should be 11125
   ((((((((69 + 80) + (11 + 64)) + ((33 + 21) + (95 + ...
0

精彩评论

暂无评论...
验证码 换一张
取 消

关注公众号