예제 #1
0
def _split_tensorwise(graph: Graph, op: Operator, v: Variable,
                      v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    xs = dict(op.inputs)
    ys = dict(op.outputs)
    op.remove_all()

    op_0 = op.copy()
    op_1 = op.copy()

    for key, x in xs.items():
        if x == v:
            x_0, x_1 = v_pair

        else:
            if axis in x.order.axes:
                x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)

            else:
                # splitting is not occurred
                x_0 = x_1 = x

        op_0.append_input(key, x_0)
        op_1.append_input(key, x_1)

    for key, y in ys.items():
        if y == v:
            y_0, y_1 = v_pair

        else:
            if axis in y.order.axes:
                # TODO (Kiikurage)
                # Attribute attached to "y" is not copied to neither "y_0" or "y_1"
                y_0 = Variable([
                    s1 if a == axis else y.shape_dict[a] for a in y.order.axes
                ], y.order)
                y_1 = Variable([
                    s2 if a == axis else y.shape_dict[a] for a in y.order.axes
                ], y.order)
                y_new, = Concat(None, axis=axis)(y_0, y_1)
                OptimizeRule.replace_variable(graph, y, y_new)

            else:
                raise UnexpectedAndPleaseReportError

        op_0.append_output(key, y_0)
        op_1.append_output(key, y_1)
예제 #2
0
def _split_tensorwise(graph: Graph, op: Operator, v: Variable,
                      v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    xs = dict(op.inputs)
    ys = dict(op.outputs)
    op.remove_all()

    op_0 = op.copy()
    op_1 = op.copy()

    for key in xs.keys():
        x = xs[key]
        if x == v:
            x_0, x_1 = v_pair

        else:
            if axis not in x.order.axes or x.shape_dict[axis] == 1:
                # broadcasting
                x_0 = x_1 = x

            else:
                x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)

        op_0.append_input(key, x_0)
        op_1.append_input(key, x_1)

    op_0.exec()
    op_1.exec()

    for key in ys.keys():
        y = ys[key]
        if y == v:
            OptimizeRule.replace_variable(
                graph, op_0.outputs[key].transpose_like(v_pair[0]), v_pair[0])
            OptimizeRule.replace_variable(
                graph, op_1.outputs[key].transpose_like(v_pair[1]), v_pair[1])

        else:
            y_0 = op_0.outputs[key]
            y_1 = op_1.outputs[key]
            y_new, = Concat(None, axis=axis)(y_0, y_1)
            OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y)