def convert_local_response_normalization(onnx_graph, node):
    kwargs = {}
    kwargs['size'] = oc.try_get_attribute(node.attribute_args.keywords['n'])
    kwargs['bias'] = float(
        oc.try_get_attribute(node.attribute_args.keywords['k']))
    kwargs['alpha'] = float(
        oc.try_get_attribute(node.attribute_args.keywords['alpha']) *
        kwargs['size'])
    kwargs['beta'] = float(
        oc.try_get_attribute(node.attribute_args.keywords['beta']))

    onnx_graph.add_node(
        "LRN",
        [node.inputs[0]],
        node.outputs,
        str(node.lineprop),
        **kwargs,
    )
def convert_unpooling_2d(onnx_graph, node: 'nodes.NodeCall'):
    ksize = oc.try_get_attribute(node.attribute_args.keywords['ksize'])
    stride = oc.try_get_attribute(node.attribute_args.keywords['stride'])
    pad = oc.try_get_attribute(node.attribute_args.keywords['pad'])
    outsize = oc.try_get_attribute(node.attribute_args.keywords['outsize'])
    cover_all = oc.try_get_attribute(node.attribute_args.keywords['cover_all'])

    assert (stride is None)  # TODO(hamaji): Not supported yet.
    assert (pad == 0)  # TODO(hamaji): Not supported yet.
    assert (outsize is None)  # TODO(hamaji): Not supported yet.
    assert (cover_all is False)  # TODO(hamaji): Not supported yet.

    scales = np.array([1, 1] + list(_pair(ksize)), dtype=np.float32)
    scales_ = oc.ONNXValue(onnx_graph,
                           scales, [node, '/Scale'],
                           is_constant=True)
    onnx_graph.add_node("Upsample", [node.inputs[0], scales_],
                        [node.outputs[0]],
                        name=str(node.lineprop))
def convert_concat(onnx_graph, node):
    xs = oc.ONNXValue(onnx_graph, node.args.keywords['xs'])
    axis = oc.try_get_attribute(node.attribute_args.keywords['axis'])

    onnx_graph.add_node(
        "ChainerSequenceConcat",
        [xs.create_sequence()],
        node.outputs,
        str(node.lineprop),
        axis=axis,
    )
def convert_dropout(onnx_graph, node):
    x = oc.ONNXValue(onnx_graph, node.args.keywords['x'])
    ratio = oc.try_get_attribute(node.attribute_args.keywords['ratio'])

    onnx_graph.add_node(
        "Dropout",
        [x],
        node.outputs,
        str(node.lineprop),
        ratio=ratio,
    )
def convert_roi_average_pooling_2d(onnx_graph, node):
    x = oc.ONNXValue(onnx_graph, node.args.keywords['x'])
    rois = oc.ONNXValue(onnx_graph, node.args.keywords['rois'])
    roi_indices = oc.ONNXValue(onnx_graph, node.args.keywords['roi_indices'])
    outsize = oc.ONNXValue(onnx_graph, node.args.keywords['outsize'])
    spatial_scale = oc.ONNXValue(onnx_graph,
                                 node.args.keywords['spatial_scale'])

    def _pair(x):
        if isinstance(x, collections.Iterable):
            return x
        return (x, x)

    onnx_graph.add_node(
        "ChainerROIAveragePool2D", [
            x.create_tensor(node.lineprop),
            rois.create_tensor(node.lineprop),
            roi_indices.create_tensor(node.lineprop)
        ],
        node.outputs,
        str(node.lineprop),
        output_shape=_pair(oc.try_get_attribute(outsize.value)),
        spatial_scale=oc.try_get_attribute(spatial_scale.value))
    return
def convert_max_pooling_2d(onnx_graph, node):
    def _pair(x):
        if isinstance(x, collections.Iterable):
            return x
        return (x, x)

    ksize = oc.try_get_attribute(node.attribute_args.keywords['ksize'])
    stride = oc.try_get_attribute(node.attribute_args.keywords['stride'])
    pad = oc.try_get_attribute(node.attribute_args.keywords['pad'])
    cover_all = oc.try_get_attribute(node.attribute_args.keywords['cover_all'])
    return_indices = oc.try_get_attribute(
        node.attribute_args.keywords['return_indices'])

    assert not return_indices  # TODO(hamaji): Not implemented yet.

    kwargs = {}
    kwargs['kernel_shape'] = _pair(ksize)

    if stride is not None:
        kwargs['strides'] = _pair(stride)
    else:
        kwargs['strides'] = _pair(ksize)

    if pad is not None:
        kwargs['pads'] = _pair(pad) * 2
    else:
        kwargs['pads'] = _pair(0)

    onnx_graph.add_node(
        "MaxPool",
        [node.inputs[0]],
        [node.outputs[0]],
        name=str(node.lineprop),
        chainer_cover_all=cover_all,
        **kwargs,
    )
