예제 #1
0
def veval_ast_unary_op(astc: 'AstContext', local_field: 'values.Field',
                       graph: 'graphs.Graph'):
    """
    eval unary operation.
    Ex. -xx
    """
    assert (isinstance(astc.nast, gast.gast.UnaryOp))
    lineprop = utils.LineProperty(astc.lineno)

    unaryop = nodes.UnaryOpType.Unknown
    if isinstance(astc.nast.op, gast.UAdd):
        unaryop = nodes.UnaryOpType.UAdd
    if isinstance(astc.nast.op, gast.USub):
        unaryop = nodes.UnaryOpType.USub
    if isinstance(astc.nast.op, gast.Not):
        unaryop = nodes.UnaryOpType.Not

    operand = veval_ast(astc.c(astc.nast.operand), local_field, graph)
    operand_value = try_get_value(operand, 'unary', lineprop)

    node = nodes.NodeUnaryOp(operand_value, unaryop)

    ret_value = veval_unary.veval(unaryop, operand_value)

    node.set_outputs([ret_value])
    graph.add_node(node)

    return values.ValueRef(ret_value)
예제 #2
0
def veval_ast_tuple(astc: 'AstContext',
                    local_field: 'values.Field',
                    graph: 'Graph',
                    option: 'VEvalOption' = None):
    assert (isinstance(astc.nast, gast.gast.Tuple))
    lineprop = utils.LineProperty(astc.lineno)

    if option is not None and option.eval_as_written_target:
        vs = []
        for v in astc.nast.elts:
            a_ = veval_ast(astc.c(v), local_field, graph, option=option)
            vs.append(a_)
        return vs
    else:
        vs_ref = []
        vs = []

        for v in astc.nast.elts:
            a_ = veval_ast(astc.c(v), local_field, graph, option=option)
            v_ = try_get_ref(a_, 'tuple', lineprop)
            vs_ref.append(v_)
            vs.append(v_.get_value())
            v_.in_container = True

        tuple_value = values.TupleValue(vs_ref)

        node = nodes.NodeGenerate('Tuple', vs, line=lineprop)
        node.set_outputs([tuple_value])
        graph.add_node(node)

        return values.ValueRef(tuple_value)
예제 #3
0
    def vcall(self,
              module: 'Field',
              graph: 'Graph',
              inst: 'values.ValueRef',
              args: 'functions.FunctionArgInput',
              line=-1):
        assert (inst is None)

        funcArgs = self.args.merge_inputs(inst, args)
        vargs = funcArgs.get_value().inputs

        dtype_value = vargs[1]
        if dtype_value is not None and not isinstance(dtype_value,
                                                      values.NoneValue):
            # TODO : make better
            dtype = utils.int_2_numpy_type(dtype_value.internal_value)
        else:
            dtype = np.array(vargs[1].internal_value).dtype

        node = nodes.NodeGenerate('zeros', funcArgs, line)
        graph.add_node(node)
        value = values.TensorValue()
        value.dtype = dtype
        value.name = '@F.{}.{}'.format(line, self.name)
        node.set_outputs([value])
        return values.ValueRef(value)
예제 #4
0
def veval_ast_bin_op(astc: 'AstContext', local_field: 'values.Field',
                     graph: 'Graph'):
    """
    eval binary operation.
    Ex. a + b, b // c, etc
    """
    assert (isinstance(astc.nast, gast.gast.BinOp))
    lineprop = utils.LineProperty(astc.lineno)

    left = veval_ast(astc.c(astc.nast.left), local_field, graph)
    right = veval_ast(astc.c(astc.nast.right), local_field, graph)

    left_value = try_get_value(left, 'compare', lineprop)
    right_value = try_get_value(right, 'compare', lineprop)

    binop = nodes.BinOpType.Unknown
    if isinstance(astc.nast.op, gast.Add):
        binop = nodes.BinOpType.Add
    if isinstance(astc.nast.op, gast.Sub):
        binop = nodes.BinOpType.Sub
    if isinstance(astc.nast.op, gast.Mult):
        binop = nodes.BinOpType.Mul

    node_bin_op = nodes.NodeBinOp(left_value, right_value, binop, astc.lineno)

    ret_value = veval_bin.veval(binop, left_value, right_value)

    node_bin_op.set_outputs([ret_value])
    graph.add_node(node_bin_op)

    return values.ValueRef(ret_value)
 def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1):
     node = nodes.NodeGenerate(
         'range', [v.get_value() for v in args.inputs], line)
     graph.add_node(node)
     value = values.RangeValue()
     value.name = '@F.{}.{}'.format(line, self.name)
     node.set_outputs([value])
     return values.ValueRef(value)
