def test_concatenate(concatenate_variables): x_list, np_list, pos = concatenate_variables with ExecutorFactory() as ex: v = ng.concat_along_axis(x_list, x_list[0].axes[pos]) d = ng.deriv(v, x_list[0], error=ng.constant(np.ones(v.axes.lengths), axes=v.axes)) f = ex.executor([v, d]) e_v, e_d = f() np_v = np.concatenate(np_list, axis=pos) ng.testing.assert_allclose(e_v.copy(), np_v) ng.testing.assert_allclose(e_d.copy(), np.ones(x_list[0].axes.lengths))
def test_concatenate(): with ExecutorFactory() as ex: A = ng.make_axis(name='A', length=3) B = ng.make_axis(name='B', length=4) np_shape = (A.length, B.length) x0_np = -np.ones(np_shape) x1_np = np.ones(np_shape) x0_ng = ng.persistent_tensor([A, B], initial_value=x0_np).named('x0') x1_ng = ng.persistent_tensor([A, B], initial_value=x1_np).named('x1') j_np = np.concatenate([x0_np, x1_np], axis=0) j_ng = ng.concat_along_axis([x0_ng, x1_ng], A) f = ex.executor(j_ng) j_val = f() ng.testing.assert_allclose(j_val, j_np)
def __call__(self, in_obj): branch_1_output = self.branch_1(in_obj) branch_2_output = self.branch_2[0](in_obj) branch_2_output = self.branch_2[1](branch_2_output) branch_3_output = self.branch_3[0](in_obj) branch_3_output = self.branch_3[1](branch_3_output) branch_4_output = self.branch_4[0](in_obj) branch_4_output = self.branch_4[1](branch_4_output) outputs = [ branch_1_output, branch_2_output, branch_3_output, branch_4_output ] # This does the equivalent of neon's merge-broadcast return ng.concat_along_axis(outputs, branch_1_output.axes.channel_axis())
def __call__(self, in_obj, merge_axis=None): outputs = [branch(in_obj) for branch in self.branches] if (isinstance(merge_axis, str)): merge_axis = ng.make_axis(name=merge_axis) if self.mode == 'concat': # Concatenate along the given axis if merge_axis is None: merge_axis = outputs[0].axes.channel_axis() outputs = ng.concat_along_axis(outputs, merge_axis) elif self.mode is None: # Return the output list directly pass else: pass return outputs
def test_concat_different_axis_lengths(): ax1 = ng.make_axis(length=3, name="concat") ax2 = ng.make_axis(length=2, name="concat") ax3 = ng.make_axis(length=10, name="other") x = ng.placeholder(axes=[ax1, ax3]) y = ng.placeholder(axes=[ax2, ax3]) np_x = np.zeros(x.axes.lengths) np_y = np.zeros(y.axes.lengths) # ax1 and ax2 have same name, so this should work v = ng.concat_along_axis([x, y], ax1) with ExecutorFactory() as ex: f = ex.executor(v, x, y) e_v = f(np_x, np_y) np_v = np.concatenate([np_x, np_y], axis=0) ng.testing.assert_allclose(e_v.copy(), np_v)