def test_make_parallel(self, output_spec): batch_size = 128 input_spec = TensorSpec((1, 10, 10), torch.float32) conv_layer_params = ((2, 3, 2), (5, 3, 1)) fc_layer_params = (256, 256) network = EncodingNetwork(input_tensor_spec=input_spec, output_tensor_spec=output_spec, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params, activation=torch.relu_, last_layer_size=1, last_activation=math_ops.identity, name='base_encoding_network') replicas = 2 num_layers = len(conv_layer_params) + len(fc_layer_params) + 1 def _benchmark(pnet, name): t0 = time.time() outputs = [] for _ in range(1000): embedding = input_spec.randn(outer_dims=(batch_size, )) output, _ = pnet(embedding) outputs.append(output) o = math_ops.add_n(outputs).sum() logging.info("%s time=%s %s" % (name, time.time() - t0, float(o))) if output_spec is None: self.assertEqual(output.shape, (batch_size, replicas, 1)) self.assertEqual(pnet.output_spec.shape, (replicas, 1)) else: self.assertEqual(output.shape, (batch_size, replicas, *output_spec.shape)) self.assertEqual(pnet.output_spec.shape, (replicas, *output_spec.shape)) pnet = network.make_parallel(replicas) self.assertTrue(isinstance(pnet, ParallelEncodingNetwork)) self.assertEqual(len(list(pnet.parameters())), num_layers * 2) _benchmark(pnet, "ParallelEncodingNetwork") self.assertEqual(pnet.name, "parallel_" + network.name) pnet = alf.networks.network.NaiveParallelNetwork(network, replicas) _benchmark(pnet, "NaiveParallelNetwork") # test on default network name self.assertEqual(pnet.name, "naive_parallel_" + network.name) # test on user-defined network name pnet = alf.networks.network.NaiveParallelNetwork(network, replicas, name="pnet") self.assertEqual(pnet.name, "pnet")
def test_parallel_network_output_size(self, replicas): batch_size = 128 input_spec = TensorSpec((100, ), torch.float32) # a dummy encoding network which ouputs the input network = EncodingNetwork(input_tensor_spec=input_spec) pnet = network.make_parallel(replicas) nnet = alf.networks.network.NaiveParallelNetwork(network, replicas) def _check_output_size(embedding): p_output, _ = pnet(embedding) n_output, _ = nnet(embedding) self.assertTrue(p_output.shape == n_output.shape) self.assertTrue(p_output.shape[1:] == pnet._output_spec.shape) # the case with shared inputs embedding = input_spec.randn(outer_dims=(batch_size, )) _check_output_size(embedding) # the case with non-shared inputs embedding = input_spec.randn(outer_dims=(batch_size, replicas)) _check_output_size(embedding)
def test_make_parallel_warning_on_using_naive_parallel(self): input_spec = TensorSpec((256, )) fc_layer_params = (32, 32) pre_encoding_net = EncodingNetwork(input_tensor_spec=input_spec, fc_layer_params=fc_layer_params) network = EncodingNetwork(input_tensor_spec=input_spec, fc_layer_params=fc_layer_params, input_preprocessors=pre_encoding_net) replicas = 2 # Create a parallel network via the ``make_parallel()`` interface. # As now ``input_preprocessors`` is not supported in # ``ParallelEncodingNetwork``, ``make_parallel()`` will return an # instance of the ``NaiveParallelNetwork`` expected_warning_message = ("``NaiveParallelNetwork`` is used by " "``make_parallel()`` !") with self.assertLogs() as ctx: pnet = network.make_parallel(replicas) warning_message = ctx.records[0] assert expected_warning_message in str(warning_message)