예제 #6
0
    def return_value_or_ref(obj: 'value.Object'):
        if isinstance(obj.get_value(), values.NumberValue):
            return values.ValueRef(obj.get_value())

        if isinstance(obj.get_value(), values.StrValue):
            return values.ValueRef(obj.get_value())

        if isinstance(obj.get_value(), values.BoolValue):
            return values.ValueRef(obj.get_value())

        if isinstance(obj.get_value(), values.NoneValue):
            return values.ValueRef(obj.get_value())

        if isinstance(obj.get_value(), values.TupleValue):
            return values.ValueRef(obj.get_value())

        return obj
예제 #7
0
    def apply_to_object(self, obj: 'values.ValueRef'):
        super().apply_to_object(obj)
        children = values.ValueRef(
            values.FuncValue(ChainerChainListChildrenFunction(self), obj))
        obj.get_field().get_attribute('children').revise(children)

        forward_func = obj.try_get_and_store_obj('forward', None)
        if forward_func is not None:
            obj.get_field().get_attribute('__call__').revise(forward_func)
            obj.get_field().get_attribute('forward').revise(forward_func)
예제 #8
0
def veval_ast_name_constant(astc: 'AstContext', local_field: 'values.Field',
                            graph: 'Graph'):
    '''
    Ex. True
    '''
    assert (isinstance(astc.nast, gast.gast.NameConstant))
    lineprop = utils.LineProperty(astc.lineno)
    ret = None
    if astc.nast.value == True:
        ret = values.ValueRef(values.BoolValue(True))
    if astc.nast.value == False:
        ret = values.ValueRef(values.BoolValue(False))
    if astc.nast.value is None:
        ret = values.ValueRef(values.NoneValue())

    name = values.create_ref_value_name_with_constant(ret)
    ret.name = name
    ret.get_value().name = name
    return ret
예제 #9
0
    def add_arg(self, name, value):

        if isinstance(value, values.Value):
            value = values.ValueRef(value)

        assert not (name in self.args.keys())

        fa = FunctionArg(name, value)
        self.args_list.append(fa)
        self.args[fa.name] = fa
    def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1):
        funcArgs = self.args.merge_inputs(inst, args)

        node = nodes.NodeCall(self, funcArgs, line)
        graph.add_node(node)
        #value = functions.generate_value_with_same_type(vargs[0])
        value = self.ret_value_func()
        value.name = '@F.{}.{}'.format(line, self.name)
        node.set_outputs([value])
        return values.ValueRef(value)
예제 #11
0
        def add_chainer_funtion(name:'str', func, ret_value_func = None):
            if ret_value_func is None:
                f = values.FuncValue(
                    functions_builtin.ChainerFunction(func), None)
            else:
                f = values.FuncValue(
                    functions_builtin.ChainerFunction(func, ret_value_func=ret_value_func), None)
            f_dict.get_field().get_attribute(name).revise(values.ValueRef(f))

            values.function_converters[func] = f
예제 #12
0
    def vcall(self,
              module: 'Field',
              graph: 'Graph',
              inst: 'values.ValueRef',
              args: 'functions.FunctionArgInput',
              line=-1):
        args = functions.FunctionArgInput()
        args.inputs.append(inst)
        args.keywords['self'] = inst

        value = values.ListValue(self.owner.children)
        return values.ValueRef(value)
예제 #13
0
def veval_ast_str(astc: 'AstContext', local_field: 'values.Field',
                  graph: 'Graph'):
    '''
    Ex. "str"
    '''
    assert (isinstance(astc.nast, gast.gast.Str))
    lineprop = utils.LineProperty(astc.lineno)
    value = values.StrValue(astc.nast.s)
    ret = values.ValueRef(value)

    name = values.create_ref_value_name_with_constant(ret)
    ret.name = name
    ret.get_value().name = name
    return ret
예제 #14
0
 def vcall(self,
           module: 'Field',
           graph: 'Graph',
           inst: 'values.ValueRef',
           args: 'functions.FunctionArgInput',
           line=-1):
     node = nodes.NodeLen(
         args.inputs[0].get_value(),  # TODO: Check this.
         line)
     graph.add_node(node)
     value = values.NumberValue(None)
     value.name = '@F.{}.{}'.format(line, self.name)
     node.set_outputs([value])
     return values.ValueRef(value)
    def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1):
        assert(inst is None)

        funcArgs = self.args.merge_inputs(inst, args)
        vargs = funcArgs.get_value().inputs
        value = values.ListValue()

        if isinstance(vargs[0], values.NoneValue):
            node = nodes.NodeGenerate('List', [], line)
            graph.add_node(node)
        else:
            node = nodes.NodeConvert('List', vargs[0], line)
            graph.add_node(node)

        value.name = '@F.{}.{}'.format(line, self.name)
        node.set_outputs([value])
        return values.ValueRef(value)
