Ejemplo n.º 1
0
def eval_for(nast, env):
    assert nast.orelse == []
    ite = eval_ast(nast.iter, env)

    # A hack for ResNet50.
    # TODO(hamaji): Come up with a sophisticated way.
    # TODO(hamaji): This code doesn't handle scope properly, I think.
    if (isinstance(ite.value, types.GeneratorType)
            and 'ChainList.children' in str(ite.value)):
        # とりあえず実際にfor文を回す
        tg = nast.target.id
        env.set_var(tg, Value(None))
        for v in ite.value:
            env.set_var(tg, _value(v))
            eval_ast(nast.body, env)
            # print('looping',env.vars.keys())

        env.pop_var(tg)
        return None

    if ite.is_py:
        ite = Value([Value(v) for v in ite.value])

    assert isinstance(nast.target, gast.Name)
    x = nast.target.id

    # 新たなenv を作って、評価中にできた子グラフをもとにする
    localenv = env.new_block()

    cnt = new_tensor()
    gtx = new_sequence()
    localenv.set_var(
        x,
        _value(
            localenv.calc(
                "ChainerSequenceLookup",
                inputs=[gtx.name, cnt.name],
            )))
    ty = eval_ast(nast.body, localenv)
    assert ty.is_none()

    in_out = _find_in_out(localenv, env)

    input_values = []
    output_values = []
    final_outputs = []
    final_setattrs = []
    for key, (iv, ov, setattr_info) in in_out.items():
        if ov is None:
            continue
        if iv is None:
            iv = Value(False)
        out = ov.copy(env, name=key)
        final_outputs.append((key, out.value))
        if setattr_info is not None:
            final_setattrs.append(tuple(list(setattr_info) + [out]))
        input_values.append(iv.to_value_info(env))
        output_values.append(ov.to_value_info(env))

    cond = new_tensor(name='loop_cond')
    localgraph = make_graph(localenv.nodes, "Loop_subgraph",
                            [cnt, cond, gtx] + input_values,
                            [cond, gtx] + output_values)

    mtc = env.calc(
        "ChainerGenericLen",
        inputs=[ite.to_sequence(env).name],
    )

    env.addnode('Loop',
                inputs=([mtc.name, "", ite.to_sequence(env).name] +
                        [i.name for i in input_values]),
                outputs=([new_tensor('out_generator').name] +
                         [o.name for _, o in final_outputs]),
                body=localgraph)

    for k, o in final_outputs:
        if '.' not in k and '/' not in k:
            env.set_var(k, _value(o))

    for var, key, value in final_setattrs:
        setattr(var.value, key, value)

    return None
Ejemplo n.º 2
0
def eval_if(nast, env):
    cond = eval_ast(nast.test, env)
    if cond.is_py and cond.value is True:
        return eval_ast(nast.body, env)
    elif cond.is_py and cond.value is False:
        return eval_ast(nast.orelse, env)

    then_env = env.new_block()
    ty = eval_ast(nast.body, then_env)
    assert ty.is_none()

    else_env = env.new_block()
    ty = eval_ast(nast.orelse, else_env)
    assert ty.is_none()

    then_in_out = _find_in_out(then_env, env)
    else_in_out = _find_in_out(else_env, env)
    keys = set(list(then_in_out.keys()) + list(else_in_out.keys()))

    input_values = []
    then_outputs = []
    else_outputs = []
    final_outputs = []
    final_setattrs = []

    for key in keys:
        then_iv, then_ov, then_setattr_info = then_in_out.get(
            key, (None, None, None))
        else_iv, else_ov, else_setattr_info = else_in_out.get(
            key, (None, None, None))

        if then_setattr_info is None:
            setattr_info = else_setattr_info
        else:
            if else_setattr_info is not None:
                assert then_setattr_info == else_setattr_info
            setattr_info = then_setattr_info

        def set_final_output(key, out):
            out = out.copy(env, name=key)
            final_outputs.append((key, out.value))
            if setattr_info is not None:
                final_setattrs.append(tuple(list(setattr_info) + [out]))

        iv = else_iv if then_iv is None else then_iv
        if iv is None:
            iv = Value(False)
        input_values.append(iv.to_value_info(env))

        if then_ov is None and else_ov is None:
            continue
        if then_ov is None:
            then_outputs.append(iv.to_value_info(env))
            else_outputs.append(else_ov.to_value_info(else_env))
            set_final_output(key, else_ov)
        elif else_ov is None:
            then_outputs.append(then_ov.to_value_info(then_env))
            else_outputs.append(iv.to_value_info(env))
            set_final_output(key, then_ov)
        else:
            then_outputs.append(then_ov.to_value_info(then_env))
            else_outputs.append(else_ov.to_value_info(else_env))
            set_final_output(key, then_ov)

    then_graph = make_graph(
        then_env.nodes,
        "If_then",
        input_values,
        then_outputs,
    )

    else_graph = make_graph(
        else_env.nodes,
        "If_else",
        input_values,
        else_outputs,
    )

    env.addnode(
        'If',
        inputs=[cond.to_value_info(env).name] + [i.name for i in input_values],
        outputs=[o.name for _, o in final_outputs],
        then_branch=then_graph,
        else_branch=else_graph,
    )

    for k, o in final_outputs:
        env.set_var(k, _value(o))

    for var, key, value in final_setattrs:
        setattr(var.value, key, value)

    return None