Пример #1
0
def veval_ast_listcomp(astc: 'AstContext', local_field: 'values.Field',
                       graph: 'Graph'):
    '''
    Ex. [x for x in xx]
    [elt for target in iter]
    '''
    assert (isinstance(astc.nast, gast.gast.ListComp))
    lineprop = utils.LineProperty(astc.lineno, astc.filename)

    listcomp_guid = str(utils.get_guid())
    listcomp_id = 'listcomp_' + listcomp_guid
    body_id = 'listcomp_body_' + listcomp_guid
    internal_counter_id = '@internal/listcomp_counter_' + listcomp_guid
    internal_list_id = '@internal/listcomp_list_' + listcomp_guid
    internal_cond_id = '@internal/listcomp_cond_' + listcomp_guid

    generator = astc.nast.generators[0]
    iter_value = try_get_value(
        veval_ast(astc.c(generator.iter), local_field, graph), 'generator',
        lineprop)
    list_value = values.ListValue()
    list_obj = values.ValueRef(list_value)

    node_generate_list = nodes.NodeGenerate('List', [], lineprop)
    node_generate_list.set_outputs([list_value])
    graph.add_node(node_generate_list)

    # body
    target_name = ''
    if isinstance(generator.target, gast.gast.Name):
        target_name = generator.target.id
    else:
        if config.show_warnings:
            print('This for is not supported. in L.{}'.format(astc.lineno))
        return None

    counter_value = values.NumberValue(None)
    counter_value.dtype = np.array(0).dtype
    counter_value.name = internal_counter_id

    cond_value = values.BoolValue(None)
    cond_value.name = internal_cond_id

    # set values with internal name
    local_field.get_attribute(internal_list_id).revise(list_obj)

    values.push_history(listcomp_id)

    body_graph = Graph()
    body_graph.root_graph = graph.root_graph
    body_graph.name = 'Body_' + listcomp_guid

    node_forgen = nodes.NodeForGenerator(counter_value, iter_value)

    target_ref = iter_value.get_iterator()
    if target_ref is None:
        target_ref = values.ValueRef(values.UnknownValue())
        if config.show_warnings:
            print('unknown iteratable type in L.{}'.format(lineprop))
    target_value = target_ref.get_value()

    node_forgen.set_outputs([target_ref.get_value()])
    local_field.get_attribute(target_name).revise(target_ref)

    body_graph.add_node(node_forgen)

    elt = veval_ast(astc.c(astc.nast.elt), local_field, body_graph)
    elt_obj = try_get_ref(elt, 'listcomp', lineprop)

    finput = functions.FunctionArgInput()
    finput.inputs.append(elt_obj)

    append_value = local_field.get_attribute(internal_list_id).get_ref(
    ).get_field().get_attribute('append').get_ref().get_value()
    append_value.func.vcall(
        None, body_graph,
        local_field.get_attribute(internal_list_id).get_ref(), finput,
        lineprop)

    value_inputs = values.get_inputs()
    value_outputs = values.get_outputs()

    values.pop_history()

    inputs = []
    outputs = []

    # default input for subgraph's input
    body_graph.add_input_value(counter_value)
    body_graph.add_input_value(cond_value)
    body_graph.add_input_value(iter_value)

    # default output for subgraph's output
    body_graph.add_output_value(cond_value)
    body_graph.add_output_value(iter_value)

    # default output
    outputs.append(functions.generate_value_with_same_type(iter_value))

    # generate pairs
    value_pairs = {}
    for v in value_inputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['input_value'] = v.input_value
        value_pairs[key]['input_body_value'] = v.value

    for v in value_outputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['output_body_value'] = v.value
        value_pairs[key]['output_obj'] = v.obj

    # remove iterator
    removed_name = str(local_field.id) + '_' + target_value.name
    del value_pairs[removed_name]

    for k, v in value_pairs.items():
        name = v['name']
        field = v['field']

        if 'input_body_value' in v:
            inputs.append(v['input_value'])
            body_graph.add_input_value(v['input_body_value'])

        else:
            temp_value1 = functions.generate_value_with_same_type(
                v['output_body_value'])
            temp_value2 = functions.generate_value_with_same_type(
                v['output_body_value'])
            inputs.append(temp_value1)
            body_graph.add_input_value(temp_value2)

        if 'output_body_value' in v:
            body_graph.add_output_value(v['output_body_value'])
            output_value = functions.generate_value_with_same_type(
                v['output_body_value'])
            outputs.append(output_value)
            if 'output_obj' in v:
                obj = v['output_obj']
                obj.revise(output_value)
                field.get_attribute(name).revise(obj)
            elif field.get_attribute(name).has_obj():
                field.get_attribute(name).get_ref().revise(output_value)
            else:
                field.get_attribute(name).revise(values.ValueRef(output_value))
        else:
            temp_value1 = v['input_body_value']
            temp_value2 = functions.generate_value_with_same_type(
                v['input_body_value'])
            body_graph.add_output_value(temp_value1)
            outputs.append(temp_value2)

    node = nodes.NodeListcomp(iter_value, inputs, body_graph, astc.lineno)
    node.set_outputs(outputs)

    graph.add_node(node)

    return local_field.get_attribute(internal_list_id).get_ref()