예제 #16
0
    def vcall(self,
              module: 'values.Field',
              graph: 'Graph',
              inst: 'values.ValueRef',
              args: 'functions.FunctionArgInput',
              line=-1):
        vargs = self.args.merge_inputs(inst, args)

        node = nodes.NodeCall(self, vargs, line)
        graph.add_node(node)
        value = values.TensorValue()

        estimate_shape = chainer_links[type(self.owner.inst)].estimate_shape
        if estimate_shape is not None:
            value.shape = estimate_shape(self.owner.inst, vargs)

        node.set_outputs([value])
        return values.ValueRef(value)
예제 #17
0
    def vcall(self,
              module: 'Field',
              graph: 'Graph',
              inst: 'values.ValueRef',
              args: 'functions.FunctionArgInput',
              line=-1):
        args = functions.FunctionArgInput()
        args.inputs.append(inst)
        args.keywords['self'] = inst

        node = nodes.NodeCall(self, args, line)

        value = values.ListValue()
        value.name = '@F.{}.{}'.format(line, self.name)
        node.set_outputs([value])

        # TODO should make tuple
        graph.add_node(node)
        return values.ValueRef(value)
예제 #18
0
    def vcall(self,
              module: 'Field',
              graph: 'Graph',
              inst: 'values.ValueRef',
              args: 'functions.FunctionArgInput',
              line=-1):
        args = functions.FunctionArgInput()
        args.inputs.append(inst)
        args.keywords['self'] = inst

        node = nodes.NodeCall(self, args, line)

        value = values.NumberValue(None)
        value.dtype = np.array(0).dtype
        value.name = '@F.{}.{}'.format(line, self.name)
        node.set_outputs([value])

        graph.add_node(node)
        return values.ValueRef(value)
예제 #19
0
def veval_ast_compare(astc: 'AstContext', local_field: 'values.Field',
                      graph: 'Graph'):
    """
    eval Compare.
    Ex. a >= b, a != b, a is b, etc
    """
    assert (isinstance(astc.nast, gast.gast.Compare))
    lineprop = utils.LineProperty(astc.lineno)

    left = veval_ast(astc.c(astc.nast.left), local_field, graph)
    right = veval_ast(astc.c(astc.nast.comparators[0]), local_field, graph)

    left_value = try_get_value(left, 'compare', lineprop)
    right_value = try_get_value(right, 'compare', lineprop)

    compare = nodes.CompareType.unknown
    if isinstance(astc.nast.ops[0], gast.Eq):
        compare = nodes.CompareType.Eq
    if isinstance(astc.nast.ops[0], gast.NotEq):
        compare = nodes.CompareType.NotEq
    if isinstance(astc.nast.ops[0], gast.Is):
        compare = nodes.CompareType.Is
    if isinstance(astc.nast.ops[0], gast.IsNot):
        compare = nodes.CompareType.IsNot
    if isinstance(astc.nast.ops[0], gast.Gt):
        compare = nodes.CompareType.Gt
    if isinstance(astc.nast.ops[0], gast.GtE):
        compare = nodes.CompareType.GtE
    if isinstance(astc.nast.ops[0], gast.Lt):
        compare = nodes.CompareType.Lt
    if isinstance(astc.nast.ops[0], gast.LtE):
        compare = nodes.CompareType.LtE

    node_compare = nodes.NodeCompare(left_value, right_value, compare,
                                     astc.lineno)

    ret_value = values.BoolValue(None)
    ret_value.name = '@{}'.format(lineprop)
    node_compare.set_outputs([ret_value])
    graph.add_node(node_compare)

    return values.ValueRef(ret_value)
예제 #20
0
def veval_ast_list(astc: 'AstContext', local_field: 'values.Field',
                   graph: 'Graph'):
    assert (isinstance(astc.nast, gast.gast.List))
    '''
    Ex. [],[x,y,z]
    TODO : Initializer
    '''
    lineprop = utils.LineProperty(astc.lineno)

    elts = []
    for elt in astc.nast.elts:
        elt_ = veval_ast(astc.c(elt), local_field, graph)
        elt_obj = try_get_ref(elt_, 'list', lineprop)
        elts.append(elt_obj)

    node = nodes.NodeGenerate('List', [elt.get_value() for elt in elts],
                              lineprop)
    graph.add_node(node)
    value = values.ListValue(elts)
    node.set_outputs([value])

    return values.ValueRef(value)
예제 #21
0
    def vcall(self,
              module: 'values.Field',
              graph: 'graphs.Graph',
              inst: 'values.ValueRef',
              args: 'FunctionArgInput',
              line=-1):
        ret = values.ValueRef(
            values.UserDefinedInstance(module, None, self.classinfo))
        inst = ret

        func_field = values.Field()
        func_field.set_module(module)

        # add args
        funcArgs = self.args.merge_inputs(inst, args)

        for k, v in funcArgs.keywords.items():
            func_field.get_field().get_attribute(k).revise(v)

        astc = vevaluator.AstContext(self.ast.body, self.lineno - 1)
        vevaluator.veval_ast(astc, func_field, graph)

        return ret
