def test_flatten_inputs_to_1d_tensor(self): # B=3; no time axis. check( flatten_np(self.struct, spaces_struct=self.spaces), np.array([ [ 0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ], [ 0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0 ], [ 0.0, 0.0, 1.0, 0.0, 7.0, 8.0, 9.0, 2.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0 ], ])) struct_tf = tree.map_structure(lambda s: tf.convert_to_tensor(s), self.struct) check( flatten_tf(struct_tf, spaces_struct=self.spaces), np.array([ [ 0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ], [ 0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0 ], [ 0.0, 0.0, 1.0, 0.0, 7.0, 8.0, 9.0, 2.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0 ], ])) struct_torch = tree.map_structure(lambda s: torch.from_numpy(s), self.struct) check( flatten_torch(struct_torch, spaces_struct=self.spaces), np.array([ [ 0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ], [ 0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0 ], [ 0.0, 0.0, 1.0, 0.0, 7.0, 8.0, 9.0, 2.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0 ], ]))
def test_flatten_inputs_to_1d_tensor_w_time_axis(self): # B=2; T=1 check( flatten_np(self.struct_w_time_axis, spaces_struct=self.spaces, time_axis=True), np.array([ [[ 0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ]], [[ 0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0 ]], ])) struct_tf = tree.map_structure(lambda s: tf.convert_to_tensor(s), self.struct_w_time_axis) check( flatten_tf(struct_tf, spaces_struct=self.spaces, time_axis=True), np.array([ [[ 0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ]], [[ 0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0 ]], ])) struct_torch = tree.map_structure(lambda s: torch.from_numpy(s), self.struct_w_time_axis) check( flatten_torch(struct_torch, spaces_struct=self.spaces, time_axis=True), np.array([ [[ 0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ]], [[ 0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0 ]], ]))