def test_combine2(): dim1 = Dim.unnamed((1, 12800, 2)) dim2 = Dim.unnamed((1, 3200, 2)) dim3 = Dim.unnamed((1, 800, 2)) dim4 = Dim.unnamed((1, 200, 2)) res = Dim.combine((dim1, dim2, dim3, dim4), 1) assert res.shape == [1, 17000, 2]
def test_combine1(): dim1 = Dim.named_ordered(a=1, c=3, b=2) dim2 = Dim.named_ordered(a=1, c=3, b=2) dim3 = Dim.combine((dim1, dim2), 'c') assert dim3.shape == [1, 6, 2] dim3.c = 4 assert dim1.c == 3 and dim2.c == 3
def get_output_size(self, in_dims): if self.transpose_in: in_dims = [(in_dim.clone() if self.transpose_in[idx] is None else in_dim.clone().transpose(self.transpose_in[idx])) for idx, in_dim in enumerate(in_dims)] if in_dims[0].is_named and self._axis_hint: self._axis = in_dims[0].get_order_idx(self._axis_hint) out_dim = Dim.combine([in_dim for in_dim in in_dims], self.axis) if self.transpose_out: out_dim.transpose(self.transpose_out[0]) return [out_dim]
def get_output_size(self, in_dims): if in_dims[0].is_named and self._axis_hint: self._axis = in_dims[0].get_order_idx(self._axis_hint) out_dim = Dim.combine([in_dim for in_dim in in_dims], self.axis) return [out_dim]