예제 #1
0
    def compile(self, inputs):
        xmodel = ch2o.compile_model(self.mc, inputs)
        f = tempfile.NamedTemporaryFile(delete=False)
        f.write(xmodel.SerializeToString())
        f.close()
        del xmodel

        graph = chainer_compiler_core.load(f.name)
        os.unlink(f.name)

        self.orig_output_names = graph.output_names()

        fwd_graph, bwd_graph = graph.backward_to(graph.input_names())
        if self.dump_onnx:
            sys.stderr.write('=== vvv forward vvv ===\n' + fwd_graph.dump() +
                             '\n=== ^^^ forward ^^^ ===\n')
            sys.stderr.write('=== vvv backward vvv ===\n' + bwd_graph.dump() +
                             '\n=== ^^^ backward ^^^ ===\n')

        assert graph.input_names() == fwd_graph.input_names()
        self.fwd_input_names = fwd_graph.input_names()
        self.fwd_output_names = fwd_graph.output_names()
        self.bwd_input_names = bwd_graph.input_names()
        self.bwd_output_names = bwd_graph.output_names()
        # TODO(hamaji): Revive shape inference.
        self.fwd = fwd_graph.compile(skip_inference=True)
        self.bwd = bwd_graph.compile(skip_inference=True)
        self.param_names = self.fwd_input_names[len(inputs):]

        self.compiled = True
예제 #2
0
def create_backprop_test(test_name, model, input_values):
    chainer.config.train = True
    model.cleargrads()
    output_values = model(*map(chainer.variable.Variable, input_values))

    test_dir = 'out/backprop_test_pc_%s' % test_name
    test_data_set_dir = os.path.join(test_dir, 'test_data_set_0')
    os.makedirs(test_data_set_dir, exist_ok=True)

    xmodel = ch2o.compile_model(model, input_values)
    all_input_tensors = xmodel.graph.input
    output_tensors = xmodel.graph.output

    if not isinstance(output_values, (list, tuple)):
        output_values = (output_values, )
    for output_value in output_values:
        output_value.grad = np.ones(output_value.shape, output_value.dtype)
        output_value.backward()

    ch2o.testcasegen.edit_onnx_protobuf(xmodel, model)

    initializer_names = set()
    for initializer in xmodel.graph.initializer:
        initializer_names.add(initializer.name)
    input_tensors = []
    for input_tensor in all_input_tensors:
        if input_tensor.name not in initializer_names:
            input_tensors.append(input_tensor)

    assert len(input_tensors) == len(input_values)
    assert len(output_tensors) == len(output_values)

    outputs = []
    for tensor, value in zip(output_tensors, output_values):
        outputs.append((tensor, value.array))
    for name, param in sorted(model.namedparams()):
        bp_name = onnx.helper.make_tensor_value_info('grad_out@' + name,
                                                     onnx.TensorProto.FLOAT,
                                                     ())
        outputs.append((bp_name, param.grad))

    ch2o.testcasegen.dump_test_inputs_outputs(
        list(zip(input_tensors, input_values)), outputs, test_data_set_dir)

    with open(os.path.join(test_dir, 'model.onnx'), 'wb') as fp:
        fp.write(xmodel.SerializeToString())
예제 #3
0
    def compile(self, inputs):
        if self.translator == 'ch2o':
            xmodel = ch2o.compile_model(self.mc, inputs)
            f = tempfile.NamedTemporaryFile(delete=False)
            f.write(xmodel.SerializeToString())
            f.close()
            del xmodel
        elif self.translator == 'onnx_chainer':
            import onnx_chainer
            f = tempfile.NamedTemporaryFile(delete=False)
            onnx_chainer.export(self.mc, inputs, filename=f)
            f.close()
        else:
            raise NotImplementedError('Unsupported translator:',
                                      self.translator)

        graph = chainer_compiler_core.load(f.name)
        os.unlink(f.name)

        self.orig_output_names = graph.output_names()

        if self.computation_order is None:
            fwd_graph, bwd_graph = graph.backward_to(graph.input_names() +
                                                     graph.param_names())
        else:
            fwd_graph, bwd_graph = graph.backward_to_with_order(
                self.computation_order)
        if self.dump_onnx:
            sys.stderr.write('=== vvv forward vvv ===\n' + fwd_graph.dump() +
                             '\n=== ^^^ forward ^^^ ===\n')
            sys.stderr.write('=== vvv backward vvv ===\n' + bwd_graph.dump() +
                             '\n=== ^^^ backward ^^^ ===\n')

        assert graph.input_names() == fwd_graph.input_names()
        self.fwd_input_names = fwd_graph.input_names()
        self.fwd_output_names = fwd_graph.output_names()
        self.bwd_input_names = bwd_graph.input_names()
        self.bwd_output_names = bwd_graph.output_names()
        # TODO(hamaji): Revive shape inference.
        self.fwd = fwd_graph.compile(skip_inference=True)
        self.bwd = bwd_graph.compile(skip_inference=True)
        self.param_names = fwd_graph.param_names()

        self.compiled = True