예제 #22
0
def veval_ast_for(astc: 'AstContext', local_field: 'values.Field',
                  graph: 'Graph'):
    '''
    for target in iter:
        ...
    '''
    assert (isinstance(astc.nast, gast.gast.For))
    lineprop = utils.LineProperty(astc.lineno)

    # for target in iter:
    iter_ = veval_ast(astc.c(astc.nast.iter), local_field, graph)
    input_iter_value = try_get_value(iter_, 'for', lineprop)
    body_iter_value = functions.generate_value_with_same_type(
        input_iter_value, suffix_type=functions.SuffixType.Input)

    # get target name
    target_name = ''
    if isinstance(astc.nast.target, gast.gast.Name):
        target_name = astc.nast.target.id
    else:
        if config.show_warnings:
            print('This for is not supported. in L.{}'.format(astc.lineno))
        return None

    # unroll?
    if isinstance(input_iter_value,
                  values.ListValue) and input_iter_value.has_constant_value(
                  ) and input_iter_value.dtype is None:
        return veval_ast_for_unroll(astc, target_name, input_iter_value,
                                    local_field, graph)

    for_guid = utils.get_guid()
    for_id = 'for_' + str(for_guid)
    body_id = 'body_' + str(for_guid)

    values.push_history(for_id)

    # body
    body_graph = Graph()
    body_graph.root_graph = graph.root_graph
    body_graph.name = 'Body_' + str(for_guid)

    # generate a node for input
    node_input = nodes.NodeInput('input')
    body_graph.add_node(node_input)

    body_counter_value = values.NumberValue(None)
    body_counter_value.dtype = np.array(0).dtype
    body_counter_value.name = 'for_counter_' + str(for_guid)

    body_cond_value = values.BoolValue(None)
    body_cond_value.name = 'for_cond_' + str(for_guid)

    # create a node to lookup a value from sequence
    node_forgen = nodes.NodeForGenerator(body_counter_value, body_iter_value)

    # generate iterator
    target_ref = input_iter_value.get_iterator()
    if target_ref is None:
        target_ref = values.ValueRef(values.UnknownValue())
        if config.show_warnings:
            print('unknown iteratable type in L.{}'.format(astc.lineno))
    target_value = target_ref.get_value()

    node_forgen.set_outputs([target_ref.get_value()])

    target_attribute = local_field.get_attribute(target_name)
    target_attribute.revise(target_ref)
    body_graph.add_node(node_forgen)

    # veval body
    body = veval_ast(astc.c(astc.nast.body), local_field, body_graph)

    value_inputs = values.get_inputs()
    value_outputs = values.get_outputs()

    values.pop_history()

    inputs = []
    outputs = []
    node_input_outputs = []

    # default input for subgraph's input
    body_graph.add_input_value(body_counter_value)
    body_graph.add_input_value(body_cond_value)
    body_graph.add_input_value(body_iter_value)

    # default output for subgraph's output
    body_graph.add_output_value(body_cond_value)
    body_graph.add_output_value(body_iter_value)

    # default output
    outputs.append(functions.generate_value_with_same_type(input_iter_value))

    # generate pairs
    value_pairs = {}
    for v in value_inputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['input_value'] = v.input_value
        value_pairs[key]['input_body_value'] = v.value

    for v in value_outputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['output_body_value'] = v.value

    for k, v in value_pairs.items():
        name = v['name']
        field = v['field']

        if 'input_body_value' in v:
            inputs.append(v['input_value'])

            body_graph.add_input_value(v['input_body_value'])
        else:
            temp_value1 = functions.generate_value_with_same_type(
                v['output_body_value'],
                is_dummy_value=True,
                suffix_type=functions.SuffixType.Dummy)
            temp_value2 = functions.generate_value_with_same_type(
                v['output_body_value'], suffix_type=functions.SuffixType.Dummy)
            inputs.append(temp_value1)

            body_graph.add_input_value(temp_value2)
            node_input_outputs.append(temp_value2)

        if 'output_body_value' in v:
            body_graph.add_output_value(v['output_body_value'])
            output_value = functions.generate_value_with_same_type(
                v['output_body_value'])
            outputs.append(output_value)
            if field.get_attribute(name).has_obj():
                field.get_attribute(name).get_ref().revise(output_value)
            else:
                field.get_attribute(name).revise(values.ValueRef(output_value))
        else:
            temp_value1 = v['input_body_value']
            temp_value2 = functions.generate_value_with_same_type(
                v['input_body_value'])
            body_graph.add_output_value(temp_value1)
            outputs.append(temp_value2)

    node = nodes.NodeFor(input_iter_value, inputs, body_graph, astc.lineno)
    node.set_outputs(outputs)
    node_input.set_outputs(node_input_outputs)

    graph.add_node(node)

    return None
