def test_transform_nest(self): ntuple = NTuple( a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))), b=torch.zeros((4, ))) transformed_ntuple = transform_nest( ntuple, field='a.x', func=lambda x: x + 1.0) ntuple.a.update({'x': torch.ones(())}) nest.map_structure(self.assertEqual, transformed_ntuple, ntuple) ntuple = NTuple( a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))), b=NTuple(a=torch.zeros((4, )), b=NTuple(a=[1], b=[1]))) transformed_ntuple = transform_nest( ntuple, field='b.b.b', func=lambda _: [2]) ntuple = ntuple._replace( b=ntuple.b._replace(b=ntuple.b.b._replace(b=[2]))) nest.map_structure(self.assertEqual, transformed_ntuple, ntuple) ntuple = NTuple(a=1, b=2) transformed_ntuple = transform_nest(ntuple, None, NestSum()) self.assertEqual(transformed_ntuple, 3) tuples = [("a", 12), ("b", 13)] nested = collections.OrderedDict(tuples) def _check_path(path, e): self.assertEqual(nested[path], e) res = nest.py_map_structure_with_path(_check_path, nested) nest.assert_same_structure(nested, res)
def test_encoding_network_preprocessing_combiner(self): input_spec = dict(a=TensorSpec((3, 80, 80)), b=[TensorSpec((80, 80)), TensorSpec(())]) imgs = common.zero_tensor_from_nested_spec(input_spec, batch_size=1) network = EncodingNetwork(input_tensor_spec=input_spec, preprocessing_combiner=NestSum(average=True), conv_layer_params=((1, 2, 2, 0), )) self.assertEqual(network._processed_input_tensor_spec, TensorSpec((3, 80, 80))) output, _ = network(imgs) self.assertTensorEqual(output, torch.zeros((40 * 40, )))
def test_q_value_network(self, lstm_hidden_size): input_spec = [TensorSpec((3, 20, 20), torch.float32)] conv_layer_params = ((8, 3, 1), (16, 3, 2, 1)) image = common.zero_tensor_from_nested_spec(input_spec, batch_size=1) network_ctor, state = self._init(lstm_hidden_size) q_net = network_ctor(input_spec, self._action_spec, input_preprocessors=[torch.relu], preprocessing_combiner=NestSum(), conv_layer_params=conv_layer_params) q_value, state = q_net(image, state) # (batch_size, num_actions) self.assertEqual(q_value.shape, (1, self._num_actions))
def test_nest_sum_specs(self): ntuple = NTuple( a=dict(x=TensorSpec(()), y=TensorSpec((2, 4))), b=TensorSpec((4, ))) ret = NestSum()(ntuple) # broadcasting self.assertEqual(ret, TensorSpec((2, 4)))
def test_nest_sum_tensors(self): ntuple = NTuple( a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))), b=torch.zeros((4, ))) ret = NestSum()(ntuple) # broadcasting self.assertTensorEqual(ret, torch.zeros((2, 4)))