Пример #2
0
def veval_ast_if(astc: 'AstContext', local_field: 'values.Field',
                 graph: 'Graph'):
    assert (isinstance(astc.nast, gast.gast.If))
    lineprop = utils.LineProperty(astc.lineno, astc.filename)

    # if condition
    test = veval_ast(astc.c(astc.nast.test), local_field, graph)
    test_value = try_get_value(test, 'if', lineprop)

    id_str = str(utils.get_guid())
    if_id = 'if_' + id_str
    true_id = 'true_' + id_str
    false_id = 'false_' + id_str

    # True
    values.push_history(true_id)

    true_graph = Graph()
    true_graph.root_graph = graph.root_graph
    true_graph.name = 'True'
    true_body = veval_ast(astc.c(astc.nast.body), local_field, true_graph)

    true_value_inputs = values.get_inputs()
    true_value_outputs = values.get_outputs()

    values.pop_history()

    # False
    values.push_history(false_id)

    false_graph = Graph()
    false_graph.root_graph = graph.root_graph
    false_graph.name = 'False'
    false_body = veval_ast(astc.c(astc.nast.orelse), local_field, false_graph)

    false_value_inputs = values.get_inputs()
    false_value_outputs = values.get_outputs()

    values.pop_history()

    # generate pairs
    value_pairs = {}
    for v in true_value_inputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['true_input_value'] = v.input_value
        value_pairs[key]['true_input_body_value'] = v.value
        value_pairs[key]['true_input_obj'] = v.obj

    for v in true_value_outputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['true_output_body_value'] = v.value
        value_pairs[key]['true_output_obj'] = v.obj

    for v in false_value_inputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['false_input_value'] = v.input_value
        value_pairs[key]['false_input_body_value'] = v.value
        value_pairs[key]['false_input_obj'] = v.obj

    for v in false_value_outputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['false_output_body_value'] = v.value
        value_pairs[key]['false_output_obj'] = v.obj

    inputs = []
    outputs = []

    for k, v in value_pairs.items():
        name = v['name']
        field = v['field']

        input_value = None
        true_input_body_value = None
        false_input_body_value = None

        # search input value
        if 'true_input_value' in v:
            input_value = v['true_input_value']
        elif 'false_input_value' in v:
            input_value = v['false_input_value']

        if input_value is not None:
            if 'true_input_body_value' in v:
                true_input_body_value = v['true_input_body_value']
            else:
                true_input_body_value = functions.generate_value_with_same_type(
                    input_value)

            if 'false_input_body_value' in v:
                false_input_body_value = v['false_input_body_value']
            else:
                false_input_body_value = functions.generate_value_with_same_type(
                    input_value)

        true_output_body_value = None
        false_output_body_value = None
        output_value = None

        # search output value
        if 'true_output_body_value' in v:
            true_output_body_value = v['true_output_body_value']

        if 'false_output_body_value' in v:
            false_output_body_value = v['false_output_body_value']

        if true_output_body_value is not None or false_output_body_value is not None:

            if true_output_body_value is None:
                if true_input_body_value is not None:
                    # e.x. not changed
                    true_output_body_value = true_input_body_value
                else:
                    # e.x. make a value in false statement
                    true_output_body_value = functions.generate_value_with_same_type(
                        false_output_body_value, is_dummy_value=True)

            if false_output_body_value is None:
                if false_input_body_value is not None:
                    # e.x. not changed
                    false_output_body_value = false_input_body_value
                else:
                    # e.x. make a value in true statement
                    false_output_body_value = functions.generate_value_with_same_type(
                        true_output_body_value, is_dummy_value=True)

        # check types between true and false
        true_output_body_value_type = None
        false_output_body_value_type = None

        if true_output_body_value is not None and true_output_body_value.is_not_none_or_any_value(
        ):
            true_output_body_value_type = true_output_body_value

        if false_output_body_value is not None and false_output_body_value.is_not_none_or_any_value(
        ):
            false_output_body_value_type = false_output_body_value

        if true_output_body_value_type is not None and false_output_body_value_type is not None and type(
                true_output_body_value_type) != type(
                    false_output_body_value_type):
            utils.print_warning(
                'Values with differenet type were generated {} between true ande false'
                .format(k), lineprop)

        if true_output_body_value_type != None:
            output_value = functions.generate_value_with_same_type(
                true_output_body_value_type)
        elif false_output_body_value_type != None:
            output_value = functions.generate_value_with_same_type(
                false_output_body_value_type)
        elif true_output_body_value is not None:
            output_value = functions.generate_value_with_same_type(
                true_output_body_value)
        elif false_output_body_value is not None:
            output_value = functions.generate_value_with_same_type(
                false_output_body_value)

        if input_value is not None:
            inputs.append(input_value)

            true_graph.add_input_value(true_input_body_value)
            false_graph.add_input_value(false_input_body_value)

        if output_value is not None:
            outputs.append(output_value)
            true_graph.add_output_value(true_output_body_value)
            false_graph.add_output_value(false_output_body_value)

            if 'true_output_obj' in v and not 'false_output_obj' in v:
                obj = v['true_output_obj']
            elif not 'true_output_obj' in v and 'false_output_obj' in v:
                obj = v['false_output_obj']
            elif 'true_output_obj' in v and 'false_output_obj' in v:
                obj = None
            else:
                assert (False)

            if obj is not None:
                obj.revise(output_value)
                field.get_attribute(name).revise(obj)
            elif field.get_attribute(name).has_obj():
                field.get_attribute(name).get_ref().revise(output_value)
            else:
                field.get_attribute(name).revise(values.ValueRef(output_value))

    node = nodes.NodeIf(test_value, inputs, true_graph, false_graph,
                        astc.lineno)
    node.set_outputs(outputs)

    graph.add_node(node)

    return None
