def test_concatenate(concatenate_variables):
    x_list, np_list, pos = concatenate_variables

    with ExecutorFactory() as ex:
        v = ng.concat_along_axis(x_list, x_list[0].axes[pos])
        d = ng.deriv(v,
                     x_list[0],
                     error=ng.constant(np.ones(v.axes.lengths), axes=v.axes))
        f = ex.executor([v, d])
        e_v, e_d = f()
        np_v = np.concatenate(np_list, axis=pos)
        ng.testing.assert_allclose(e_v.copy(), np_v)
        ng.testing.assert_allclose(e_d.copy(), np.ones(x_list[0].axes.lengths))
Пример #2
0
def test_concatenate():
    with ExecutorFactory() as ex:
        A = ng.make_axis(name='A', length=3)
        B = ng.make_axis(name='B', length=4)
        np_shape = (A.length, B.length)
        x0_np = -np.ones(np_shape)
        x1_np = np.ones(np_shape)
        x0_ng = ng.persistent_tensor([A, B], initial_value=x0_np).named('x0')
        x1_ng = ng.persistent_tensor([A, B], initial_value=x1_np).named('x1')
        j_np = np.concatenate([x0_np, x1_np], axis=0)
        j_ng = ng.concat_along_axis([x0_ng, x1_ng], A)
        f = ex.executor(j_ng)
        j_val = f()
        ng.testing.assert_allclose(j_val, j_np)
Пример #3
0
    def __call__(self, in_obj):

        branch_1_output = self.branch_1(in_obj)
        branch_2_output = self.branch_2[0](in_obj)
        branch_2_output = self.branch_2[1](branch_2_output)
        branch_3_output = self.branch_3[0](in_obj)
        branch_3_output = self.branch_3[1](branch_3_output)
        branch_4_output = self.branch_4[0](in_obj)
        branch_4_output = self.branch_4[1](branch_4_output)

        outputs = [
            branch_1_output, branch_2_output, branch_3_output, branch_4_output
        ]
        # This does the equivalent of neon's merge-broadcast
        return ng.concat_along_axis(outputs,
                                    branch_1_output.axes.channel_axis())
Пример #4
0
    def __call__(self, in_obj, merge_axis=None):
        outputs = [branch(in_obj) for branch in self.branches]
        if (isinstance(merge_axis, str)):
            merge_axis = ng.make_axis(name=merge_axis)

        if self.mode == 'concat':
            # Concatenate along the given axis
            if merge_axis is None:
                merge_axis = outputs[0].axes.channel_axis()
            outputs = ng.concat_along_axis(outputs, merge_axis)
        elif self.mode is None:
            # Return the output list directly
            pass
        else:
            pass
        return outputs
def test_concat_different_axis_lengths():
    ax1 = ng.make_axis(length=3, name="concat")
    ax2 = ng.make_axis(length=2, name="concat")
    ax3 = ng.make_axis(length=10, name="other")

    x = ng.placeholder(axes=[ax1, ax3])
    y = ng.placeholder(axes=[ax2, ax3])

    np_x = np.zeros(x.axes.lengths)
    np_y = np.zeros(y.axes.lengths)

    # ax1 and ax2 have same name, so this should work
    v = ng.concat_along_axis([x, y], ax1)
    with ExecutorFactory() as ex:
        f = ex.executor(v, x, y)
        e_v = f(np_x, np_y)
        np_v = np.concatenate([np_x, np_y], axis=0)
        ng.testing.assert_allclose(e_v.copy(), np_v)