In [1]:
import ast

In [2]:
def translate(node):
    if isinstance(node, ast.BinOp):
        return translate_bin_op(node)
    elif isinstance(node, ast.UnaryOp):
        return translate_unary_op(node)
    elif isinstance(node, ast.Num):
        return translate_num(node)
    elif isinstance(node, ast.Name):
        return translate_name(node)
    elif isinstance(node, ast.BoolOp):
        return translate_bool_op(node)
    elif isinstance(node, ast.Compare):
        return translate_compare(node)
    else:
        raise ValueError('Node {} is not supported'.format(type(node)))

In [3]:
def translate_name(node):
    return node.id

In [4]:
def translate_num(node):
    return str(node.n)

In [5]:
def add_brackets(expression):
    return '\\left( {expression} \\right)'.format(
        expression=expression
    )

In [6]:
def is_add(node):
    return isinstance(node, ast.BinOp) and (
        isinstance(node.op, ast.Add) or isinstance(node.op, ast.Sub))

In [7]:
def translate_unary_op(node):
    expression = translate(node.operand)
    if isinstance(node.op, ast.Not):
        return '\\overline{{{expression}}}'.format(
            expression=expression
        )
    
    if isinstance(node.op, ast.UAdd):
        sign = '+'
    elif isinstance(node.op, ast.USub):
        sign = '-'
    else:
        raise ValueError('Node {} is not supported'.format(type(node)))
    
    if is_add(node.operand):
        expression = add_brackets(expression)
    
    return '{sign} {expression}'.format(
        sign=sign,
        expression=expression
    )

In [8]:
def translate_bin_op(node):
    left = translate(node.left)
    right = translate(node.right)
    if isinstance(node.op, ast.Mult):
        if is_add(node.left):
            left = add_brackets(left)
        if is_add(node.right):
            right = add_brackets(right)
        return '{left} \\cdot {right}'.format(
            left=left,
            right=right
        )
    elif isinstance(node.op, ast.Div):
        return '\\frac{{{left}}}{{{right}}}'.format(
            left=left,
            right=right
        )
    elif isinstance(node.op, ast.Add):
        return '{left} + {right}'.format(
            left=left,
            right=right
        )
    elif isinstance(node.op, ast.Sub):
        return '{left} - {right}'.format(
            left=left,
            right=right
        )
    elif isinstance(node.op, ast.Pow):
        if is_add(node.left):
            left = add_brackets(left)
        return '{{{left}}} ^ {{{right}}}'.format(
            left=left,
            right=right
        )
    else:
        raise ValueError('Node {} is not supported'.format(type(node)))
        

In [9]:
COMPARE_OPS = {
    ast.Lt: '<',
    ast.LtE: '\\leq',
    ast.Gt: '>',
    ast.GtE: '\\geq',
    ast.Eq: '=',
}

In [10]:
def compare_op_sign(op):
    try:
        return COMPARE_OPS[type(op)]
    except KeyError:
        raise ValueError('Operator {} is not supported'.format(type(op)))

In [11]:
def translate_compare(node):
    args = [translate(node.left)]
    for op, arg in zip(node.ops, node.comparators):
        args.append(compare_op_sign(op))
        args.append(translate(arg))
    return ' '.join(args)

In [12]:
def bool_op_sign(op):
    if isinstance(op, ast.And):
        return '\\wedge'
    elif isinstance(op, ast.Or):
        return '\\vee'
    else:
        raise ValueError('Operator {} is not supported'.format(type(op)))

In [13]:
def translate_bool_op(node):
    sign = ' {sign} '.format(
        sign=bool_op_sign(node.op)
    )
    args = []
    for arg in node.values:
        expression = translate(arg)
        if not isinstance(arg, ast.Name) and \
            not isinstance(arg, ast.Num) and \
            not (
                isinstance(arg, ast.UnaryOp) and \
                type(arg.op) == ast.Not):
            expression = add_brackets(expression)
        args.append(expression)
    return sign.join(args)

In [14]:
tree = ast.parse('-(x + 2) * y * z + 3 * (2 + v) ** 2 / u')
print(translate(tree.body[0].value))

- \left( x + 2 \right) \cdot y \cdot z + \frac{3 \cdot {\left( 2 + v \right)} ^ {2}}{u}


$$- \left( x + 2 \right) \cdot y \cdot z + \frac{3 \cdot {\left( 2 + v \right)} ^ {2}}{u}$$

In [15]:
tree = ast.parse('not A >= B and 0 < C <= 1')
print(translate(tree.body[0].value))

\overline{A \geq B} \wedge \left( 0 < C \leq 1 \right)


$$\overline{A \geq B} \wedge \left( 0 < C \leq 1 \right)$$