Пример #3
0
def convert_model(model: 'chainer.Chain', args=[]):
    # reset values
    values.reset_field_and_attributes()
    utils.reset_guid()

    values.function_converters.clear()
    values.builtin_function_converters.clear()
    values.instance_converters.clear()

    def instance_converter(m, i):
        if links_builtin.is_builtin_chainer_link(i):
            return links_builtin.ChainerLinkInstance(m, i)

        if isinstance(i, chainer.ChainList):    
            module = values.Object(values.ModuleValue(sys.modules[i.__module__]))
            return links_builtin.ChainerChainListInstance(module, i)

        if isinstance(i, chainer.Link):
            module = values.Object(values.ModuleValue(sys.modules[i.__module__]))
            return links_builtin.ChainerChainInstance(module, i)

        return None

    values.instance_converters.append(instance_converter)

    custom_functions_module = values.Object(values.ModuleValue(custom_functions))

    # onnx
    functions_onnx_module = values.Object(values.ModuleValue(functions_onnx))
    def ret_same(funcArgs):
        return functions.generate_value_with_same_type(funcArgs.keywords['x'].get_value())

    values.function_converters[functions_onnx.onnx_abs] = values.FuncValue(functions_builtin.ChainerFunction(functions_onnx.onnx_abs, ret_value_func=ret_same), None, module=functions_onnx_module)

    # chainer
    c_variable = values.FuncValue(functions_ndarray.NDArrayFunction(), None)
    values.function_converters[chainer.Variable] = c_variable

    # chainer.functions
    def add_chainer_function(func, ret_value_func = None):
        if ret_value_func is None:
            f = values.FuncValue(
                functions_builtin.ChainerFunction(func), None)
        else:
            f = values.FuncValue(
                functions_builtin.ChainerFunction(func, ret_value_func=ret_value_func), None)

        values.function_converters[func] = f

    def ret_tuple(funcArgs = None):
        ret = values.TupleValue()
        ret.vtype = values.TensorValue
        return ret

    # register unsupported functions to show error when unsupported functions are called
    for f in F.__dict__.items():
        if inspect.isfunction(f[1]):
            values.function_converters[f[1]] = values.FuncValue(functions.UnimplementedFunction(f[1]), None)

    # activation
    add_chainer_function(F.elu)
    add_chainer_function(F.leaky_relu)
    add_chainer_function(F.log_softmax)
    add_chainer_function(F.relu)
    add_chainer_function(F.selu)
    add_chainer_function(F.sigmoid)
    add_chainer_function(F.softmax)
    add_chainer_function(F.tanh)

    add_chainer_function(F.softmax_cross_entropy)
    add_chainer_function(F.pad_sequence)
    add_chainer_function(F.average_pooling_2d)
    add_chainer_function(F.unpooling_2d)
    add_chainer_function(F.reshape)
    add_chainer_function(F.transpose)
    add_chainer_function(F.split_axis, ret_value_func=ret_tuple)
    add_chainer_function(F.hstack)
    add_chainer_function(F.vstack)
    add_chainer_function(F.stack)
    add_chainer_function(F.separate, ret_value_func=ret_tuple)
    add_chainer_function(F.squeeze)
    add_chainer_function(F.swapaxes)
    add_chainer_function(F.dropout)
    add_chainer_function(F.concat)
    add_chainer_function(F.matmul)
    add_chainer_function(F.max_pooling_2d)
    add_chainer_function(F.resize_images)
    add_chainer_function(F.broadcast_to)
    add_chainer_function(F.expand_dims)
    add_chainer_function(F.local_response_normalization)
    add_chainer_function(F.mean)
    add_chainer_function(F.average)
    add_chainer_function(F.sum)
    add_chainer_function(F.maximum)
    add_chainer_function(F.minimum)
    add_chainer_function(F.max)
    add_chainer_function(F.min)

    values.function_converters[F.absolute] = values.FuncValue(functions.UserDefinedFunction(custom_functions.chainer_absolute), None, module=custom_functions_module)

    add_chainer_function(F.sin)
    add_chainer_function(F.sinh)
    add_chainer_function(F.sign)
    add_chainer_function(F.cos)
    add_chainer_function(F.cosh)
    add_chainer_function(F.tan)
    add_chainer_function(F.tanh)
    add_chainer_function(F.arcsin)
    add_chainer_function(F.arccos)
    add_chainer_function(F.arctan)
    add_chainer_function(F.exp)
    add_chainer_function(F.log)
    add_chainer_function(F.sqrt)

    add_chainer_function(F.clip)

    values.function_converters[F.argmax] = values.FuncValue(functions_builtin.ChainerArgminmaxFunction(F.argmax), None)
    values.function_converters[F.argmin] = values.FuncValue(functions_builtin.ChainerArgminmaxFunction(F.argmin), None)

    values.function_converters[F.clipped_relu] = values.FuncValue(functions.UserDefinedFunction(custom_functions.chainer_clipped_relu), None, module=custom_functions_module)

    if int(chainer.__version__[0]) >= 6:
        add_chainer_function(F.roi_max_pooling_2d)
        add_chainer_function(F.roi_average_pooling_2d)
        add_chainer_function(F.roi_max_align_2d)

    add_chainer_function(F.roi_average_align_2d)

    # numpy
    f_array = values.FuncValue(functions_ndarray.NDArrayFunction(), None)
    f_zeros = values.FuncValue(functions_ndarray.NDArrayZerosFunction(), None)
    f_full = values.FuncValue(functions_ndarray.NDArrayFullFunction(), None)
    f_ceil = values.FuncValue(functions_ndarray.NDArrayCeilFunction(), None)
    f_cumsum = values.FuncValue(functions_ndarray.NDArrayCumsumFunction(), None)
    f_maximum = values.FuncValue(functions_ndarray.NDArrayChainerFunction(functions_ndarray.dummy_maximum), None)
    f_minimum = values.FuncValue(functions_ndarray.NDArrayChainerFunction(functions_ndarray.dummy_minimum), None)
    f_argmax = values.FuncValue(functions_ndarray.NDarrayArgminmaxFunction(functions_ndarray.dummy_argmax), None)
    f_argmin = values.FuncValue(functions_ndarray.NDarrayArgminmaxFunction(functions_ndarray.dummy_argmin), None)
    f_round = values.FuncValue(functions_ndarray.NDarrayRoundFunction(functions_ndarray.dummy_round), None)
    f_sqrt = values.FuncValue(functions_ndarray.NDarraySqrtFunction(functions_ndarray.dummy_sqrt), None)
    f_stack = values.FuncValue(functions_ndarray.NDarrayStackFunction(functions_ndarray.dummy_stack), None)
    f_reshape = values.FuncValue(functions_ndarray.NDarrayReshapeFunction(functions_ndarray.dummy_reshape), None)
    f_transpose = values.FuncValue(functions_ndarray.NDarrayTransposeFunction(functions_ndarray.dummy_transpose), None)

    f_int32 = values.FuncValue(functions_ndarray.NDArrayInt32(), None)
    f_float32 = values.FuncValue(functions_ndarray.NDArrayFloat32(), None)

    values.function_converters[np.array] = f_array
    values.function_converters[np.zeros] = f_zeros
    values.function_converters[np.full] = f_full
    values.function_converters[np.ceil] = f_ceil
    values.function_converters[np.cumsum] = f_cumsum
    values.function_converters[np.int32] = f_int32
    values.function_converters[np.float32] = f_float32
    values.function_converters[np.maximum] = f_maximum
    values.function_converters[np.minimum] = f_minimum
    values.function_converters[np.argmax] = f_argmax
    values.function_converters[np.argmin] = f_argmin
    values.function_converters[np.round] = f_round
    values.function_converters[np.sqrt] = f_sqrt
    values.function_converters[np.stack] = f_stack
    values.function_converters[np.reshape] = f_reshape
    values.function_converters[np.transpose] = f_transpose

    values.function_converters[np.clip] = values.FuncValue(functions.UserDefinedFunction(custom_functions.numpy_clip), None, module=custom_functions_module)
    values.function_converters[np.absolute] = values.FuncValue(functions.UserDefinedFunction(custom_functions.numpy_absolute), None, module=custom_functions_module)

    values.function_converters[custom_functions.check_attribute_value] = values.FuncValue(functions.CheckAttributeValueFunction(), None, module=custom_functions_module)

    values.function_converters[custom_functions.check_attribute_scalar] = values.FuncValue(functions.CheckAttributeScalarFunction(), None, module=custom_functions_module)

    values.builtin_function_converters['abs'] = values.FuncValue(functions.UserDefinedFunction(custom_functions.builtin_absolute), None, module=custom_functions_module)

    m_range = values.FuncValue(functions_builtin.RangeFunction(), None)
    values.builtin_function_converters['range'] = m_range

    m_len = values.FuncValue(functions_builtin.LenFunction(), None)
    values.builtin_function_converters['len'] = m_len

    values.function_converters[six.moves.range] = m_range

    m_list = values.FuncValue(functions_builtin.ListFunction(), None)
    values.builtin_function_converters['list'] = m_list

    m_print = values.FuncValue(functions_builtin.PrintFunction(), None)
    values.builtin_function_converters['print'] = m_print

    m_getattr = values.FuncValue(functions_builtin.GetAttrFunction(), None)
    values.builtin_function_converters['getattr'] = m_getattr

    m_hasattr = values.FuncValue(functions_builtin.HasAttrFunction(), None)
    values.builtin_function_converters['hasattr'] = m_hasattr

    m_to_gpu = values.FuncValue(functions_builtin.CopyFunction(cuda.to_gpu), None)
    values.function_converters[cuda.to_gpu] = m_to_gpu

    m_to_cpu = values.FuncValue(functions_builtin.CopyFunction(cuda.to_cpu), None)
    values.function_converters[cuda.to_cpu] = m_to_cpu

    # generate VEvalFlag functions
    def add_veval_flag_function(name:'str', func):
        f = values.FuncValue(functions_builtin.VEvalContextFunction(func), None)
        values.builtin_function_converters[name] = f

    add_veval_flag_function('eval_as_written_target', flags.eval_as_written_target)
    add_veval_flag_function('ignore_branch', flags.ignore_branch)
    add_veval_flag_function('for_unroll', flags.for_unroll)

    # generate default module
    default_module = values.Object(values.ModuleValue(sys.modules[model.__module__]))

    model_inst = values.parse_instance(default_module, '', model)
    forward_func = model_inst.try_get_and_store_obj('forward', None)

    # convert args
    finput = functions.FunctionArgInput()

    value_args = []
    ind = 0

    node_input = nodes.NodeInput('input')

    for arg in args:
        varg = values.parse_instance(default_module, '', arg, None)
        varg.name = 'in_' + str(ind)
        varg.get_value().name = 'in_' + str(ind)

        # make value unknown
        varg.get_value().internal_value = None

        finput.inputs.append(varg)
        value_args.append(varg.get_value())
        ind += 1

    node_input.set_outputs(value_args)

    graph = Graph()
    graph.root_graph = graph
    graph.add_node(node_input)

    forward_func_value = forward_func.get_value()
    ret = forward_func_value.func.vcall(
        default_module, graph, forward_func_value.obj, finput).get_obj()
    assert(ret is None or isinstance(ret, values.Object))

    def try_get_value(value) -> 'values.Value':
        if isinstance(value, values.Value):
            return value

        if isinstance(value, values.Object):
            return value.get_value()

        if isinstance(value, values.Attribute):
            return value.get_obj().get_value()

    if ret is None or isinstance(ret, values.NoneValue):
        if config.show_warnings:
            print('Failed to compile. output is None.')
        return (value_args, None, graph)

    ret_ = []
    if isinstance(ret.get_value(), values.TupleValue):
        if ret.get_value().internal_value is not None:
            for v in ret.get_value().internal_value:
                assert(v is not None)
                ret_.append(try_get_value(v))
        else:
            ret_ = [ret.get_value()]

    elif isinstance(ret, list):
        ret_ = [r.get_value() for r in ret]
    else:
        ret_ = [ret.get_value()]

    for v in value_args:
        graph.add_input_value(v)

    for v in ret_:
        graph.add_output_value(v)

    return (value_args, ret_, graph)