def convert_softmax_cross_entropy(onnx_graph, node):
    normalize = oc.try_get_attribute(node.attribute_args.keywords['normalize'])
    cache_score = oc.try_get_attribute(
        node.attribute_args.keywords['cache_score'])
    class_weight = oc.try_get_attribute(
        node.attribute_args.keywords['class_weight'])
    ignore_label = oc.try_get_attribute(
        node.attribute_args.keywords['ignore_label'])
    reduce = oc.try_get_attribute(node.attribute_args.keywords['reduce'])
    enable_double_backprop = oc.try_get_attribute(
        node.attribute_args.keywords['enable_double_backprop'])

    assert normalize  # TODO(hamaji): Not supported yet.
    assert cache_score  # TODO(hamaji): Not supported yet.
    assert class_weight is None  # TODO(hamaji): Not supported yet.
    assert ignore_label == -1  # TODO(hamaji): Not supported yet.
    assert reduce == 'mean'  # TODO(hamaji): Not supported yet.
    assert not enable_double_backprop  # TODO(hamaji): Not supported yet.

    onnx_graph.add_node("ChainerSoftmaxCrossEntropy", node.inputs[0:2],
                        node.outputs, str(node.lineprop))
def convert_softmax(onnx_graph, node):
    onnx_graph.add_node("Softmax", [node.inputs[0]], [node.outputs[0]],
                        str(node.lineprop),
                        axis=oc.try_get_attribute(node.inputs[1]),
                        chainer_is_onnx_semantics=False)
def convert_onnx_chainer_linear(onnx_graph: 'ONNXGraph',
                                node: 'nodes.NodeCall'):
    chainer_inst = node.func.owner.inst  # type: chainer.links.Linear
    onnx_name = oc.node2onnx_parameter[node].onnx_name

    x = oc.ONNXValue(onnx_graph, node.args.get_value('x'))
    axes = oc.try_get_attribute(node.args.get_value('n_batch_axes'), node)
    o = oc.ONNXValue(onnx_graph, node.outputs[0])

    if chainer_inst.W.data is None:
        print("W is unknown. Please infer this model.")

    w = oc.ONNXValue(onnx_graph, chainer_inst.W)

    if axes != 1:
        inputs = [x, w]

        if chainer_inst.b is not None:
            b = oc.ONNXValue(onnx_graph, chainer_inst.b)
            inputs.append(b)

        onnx_graph.add_node('ChainerLinear',
                            inputs, [o],
                            str(node.lineprop),
                            n_batch_axes=axes)
        return

    (x_shape, ) = onnx_graph.add_node('Shape', [x], [None], str(node.lineprop))

    (batch_size_1, ) = onnx_graph.add_node('Gather', [
        x_shape,
        oc.ONNXValue(onnx_graph, np.array(0, dtype=np.int64),
                     [onnx_name, '/Zero'])
    ], [None], str(node.lineprop))

    (batch_size_2, ) = onnx_graph.add_node('Unsqueeze', [batch_size_1], [None],
                                           str(node.lineprop),
                                           axes=[0])

    (mat_shape, ) = onnx_graph.add_node('Concat', [
        batch_size_2,
        oc.ONNXValue(onnx_graph, np.array([-1], dtype=np.int64),
                     [onnx_name, '/Minus1'])
    ], [None],
                                        str(node.lineprop),
                                        axis=0)

    (x_reshape, ) = onnx_graph.add_node('Reshape', [x, mat_shape], [None],
                                        str(node.lineprop))

    if chainer_inst.b is not None:
        b = oc.ONNXValue(onnx_graph, chainer_inst.b)

        onnx_graph.add_node('Gemm', [x_reshape, w, b], [o],
                            str(node.lineprop),
                            transA=0,
                            transB=1)
    else:
        temp = oc.ONNXValue(onnx_graph, np.float32, [onnx_name, '/Temp'])

        onnx_graph.add_node('Transpose', [w], [temp],
                            str(node.lineprop),
                            perm=[1, 0])

        onnx_graph.add_node('MatMul', [x_reshape, temp], [o],
                            str(node.lineprop))