예제 #23
0
 def apply_to_object(self, obj: 'values.ValueRef'):
     self.func = values.ValueRef(
         values.FuncValue(ChainerLinkFunction(self), obj))
     obj.get_field().get_attribute('forward').revise(self.func)
예제 #24
0
 def __init__(self):
     super().__init__()
     self.name = 'list'
     self.args.add_arg('value', values.ValueRef(values.NoneValue()))
예제 #25
0
def convert_model(model: 'chainer.Chain', args=[]):
    # reset values
    values.reset_field_and_attributes()
    utils.reset_guid()

    values.instance_converters.clear()

    def instance_converter(m, i):
        if links_builtin.is_builtin_chainer_link(i):
            return links_builtin.ChainerLinkInstance(m, i)
        return None

    values.instance_converters.append(instance_converter)

    # generate default module
    default_module = values.Module(sys.modules[model.__module__])

    # chainer
    chainer_module_name = get_module_name(
        C, default_module.internal_module)

    if chainer_module_name != '':
        c_dict = values.ValueRef(values.ModuleValue())

        # a substitute of Variable
        c_variable = values.FuncValue(functions_ndarray.NDArrayFunction(), None)
        c_dict.get_field().get_attribute('Variable').revise(values.ValueRef(c_variable))

        default_module.set_default_value(chainer_module_name, c_dict)

    # chainer.functions
    chainer_functions_module_name = get_module_name(
        F, default_module.internal_module)

    if chainer_functions_module_name != '':
        f_dict = values.ValueRef(values.ModuleValue())

        def add_chainer_funtion(name:'str', func, ret_value_func = None):
            if ret_value_func is None:
                f = values.FuncValue(
                    functions_builtin.ChainerFunction(func), None)
            else:
                f = values.FuncValue(
                    functions_builtin.ChainerFunction(func, ret_value_func=ret_value_func), None)
            f_dict.get_field().get_attribute(name).revise(values.ValueRef(f))

            values.function_converters[func] = f

        def ret_tuple():
            return values.TupleValue()

        add_chainer_funtion('relu', F.relu)
        add_chainer_funtion('softmax', F.softmax)
        add_chainer_funtion('softmax_cross_entropy', F.softmax_cross_entropy)
        add_chainer_funtion('pad_sequence', F.pad_sequence)
        add_chainer_funtion('average_pooling_2d', F.average_pooling_2d)
        add_chainer_funtion('unpooling_2d', F.unpooling_2d)
        add_chainer_funtion('reshape', F.reshape)
        add_chainer_funtion('split_axis', F.split_axis, ret_value_func=ret_tuple)
        add_chainer_funtion('reshape', F.reshape)
        add_chainer_funtion('swapaxes', F.swapaxes)
        add_chainer_funtion('dropout', F.dropout)
        add_chainer_funtion('concat', F.concat)
        add_chainer_funtion('matmul', F.matmul)
        add_chainer_funtion('max_pooling_2d', F.max_pooling_2d)
        add_chainer_funtion('resize_images', F.resize_images)

        if int(chainer.__version__[0]) >= 6:
            add_chainer_funtion('roi_max_pooling_2d', F.roi_max_pooling_2d)
            add_chainer_funtion('roi_average_pooling_2d', F.roi_average_pooling_2d)
            add_chainer_funtion('roi_max_align_2d', F.roi_max_align_2d)
        
        add_chainer_funtion('roi_average_align_2d', F.roi_average_align_2d)

        default_module.set_default_value(chainer_functions_module_name, f_dict)

    # numpy
    numpy_module_name = get_module_name(np, default_module.internal_module)
    if numpy_module_name != '':
        f_dict = values.ValueRef(values.ModuleValue())

        f_array = values.FuncValue(functions_ndarray.NDArrayFunction(), None)
        f_dict.get_field().get_attribute('array').revise(values.ValueRef(f_array))

        f_zeros = values.FuncValue(functions_ndarray.NDArrayZerosFunction(), None)
        f_dict.get_field().get_attribute('zeros').revise(values.ValueRef(f_zeros))

        f_full = values.FuncValue(functions_ndarray.NDArrayFullFunction(), None)
        f_dict.get_field().get_attribute('full').revise(values.ValueRef(f_full))

        f_ceil = values.FuncValue(functions_ndarray.NDArrayCeilFunction(), None)
        f_dict.get_field().get_attribute('ceil').revise(values.ValueRef(f_ceil))

        f_dict.get_field().get_attribute('int32').revise(
            values.ValueRef(values.NumberValue(utils.numpy_type_2_int(np.int32))))
        f_dict.get_field().get_attribute('float32').revise(
            values.ValueRef(values.NumberValue(utils.numpy_type_2_int(np.float32))))

        default_module.set_default_value(numpy_module_name, f_dict)

    m_range = values.FuncValue(functions_builtin.RangeFunction(), None)
    default_module.set_default_value('range', values.ValueRef(m_range))

    m_list = values.FuncValue(functions_builtin.ListFunction(), None)
    default_module.set_default_value('list', values.ValueRef(m_list))

    model_inst = values.parse_instance(default_module, '', model)
    forward_func = model_inst.try_get_and_store_obj('forward')

    # convert args
    finput = functions.FunctionArgInput()

    value_args = []
    ind = 0

    node_input = nodes.NodeInput('input')

    for arg in args:
        varg = values.parse_instance(default_module, '', arg, None, True)
        varg.name = 'in_' + str(ind)
        varg.get_value().name = 'in_' + str(ind)

        # make value unknown
        # if isinstance(varg.get_value(), values.TupleValue):
        #    for i in range(len(varg.get_value().internal_value)):
        #        varg.get_value().internal_value[i] = None
        # else:
        varg.get_value().internal_value = None

        finput.inputs.append(varg)
        value_args.append(varg.get_value())
        ind += 1

    node_input.set_outputs(value_args)

    graph = Graph()
    graph.add_node(node_input)

    forward_func_value = forward_func.get_value()
    ret = forward_func_value.func.vcall(
        default_module, graph, forward_func_value.obj, finput)
    assert(ret is None or isinstance(ret, values.ValueRef))

    def try_get_value(value) -> 'values.Value':
        if isinstance(value, values.Value):
            return value

        if isinstance(value, values.ValueRef):
            return value.get_value()

        if isinstance(value, values.Attribute):
            return value.get_ref().get_value()

    if ret is None or isinstance(ret, values.NoneValue):
        if config.show_warnings:
            print('Failed to compile. output is None.')
        return (value_args, None, graph)

    ret_ = []
    if isinstance(ret.get_value(), values.TupleValue):
        if ret.get_value().internal_value is not None:
            for v in ret.get_value().internal_value:
                assert(v is not None)
                ret_.append(try_get_value(v))
        else:
            ret_ = [ret.get_value()]

    elif isinstance(ret, list):
        ret_ = [r.get_value() for r in ret]
    else:
        ret_ = [ret.get_value()]

    for v in value_args:
        graph.add_input_value(v)

    for v in ret_:
        graph.add_output_value(v)

    return (value_args, ret_, graph)