Пример #4
0
def veval_ast_for(astc: 'AstContext', local_field: 'values.Field',
                  graph: 'Graph'):
    '''
    for target in iter:
        ...
    '''
    assert (isinstance(astc.nast, gast.gast.For))
    lineprop = utils.LineProperty(astc.lineno, astc.filename)

    # for target in iter:
    iter_ = veval_ast(astc.c(astc.nast.iter), local_field, graph)
    input_iter_value = try_get_value(iter_, 'for', lineprop)
    body_iter_value = functions.generate_value_with_same_type(
        input_iter_value, suffix_type=functions.SuffixType.Input)

    # get target name
    target_name = ''
    if isinstance(astc.nast.target, gast.gast.Name):
        target_name = astc.nast.target.id
    else:
        if config.show_warnings:
            print('This for is not supported. in L.{}'.format(astc.lineno))
        return None

    # unroll?
    if isinstance(input_iter_value,
                  values.ListValue) and input_iter_value.has_constant_value(
                  ) and input_iter_value.dtype is None:
        return veval_ast_for_unroll(astc, target_name, input_iter_value,
                                    local_field, graph)

    for_guid = utils.get_guid()
    for_id = 'for_' + str(for_guid)
    body_id = 'body_' + str(for_guid)

    values.push_history(for_id)

    # body
    body_graph = Graph()
    body_graph.root_graph = graph.root_graph
    body_graph.name = 'Body_' + str(for_guid)

    # generate a node for input
    node_input = nodes.NodeInput('input')
    body_graph.add_node(node_input)

    body_counter_value = values.NumberValue(None)
    body_counter_value.dtype = np.array(0).dtype
    body_counter_value.name = 'for_counter_' + str(for_guid)

    body_cond_value = values.BoolValue(None)
    body_cond_value.name = 'for_cond_' + str(for_guid)

    # create a node to lookup a value from sequence
    node_forgen = nodes.NodeForGenerator(body_counter_value, body_iter_value)

    # generate iterator
    target_ref = input_iter_value.get_iterator()
    if target_ref is None:
        target_ref = values.ValueRef(values.UnknownValue())
        if config.show_warnings:
            print('unknown iteratable type in L.{}'.format(astc.lineno))
    target_value = target_ref.get_value()

    node_forgen.set_outputs([target_ref.get_value()])

    target_attribute = local_field.get_attribute(target_name)
    target_attribute.revise(target_ref)
    body_graph.add_node(node_forgen)

    # veval body
    body = veval_ast(astc.c(astc.nast.body), local_field, body_graph)

    value_inputs = values.get_inputs()
    value_outputs = values.get_outputs()

    break_attribute = local_field.get_attribute('#keepgoing')
    if break_attribute.has_obj():
        break_attribute_ref = break_attribute.get_ref()
        break_attribute_value = break_attribute_ref.get_value()
    else:
        break_attribute_value = body_cond_value

    values.pop_history()

    inputs = []
    outputs = []
    node_input_outputs = []

    # default input for subgraph's input
    body_graph.add_input_value(body_counter_value)
    body_graph.add_input_value(body_cond_value)
    body_graph.add_input_value(body_iter_value)

    # default output for subgraph's output
    body_graph.add_output_value(break_attribute_value)
    body_graph.add_output_value(body_iter_value)

    # default output
    outputs.append(functions.generate_value_with_same_type(input_iter_value))

    # generate pairs
    value_pairs = {}
    for v in value_inputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['input_value'] = v.input_value
        value_pairs[key]['input_body_value'] = v.value

    for v in value_outputs:
        key = str(v.field.id) + '_' + v.name
        if not (key in value_pairs.keys()):
            value_pairs[key] = {}

        value_pairs[key]['field'] = v.field
        value_pairs[key]['name'] = v.name
        value_pairs[key]['output_body_value'] = v.value
        value_pairs[key]['output_obj'] = v.obj

    for k, v in value_pairs.items():
        name = v['name']
        field = v['field']

        if 'input_body_value' in v:
            inputs.append(v['input_value'])

            body_graph.add_input_value(v['input_body_value'])
        else:
            temp_value1 = functions.generate_value_with_same_type(
                v['output_body_value'],
                is_dummy_value=True,
                suffix_type=functions.SuffixType.Dummy)
            temp_value2 = functions.generate_value_with_same_type(
                v['output_body_value'], suffix_type=functions.SuffixType.Dummy)
            inputs.append(temp_value1)

            body_graph.add_input_value(temp_value2)
            node_input_outputs.append(temp_value2)

        if 'output_body_value' in v:
            body_graph.add_output_value(v['output_body_value'])
            output_value = functions.generate_value_with_same_type(
                v['output_body_value'])
            outputs.append(output_value)

            if 'output_obj' in v:
                obj = v['output_obj']
                obj.revise(output_value)
                field.get_attribute(name).revise(obj)
            elif field.get_attribute(name).has_obj():
                field.get_attribute(name).get_ref().revise(output_value)
            else:
                field.get_attribute(name).revise(values.ValueRef(output_value))
        else:
            temp_value1 = v['input_body_value']
            temp_value2 = functions.generate_value_with_same_type(
                v['input_body_value'])
            body_graph.add_output_value(temp_value1)
            outputs.append(temp_value2)

    node = nodes.NodeFor(input_iter_value, inputs, body_graph, body_cond_value,
                         astc.lineno)
    node.set_outputs(outputs)
    node_input.set_outputs(node_input_outputs)

    graph.add_node(node)

    return None
