Example #1
0
    def test_make_parallel(self, output_spec):
        batch_size = 128
        input_spec = TensorSpec((1, 10, 10), torch.float32)

        conv_layer_params = ((2, 3, 2), (5, 3, 1))
        fc_layer_params = (256, 256)
        network = EncodingNetwork(input_tensor_spec=input_spec,
                                  output_tensor_spec=output_spec,
                                  conv_layer_params=conv_layer_params,
                                  fc_layer_params=fc_layer_params,
                                  activation=torch.relu_,
                                  last_layer_size=1,
                                  last_activation=math_ops.identity,
                                  name='base_encoding_network')
        replicas = 2
        num_layers = len(conv_layer_params) + len(fc_layer_params) + 1

        def _benchmark(pnet, name):
            t0 = time.time()
            outputs = []
            for _ in range(1000):
                embedding = input_spec.randn(outer_dims=(batch_size, ))
                output, _ = pnet(embedding)
                outputs.append(output)
            o = math_ops.add_n(outputs).sum()
            logging.info("%s time=%s %s" % (name, time.time() - t0, float(o)))

            if output_spec is None:
                self.assertEqual(output.shape, (batch_size, replicas, 1))
                self.assertEqual(pnet.output_spec.shape, (replicas, 1))
            else:
                self.assertEqual(output.shape,
                                 (batch_size, replicas, *output_spec.shape))
                self.assertEqual(pnet.output_spec.shape,
                                 (replicas, *output_spec.shape))

        pnet = network.make_parallel(replicas)
        self.assertTrue(isinstance(pnet, ParallelEncodingNetwork))
        self.assertEqual(len(list(pnet.parameters())), num_layers * 2)
        _benchmark(pnet, "ParallelEncodingNetwork")
        self.assertEqual(pnet.name, "parallel_" + network.name)

        pnet = alf.networks.network.NaiveParallelNetwork(network, replicas)
        _benchmark(pnet, "NaiveParallelNetwork")

        # test on default network name
        self.assertEqual(pnet.name, "naive_parallel_" + network.name)

        # test on user-defined network name
        pnet = alf.networks.network.NaiveParallelNetwork(network,
                                                         replicas,
                                                         name="pnet")
        self.assertEqual(pnet.name, "pnet")
Example #2
0
    def test_parallel_network_output_size(self, replicas):
        batch_size = 128
        input_spec = TensorSpec((100, ), torch.float32)

        # a dummy encoding network which ouputs the input
        network = EncodingNetwork(input_tensor_spec=input_spec)

        pnet = network.make_parallel(replicas)
        nnet = alf.networks.network.NaiveParallelNetwork(network, replicas)

        def _check_output_size(embedding):
            p_output, _ = pnet(embedding)
            n_output, _ = nnet(embedding)
            self.assertTrue(p_output.shape == n_output.shape)
            self.assertTrue(p_output.shape[1:] == pnet._output_spec.shape)

        # the case with shared inputs
        embedding = input_spec.randn(outer_dims=(batch_size, ))
        _check_output_size(embedding)

        # the case with non-shared inputs
        embedding = input_spec.randn(outer_dims=(batch_size, replicas))
        _check_output_size(embedding)
Example #3
0
    def test_make_parallel_warning_on_using_naive_parallel(self):
        input_spec = TensorSpec((256, ))
        fc_layer_params = (32, 32)

        pre_encoding_net = EncodingNetwork(input_tensor_spec=input_spec,
                                           fc_layer_params=fc_layer_params)

        network = EncodingNetwork(input_tensor_spec=input_spec,
                                  fc_layer_params=fc_layer_params,
                                  input_preprocessors=pre_encoding_net)

        replicas = 2

        # Create a parallel network via the ``make_parallel()`` interface.
        # As now ``input_preprocessors`` is not supported in
        # ``ParallelEncodingNetwork``, ``make_parallel()`` will return an
        # instance of the ``NaiveParallelNetwork``
        expected_warning_message = ("``NaiveParallelNetwork`` is used by "
                                    "``make_parallel()`` !")
        with self.assertLogs() as ctx:
            pnet = network.make_parallel(replicas)
            warning_message = ctx.records[0]
            assert expected_warning_message in str(warning_message)