예제 #26
0
def veval_ast_if(astc: 'AstContext', local_field: 'values.Field',
                 graph: 'Graph'):
    assert (isinstance(astc.nast, gast.gast.If))
    lineprop = utils.LineProperty(astc.lineno)

    # if condition
    test = veval_ast(astc.c(astc.nast.test), local_field, graph)
    test_value = try_get_value(test, 'if', lineprop)

    id_str = str(utils.get_guid())
    if_id = 'if_' + id_str
    true_id = 'true_' + id_str
    false_id = 'false_' + id_str

    # True
    values.push_history(true_id)

    true_graph = Graph()
    true_graph.root_graph = graph.root_graph
    true_graph.name = 'True'
    body = veval_ast(astc.c(astc.nast.body), local_field, true_graph)

    true_value_inputs = values.get_inputs()
    true_value_outputs = values.get_outputs()

    values.pop_history()

    # False
    values.push_history(false_id)

    false_graph = Graph()
    false_graph.root_graph = graph.root_graph
    false_graph.name = 'False'
    body = veval_ast(astc.c(astc.nast.orelse), local_field, false_graph)

    false_value_inputs = values.get_inputs()
    false_value_outputs = values.get_outputs()

    values.pop_history()

    # generate pairs
    value_pairs = {}
    for v in true_value_inputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['true_input_value'] = v.input_value
        value_pairs[key]['true_input_body_value'] = v.value

    for v in true_value_outputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['true_output_body_value'] = v.value

    for v in false_value_inputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['false_input_value'] = v.input_value
        value_pairs[key]['false_input_body_value'] = v.value

    for v in false_value_outputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['false_output_body_value'] = v.value

    inputs = []
    outputs = []

    for k, v in value_pairs.items():
        name = v['name']
        field = v['field']

        input_value = None
        true_input_body_value = None
        false_input_body_value = None

        if 'true_input_value' in v:
            input_value = v['true_input_value']
        elif 'false_input_value' in v:
            input_value = v['false_input_value']

        if input_value is not None:
            if 'true_input_body_value' in v:
                true_input_body_value = v['true_input_body_value']
            else:
                true_input_body_value = functions.generate_value_with_same_type(
                    input_value)

            if 'false_input_body_value' in v:
                false_input_body_value = v['false_input_body_value']
            else:
                false_input_body_value = functions.generate_value_with_same_type(
                    input_value)

        true_output_body_value = None
        false_output_body_value = None
        output_value = None

        if 'true_output_body_value' in v:
            true_output_body_value = v['true_output_body_value']
        else:
            true_output_body_value = true_input_body_value

        if 'false_output_body_value' in v:
            false_output_body_value = v['false_output_body_value']
        else:
            false_output_body_value = false_input_body_value

        # TODO check types between true and false

        if true_output_body_value is not None or false_output_body_value is not None:
            output_value = functions.generate_value_with_same_type(
                true_output_body_value)

        if input_value is not None:
            inputs.append(input_value)

            true_graph.add_input_value(true_input_body_value)
            false_graph.add_input_value(false_input_body_value)

        if output_value is not None:
            outputs.append(output_value)
            true_graph.add_output_value(true_output_body_value)
            false_graph.add_output_value(false_output_body_value)

            if field.get_attribute(name).has_obj():
                field.get_attribute(name).get_ref().revise(output_value)
            else:
                field.get_attribute(name).revise(values.ValueRef(output_value))

    node = nodes.NodeIf(test_value, inputs, true_graph, false_graph,
                        astc.lineno)
    node.set_outputs(outputs)

    graph.add_node(node)

    return None