Пример #5
0
def convert_model(model: 'chainer.Chain', args=[]):
    # reset values
    values.reset_field_and_attributes()
    utils.reset_guid()

    values.function_converters.clear()
    values.builtin_function_converters.clear()
    values.instance_converters.clear()

    def instance_converter(m, i):
        if links_builtin.is_builtin_chainer_link(i):
            return links_builtin.ChainerLinkInstance(m, i)

        if isinstance(i, chainer.ChainList):
            module = values.ValueRef(
                values.ModuleValue(sys.modules[i.__module__]))
            return links_builtin.ChainerChainListInstance(module, i)

        if isinstance(i, chainer.Link):
            module = values.ValueRef(
                values.ModuleValue(sys.modules[i.__module__]))
            return links_builtin.ChainerChainInstance(module, i)

        return None

    values.instance_converters.append(instance_converter)

    # chainer
    c_variable = values.FuncValue(functions_ndarray.NDArrayFunction(), None)
    values.function_converters[chainer.Variable] = c_variable

    # chainer.functions
    def add_chainer_funtion(name: 'str', func, ret_value_func=None):
        if ret_value_func is None:
            f = values.FuncValue(functions_builtin.ChainerFunction(func), None)
        else:
            f = values.FuncValue(
                functions_builtin.ChainerFunction(
                    func, ret_value_func=ret_value_func), None)

        values.function_converters[func] = f

    def ret_tuple():
        ret = values.TupleValue()
        ret.vtype = values.TensorValue
        return ret

    add_chainer_funtion('relu', F.relu)
    add_chainer_funtion('softmax', F.softmax)
    add_chainer_funtion('softmax_cross_entropy', F.softmax_cross_entropy)
    add_chainer_funtion('pad_sequence', F.pad_sequence)
    add_chainer_funtion('average_pooling_2d', F.average_pooling_2d)
    add_chainer_funtion('unpooling_2d', F.unpooling_2d)
    add_chainer_funtion('reshape', F.reshape)
    add_chainer_funtion('split_axis', F.split_axis, ret_value_func=ret_tuple)
    add_chainer_funtion('hstack', F.hstack)
    add_chainer_funtion('vstack', F.vstack)
    add_chainer_funtion('stack', F.stack)
    add_chainer_funtion('separate', F.separate, ret_value_func=ret_tuple)
    add_chainer_funtion('squeeze', F.squeeze)
    add_chainer_funtion('swapaxes', F.swapaxes)
    add_chainer_funtion('dropout', F.dropout)
    add_chainer_funtion('concat', F.concat)
    add_chainer_funtion('matmul', F.matmul)
    add_chainer_funtion('max_pooling_2d', F.max_pooling_2d)
    add_chainer_funtion('resize_images', F.resize_images)
    add_chainer_funtion('tanh', F.tanh)
    add_chainer_funtion('sigmoid', F.sigmoid)
    add_chainer_funtion('broadcast_to', F.broadcast_to)
    add_chainer_funtion('expand_dims', F.expand_dims)
    add_chainer_funtion('local_response_normalization',
                        F.local_response_normalization)
    add_chainer_funtion('mean', F.mean)
    add_chainer_funtion('average', F.average)
    add_chainer_funtion('sum', F.sum)

    if int(chainer.__version__[0]) >= 6:
        add_chainer_funtion('roi_max_pooling_2d', F.roi_max_pooling_2d)
        add_chainer_funtion('roi_average_pooling_2d', F.roi_average_pooling_2d)
        add_chainer_funtion('roi_max_align_2d', F.roi_max_align_2d)

    add_chainer_funtion('roi_average_align_2d', F.roi_average_align_2d)

    # numpy
    f_array = values.FuncValue(functions_ndarray.NDArrayFunction(), None)
    f_zeros = values.FuncValue(functions_ndarray.NDArrayZerosFunction(), None)
    f_full = values.FuncValue(functions_ndarray.NDArrayFullFunction(), None)
    f_ceil = values.FuncValue(functions_ndarray.NDArrayCeilFunction(), None)
    f_cumsum = values.FuncValue(functions_ndarray.NDArrayCumsumFunction(),
                                None)

    f_int32 = values.FuncValue(functions_ndarray.NDArrayInt32(), None)
    f_float32 = values.FuncValue(functions_ndarray.NDArrayFloat32(), None)

    values.function_converters[np.array] = f_array
    values.function_converters[np.zeros] = f_zeros
    values.function_converters[np.full] = f_full
    values.function_converters[np.ceil] = f_ceil
    values.function_converters[np.cumsum] = f_cumsum
    values.function_converters[np.int32] = f_int32
    values.function_converters[np.float32] = f_float32

    m_range = values.FuncValue(functions_builtin.RangeFunction(), None)
    values.builtin_function_converters['range'] = m_range

    m_len = values.FuncValue(functions_builtin.LenFunction(), None)
    values.builtin_function_converters['len'] = m_len

    values.function_converters[six.moves.range] = m_range

    m_list = values.FuncValue(functions_builtin.ListFunction(), None)
    values.builtin_function_converters['list'] = m_list

    m_to_gpu = values.FuncValue(functions_builtin.CopyFunction(cuda.to_gpu),
                                None)
    values.function_converters[cuda.to_gpu] = m_to_gpu

    m_to_cpu = values.FuncValue(functions_builtin.CopyFunction(cuda.to_cpu),
                                None)
    values.function_converters[cuda.to_cpu] = m_to_cpu

    # generate default module
    default_module = values.ValueRef(
        values.ModuleValue(sys.modules[model.__module__]))

    model_inst = values.parse_instance(default_module, '', model)
    forward_func = model_inst.try_get_and_store_obj('forward', None)

    # convert args
    finput = functions.FunctionArgInput()

    value_args = []
    ind = 0

    node_input = nodes.NodeInput('input')

    for arg in args:
        varg = values.parse_instance(default_module, '', arg, None)
        varg.name = 'in_' + str(ind)
        varg.get_value().name = 'in_' + str(ind)

        # make value unknown
        varg.get_value().internal_value = None

        finput.inputs.append(varg)
        value_args.append(varg.get_value())
        ind += 1

    node_input.set_outputs(value_args)

    graph = Graph()
    graph.root_graph = graph
    graph.add_node(node_input)

    forward_func_value = forward_func.get_value()
    ret = forward_func_value.func.vcall(default_module, graph,
                                        forward_func_value.obj, finput)
    assert (ret is None or isinstance(ret, values.ValueRef))

    def try_get_value(value) -> 'values.Value':
        if isinstance(value, values.Value):
            return value

        if isinstance(value, values.ValueRef):
            return value.get_value()

        if isinstance(value, values.Attribute):
            return value.get_ref().get_value()

    if ret is None or isinstance(ret, values.NoneValue):
        if config.show_warnings:
            print('Failed to compile. output is None.')
        return (value_args, None, graph)

    ret_ = []
    if isinstance(ret.get_value(), values.TupleValue):
        if ret.get_value().internal_value is not None:
            for v in ret.get_value().internal_value:
                assert (v is not None)
                ret_.append(try_get_value(v))
        else:
            ret_ = [ret.get_value()]

    elif isinstance(ret, list):
        ret_ = [r.get_value() for r in ret]
    else:
        ret_ = [ret.get_value()]

    for v in value_args:
        graph.add_input_value(v)

    for v in ret_:
        graph.add_output_value(v)

    return (value_args, ret_, graph)