Example #1
0
    def _test_replace_func(self, fn, xs, set_grad=False):
        def make_list(v):
            if isinstance(v, (list, tuple)):
                return list(v)
            else:
                return [v]

        xvs = [x for x in xs if isinstance(x, chainer.Variable)]
        rfn = as_funcnode('fn')(fn)
        eys = make_list(fn(*xs))
        egxs = chainer.grad(eys, xvs, set_grad=set_grad)
        ays = make_list(rfn(*xs))
        agxs = chainer.grad(ays, xvs, set_grad=set_grad)
        assert len(eys) == len(ays)
        for ay, ey in zip(ays, eys):
            np.testing.assert_allclose(ay.array, ey.array)
        assert len(egxs) == len(agxs)
        for agx, egx in zip(agxs, egxs):
            if egx is None:
                assert egx is None
            else:
                np.testing.assert_allclose(agx.array, egx.array)
Example #2
0
import chainer
import chainer.functions as F
import numpy as np
import onnx
import onnx_chainer
from onnx_chainer.replace_func import as_funcnode


class Sign(chainer.Chain):
    def forward(self, x):
        y = F.relu(x)
        y = F.sign(y)
        y = F.relu(y)
        return y


F.sign = as_funcnode('Sign')(F.sign)


def convert_sign(param):
    return onnx.helper.make_node('Sign', param.input_names,
                                 param.output_names),


external_converters = {'Sign': convert_sign}

model = Sign()
onnx_chainer.export_testcase(model, [np.array(3.14)],
                             'sign',
                             external_converters=external_converters)