Esempio n. 1
0
    def test_flatten_inputs_to_1d_tensor(self):
        # B=3; no time axis.
        check(
            flatten_np(self.struct, spaces_struct=self.spaces),
            np.array([
                [
                    0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0,
                    0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0
                ],
                [
                    0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0,
                    0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0
                ],
                [
                    0.0, 0.0, 1.0, 0.0, 7.0, 8.0, 9.0, 2.0, 1.0, 0.0, 1.0, 0.0,
                    0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0
                ],
            ]))

        struct_tf = tree.map_structure(lambda s: tf.convert_to_tensor(s),
                                       self.struct)
        check(
            flatten_tf(struct_tf, spaces_struct=self.spaces),
            np.array([
                [
                    0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0,
                    0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0
                ],
                [
                    0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0,
                    0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0
                ],
                [
                    0.0, 0.0, 1.0, 0.0, 7.0, 8.0, 9.0, 2.0, 1.0, 0.0, 1.0, 0.0,
                    0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0
                ],
            ]))

        struct_torch = tree.map_structure(lambda s: torch.from_numpy(s),
                                          self.struct)
        check(
            flatten_torch(struct_torch, spaces_struct=self.spaces),
            np.array([
                [
                    0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0,
                    0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0
                ],
                [
                    0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0,
                    0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0
                ],
                [
                    0.0, 0.0, 1.0, 0.0, 7.0, 8.0, 9.0, 2.0, 1.0, 0.0, 1.0, 0.0,
                    0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0
                ],
            ]))
Esempio n. 2
0
    def test_flatten_inputs_to_1d_tensor_w_time_axis(self):
        # B=2; T=1
        check(
            flatten_np(self.struct_w_time_axis,
                       spaces_struct=self.spaces,
                       time_axis=True),
            np.array([
                [[
                    0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0,
                    0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0
                ]],
                [[
                    0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0,
                    0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0
                ]],
            ]))

        struct_tf = tree.map_structure(lambda s: tf.convert_to_tensor(s),
                                       self.struct_w_time_axis)
        check(
            flatten_tf(struct_tf, spaces_struct=self.spaces, time_axis=True),
            np.array([
                [[
                    0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0,
                    0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0
                ]],
                [[
                    0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0,
                    0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0
                ]],
            ]))

        struct_torch = tree.map_structure(lambda s: torch.from_numpy(s),
                                          self.struct_w_time_axis)
        check(
            flatten_torch(struct_torch,
                          spaces_struct=self.spaces,
                          time_axis=True),
            np.array([
                [[
                    0.0, 1.0, 0.0, 0.0, 1.0, 2.0, 3.0, 8.0, 7.0, 6.0, 0.0, 1.0,
                    0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0
                ]],
                [[
                    0.0, 0.0, 0.0, 1.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 0.0, 0.0,
                    0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0
                ]],
            ]))