Example #1
0
    def test_transform_nest(self):
        ntuple = NTuple(
            a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))),
            b=torch.zeros((4, )))
        transformed_ntuple = transform_nest(
            ntuple, field='a.x', func=lambda x: x + 1.0)
        ntuple.a.update({'x': torch.ones(())})
        nest.map_structure(self.assertEqual, transformed_ntuple, ntuple)

        ntuple = NTuple(
            a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))),
            b=NTuple(a=torch.zeros((4, )), b=NTuple(a=[1], b=[1])))
        transformed_ntuple = transform_nest(
            ntuple, field='b.b.b', func=lambda _: [2])
        ntuple = ntuple._replace(
            b=ntuple.b._replace(b=ntuple.b.b._replace(b=[2])))
        nest.map_structure(self.assertEqual, transformed_ntuple, ntuple)

        ntuple = NTuple(a=1, b=2)
        transformed_ntuple = transform_nest(ntuple, None, NestSum())
        self.assertEqual(transformed_ntuple, 3)

        tuples = [("a", 12), ("b", 13)]
        nested = collections.OrderedDict(tuples)

        def _check_path(path, e):
            self.assertEqual(nested[path], e)

        res = nest.py_map_structure_with_path(_check_path, nested)
        nest.assert_same_structure(nested, res)
Example #2
0
    def test_encoding_network_preprocessing_combiner(self):
        input_spec = dict(a=TensorSpec((3, 80, 80)),
                          b=[TensorSpec((80, 80)),
                             TensorSpec(())])
        imgs = common.zero_tensor_from_nested_spec(input_spec, batch_size=1)
        network = EncodingNetwork(input_tensor_spec=input_spec,
                                  preprocessing_combiner=NestSum(average=True),
                                  conv_layer_params=((1, 2, 2, 0), ))

        self.assertEqual(network._processed_input_tensor_spec,
                         TensorSpec((3, 80, 80)))

        output, _ = network(imgs)
        self.assertTensorEqual(output, torch.zeros((40 * 40, )))
Example #3
0
    def test_q_value_network(self, lstm_hidden_size):
        input_spec = [TensorSpec((3, 20, 20), torch.float32)]
        conv_layer_params = ((8, 3, 1), (16, 3, 2, 1))
        image = common.zero_tensor_from_nested_spec(input_spec, batch_size=1)

        network_ctor, state = self._init(lstm_hidden_size)

        q_net = network_ctor(input_spec,
                             self._action_spec,
                             input_preprocessors=[torch.relu],
                             preprocessing_combiner=NestSum(),
                             conv_layer_params=conv_layer_params)
        q_value, state = q_net(image, state)

        # (batch_size, num_actions)
        self.assertEqual(q_value.shape, (1, self._num_actions))
Example #4
0
 def test_nest_sum_specs(self):
     ntuple = NTuple(
         a=dict(x=TensorSpec(()), y=TensorSpec((2, 4))),
         b=TensorSpec((4, )))
     ret = NestSum()(ntuple)  # broadcasting
     self.assertEqual(ret, TensorSpec((2, 4)))
Example #5
0
 def test_nest_sum_tensors(self):
     ntuple = NTuple(
         a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))),
         b=torch.zeros((4, )))
     ret = NestSum()(ntuple)  # broadcasting
     self.assertTensorEqual(ret, torch.zeros((2, 4)))