def test_backprop():
    graph = chainer_compiler_core.load(
        'out/ch2o_node_Linear_backprop/model.onnx')
    params = graph.params()
    input_names = graph.input_names()
    output_names = graph.output_names()
    assert len(input_names) == 1
    assert len(output_names) == 1

    fwd_graph, bwd_graph = graph.backward()
    assert len(fwd_graph.input_names()) == 1
    assert len(fwd_graph.output_names()) == 3
    assert len(bwd_graph.input_names()) == 3
    assert len(bwd_graph.output_names()) == 2

    fwd = fwd_graph.compile()
    bwd = bwd_graph.compile()

    fwd_inputs = dict(params)
    t1 = aranges(5, 7)
    fwd_inputs[input_names[0]] = chainer_compiler_core.value(t1)

    loss = (chainerx.dot(t1, params['/l1/W'].array().T) +
            params['/l1/b'].array())

    fwd_outputs = fwd.run(fwd_inputs)
    assert len(fwd_outputs) == 3

    chainerx.testing.assert_allclose(loss,
                                     fwd_outputs[output_names[0]].array())

    grad_loss = aranges(*loss.shape) + 4.2

    bwd_inputs = {}
    for name in fwd_graph.output_names():
        iname = name
        value = fwd_outputs[name]
        if name in output_names:
            iname = 'grad_in@' + name
            value = chainer_compiler_core.value(grad_loss)
        bwd_inputs[iname] = value

    bwd_outputs = bwd.run(bwd_inputs)

    grad_w = chainerx.dot(grad_loss.T, t1)
    chainerx.testing.assert_allclose(grad_w,
                                     bwd_outputs['grad_out@/l1/W'].array())
    grad_b = chainerx.sum(grad_loss, axis=0)
    chainerx.testing.assert_allclose(grad_b,
                                     bwd_outputs['grad_out@/l1/b'].array())
Ejemplo n.º 2
0
 def forward_chainerx(self, x):
     a, b = x
     # TODO(sonots): Support transa and transb in ChainerX
     if self.transa or self.transb or self.transc:
         return chainer.Fallback
     # TODO(sonots): Support dtype promotion in ChainerX
     if a.dtype != b.dtype:
         return chainer.Fallback
     # TODO(sonots): Support ndim > 2 in ChainerX
     if a.ndim != 2 or b.ndim != 2:
         return chainer.Fallback
     # TODO(niboshi): Support it
     if self.dtype is not None and self.dtype != a.dtype:
         return chainer.Fallback
     return chainerx.dot(a, b),
Ejemplo n.º 3
0
 def forward_chainerx(self, x):
     a, b = x
     # TODO(sonots): Support transa and transb in ChainerX
     if self.transa or self.transb or self.transc:
         return chainer.Fallback
     # TODO(sonots): Support dtype promotion in ChainerX
     if a.dtype != b.dtype:
         return chainer.Fallback
     # TODO(sonots): Support ndim > 2 in ChainerX
     if a.ndim != 2 or b.ndim != 2:
         return chainer.Fallback
     # TODO(niboshi): Support it
     if self.dtype is not None and self.dtype != a.dtype:
         return chainer.Fallback
     return chainerx.dot(a, b),
def test_inference():
    graph = chainer_compiler_core.load('out/ch2o_node_Linear/model.onnx')
    params = graph.params()
    input_names = graph.input_names()
    output_names = graph.output_names()
    assert len(input_names) == 1
    assert len(output_names) == 2

    xcvm = graph.compile()

    inputs = dict(params)
    t1 = aranges(5, 7)
    inputs[input_names[0]] = chainer_compiler_core.value(t1)

    y1 = chainerx.dot(t1, params['/l1/W'].array().T) + params['/l1/b'].array()
    y2 = chainerx.dot(t1, params['/l2/W'].array().T)

    outputs = xcvm.run(inputs)
    assert len(outputs) == 2

    chainerx.testing.assert_allclose(y1, outputs[output_names[0]].array())
    chainerx.testing.assert_allclose(y2, outputs[output_names[1]].array())

    assert 'op_type: "ChainerLinear"' in graph.dump()