예제 #27
0
def veval_ast_listcomp(astc: 'AstContext', local_field: 'values.Field',
                       graph: 'Graph'):
    '''
    Ex. [x for x in xx]
    [elt for target in iter]
    '''
    assert (isinstance(astc.nast, gast.gast.ListComp))
    lineprop = utils.LineProperty(astc.lineno)

    listcomp_guid = str(utils.get_guid())
    listcomp_id = 'listcomp_' + listcomp_guid
    body_id = 'listcomp_body_' + listcomp_guid
    internal_counter_id = '@internal/listcomp_counter_' + listcomp_guid
    internal_list_id = '@internal/listcomp_list_' + listcomp_guid
    internal_cond_id = '@internal/listcomp_cond_' + listcomp_guid

    generator = astc.nast.generators[0]
    iter_value = try_get_value(
        veval_ast(astc.c(generator.iter), local_field, graph), 'generator',
        lineprop)
    list_value = values.ListValue()
    list_obj = values.ValueRef(list_value)

    node_generate_list = nodes.NodeGenerate('List', [], lineprop)
    node_generate_list.set_outputs([list_value])
    graph.add_node(node_generate_list)

    # body
    target_name = ''
    if isinstance(generator.target, gast.gast.Name):
        target_name = generator.target.id
    else:
        if config.show_warnings:
            print('This for is not supported. in L.{}'.format(astc.lineno))
        return None

    counter_value = values.NumberValue(None)
    counter_value.dtype = np.array(0).dtype
    counter_value.name = internal_counter_id

    cond_value = values.BoolValue(None)
    cond_value.name = internal_cond_id

    # set values with internal name
    local_field.get_attribute(internal_list_id).revise(list_obj)

    values.push_history(listcomp_id)

    body_graph = Graph()
    body_graph.root_graph = graph.root_graph
    body_graph.name = 'Body_' + listcomp_guid

    node_forgen = nodes.NodeForGenerator(counter_value, iter_value)

    target_ref = iter_value.get_iterator()
    if target_ref is None:
        target_ref = values.ValueRef(values.UnknownValue())
        if config.show_warnings:
            print('unknown iteratable type in L.{}'.format(astc.lineno))
    target_value = target_ref.get_value()

    node_forgen.set_outputs([target_ref.get_value()])
    local_field.get_attribute(target_name,
                              from_module=False).revise(target_ref)

    body_graph.add_node(node_forgen)

    elt = veval_ast(astc.c(astc.nast.elt), local_field, body_graph)
    elt_obj = try_get_ref(elt, 'listcomp', lineprop)

    finput = functions.FunctionArgInput()
    finput.inputs.append(elt_obj)

    append_value = local_field.get_attribute(internal_list_id).get_ref(
    ).get_field().get_attribute('append').get_ref().get_value()
    append_value.func.vcall(
        local_field.module, body_graph,
        local_field.get_attribute(internal_list_id).get_ref(), finput,
        lineprop)

    value_inputs = values.get_inputs()
    value_outputs = values.get_outputs()

    values.pop_history()

    inputs = []
    outputs = []

    # default input for subgraph's input
    body_graph.add_input_value(counter_value)
    body_graph.add_input_value(cond_value)
    body_graph.add_input_value(iter_value)

    # default output for subgraph's output
    body_graph.add_output_value(cond_value)
    body_graph.add_output_value(iter_value)

    # default output
    outputs.append(functions.generate_value_with_same_type(iter_value))

    # generate pairs
    value_pairs = {}
    for v in value_inputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['input_value'] = v.input_value
        value_pairs[key]['input_body_value'] = v.value

    for v in value_outputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['output_body_value'] = v.value

    # remove iterator
    removed_name = str(local_field.id) + '_' + target_value.name
    del value_pairs[removed_name]

    for k, v in value_pairs.items():
        name = v['name']
        field = v['field']

        if 'input_body_value' in v:
            inputs.append(v['input_value'])
            body_graph.add_input_value(v['input_body_value'])

        else:
            temp_value1 = functions.generate_value_with_same_type(
                v['output_body_value'])
            temp_value2 = functions.generate_value_with_same_type(
                v['output_body_value'])
            inputs.append(temp_value1)
            body_graph.add_input_value(temp_value2)

        if 'output_body_value' in v:
            body_graph.add_output_value(v['output_body_value'])
            output_value = functions.generate_value_with_same_type(
                v['output_body_value'])
            outputs.append(output_value)
            if field.get_attribute(name).has_obj():
                field.get_attribute(name).get_ref().revise(output_value)
            else:
                field.get_attribute(name).revise(values.ValueRef(output_value))
        else:
            temp_value1 = v['input_body_value']
            temp_value2 = functions.generate_value_with_same_type(
                v['input_body_value'])
            body_graph.add_output_value(temp_value1)
            outputs.append(temp_value2)

    node = nodes.NodeListcomp(iter_value, inputs, body_graph, astc.lineno)
    node.set_outputs(outputs)

    graph.add_node(node)

    return local_field.get_attribute(internal_list_id).get_ref()
