Beispiel #1
0
def dygraph2program(layer,
                    inputs,
                    feed_prefix='feed_',
                    fetch_prefix='fetch_',
                    tmp_prefix='t_',
                    extract_inputs_fn=None,
                    extract_outputs_fn=None,
                    dtypes=None):
    print(type(layer))
    assert isinstance(layer, Layer)
    extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars
    extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars

    if os.environ.get("FLAGS_enable_eager_mode") == "1":
        return _dy2prog(layer, inputs, feed_prefix, fetch_prefix, tmp_prefix,
                        extract_inputs_fn, extract_outputs_fn, dtypes)

    tracer = _dygraph_tracer()._get_program_desc_tracer()

    with program_desc_tracing_guard(True):

        if _is_shape(inputs):
            shapes = [inputs]
            inputs = _create_tensors(shapes, dtypes=dtypes)
            input_var_list = inputs
        elif _is_shapes(inputs):
            inputs = _create_tensors(inputs, dtypes=dtypes)
            input_var_list = inputs
        else:
            inputs = to_variables(inputs)
            input_var_list = extract_inputs_fn(inputs)

        original_outputs = layer(*inputs)
        # 'original_outputs' may be dict, so we should convert it to list of varibles.
        # And should not create new varibles in 'extract_vars'.
        out_var_list = extract_outputs_fn(original_outputs)
        program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
            input_var_list, feed_prefix, out_var_list, fetch_prefix,
            tmp_prefix)
        tracer.reset()

    with _dygraph_guard(None):
        program = Program()
        program.desc = program_desc
        program.blocks = [Block(program, 0)]
        program._sync_with_cpp()
    return program
Beispiel #2
0
def create_program_from_desc(program_desc):
    program = Program()
    program.desc = program_desc
    program.blocks = [Block(program, 0)]
    program._sync_with_cpp()
    return program