예제 #28
0
def veval_ast_subscript(astc: 'AstContext', local_field: 'values.Field',
                        graph: 'Graph'):
    '''
    Ex. x[1], x[y,z]
    '''
    assert (isinstance(astc.nast, gast.gast.Subscript))
    lineprop = utils.LineProperty(astc.lineno)

    def veval_with_default(nast, default_value):
        if nast is None:
            ret = values.NumberValue(default_value)
            ret.name = '@SliceDefault'
            return ret
        obj = veval_ast(astc.c(nast), local_field, graph)
        return try_get_value(obj, 'subscript', lineprop)

    def get_slice_indices(slice):
        if slice.lower is None and slice.upper is None and slice.step is None:
            return []
        indices = [
            veval_with_default(slice.lower, 0),
            veval_with_default(slice.upper, utils.slice_int_max)
        ]
        if slice.step is not None:
            indices.append(veval_with_default(slice.step, 1))
        return indices

    value = veval_ast(astc.c(astc.nast.value), local_field, graph)
    value_value = try_get_value(value, 'subscript', lineprop)

    if isinstance(astc.nast.slice, gast.gast.Index):
        slice_ = veval_ast(astc.c(astc.nast.slice.value), local_field, graph)
        slice_value = try_get_value(slice_, 'subscript', lineprop)

        if isinstance(slice_value, values.TupleValue):
            # ex. x[1,2]
            if slice_value.has_constant_value():
                values_ = [
                    try_get_value(x, 'subscript', lineprop)
                    for x in slice_value.get_constant_value()
                ]
                node = nodes.NodeGetItem(value_value, values_, line=lineprop)
            else:
                if config.show_warnings:
                    print('This subscript is not supported. in L.{}'.format(
                        astc.lineno))
                node = nodes.NodeInvalid(line=lineprop)
        else:
            # ex. x[1]
            node = nodes.NodeGetItem(value_value, [slice_value])
        ret_value = values.Value()
        node.set_outputs([ret_value])
        graph.add_node(node)
        return values.ValueRef(ret_value)

    elif isinstance(astc.nast.slice, gast.gast.Slice):

        indices = get_slice_indices(astc.nast.slice)

        node = nodes.NodeSlice(value_value, indices, [len(indices)])
        ret_value = functions.generate_value_with_same_type(value_value)
        node.set_outputs([ret_value])
        graph.add_node(node)
        return values.ValueRef(ret_value)

    elif isinstance(astc.nast.slice, gast.gast.ExtSlice):
        indices = []
        slice_specs = []
        for dim in astc.nast.slice.dims:
            if isinstance(dim, gast.gast.Index):
                indices.append(
                    try_get_value(
                        veval_ast(astc.c(dim.value), local_field, graph),
                        'subscript', lineprop))
                slice_specs.append(1)
            elif isinstance(dim, gast.gast.Slice):
                ni = get_slice_indices(dim)
                indices.extend(ni)
                slice_specs.append(len(ni))
            else:
                assert False, 'Unknown slice: %s in %s' % (dim, nast.slice)

        node = nodes.NodeSlice(value_value, indices, slice_specs)
        ret_value = functions.generate_value_with_same_type(value_value)
        node.set_outputs([ret_value])
        graph.add_node(node)
        return values.ValueRef(ret_value)

    return None