Ejemplo n.º 1
0
    def test_multi_gpus(self):
        """
        make sure we don't count the models replicated on different GPUs
        """
        nb_cuda_devices = torch.cuda.device_count()
        if nb_cuda_devices < 2:
            # we do not have enough GPUs, abot the test
            warnings.warn(f'This test can\'t be run. Requires CUDA devices=2, got={nb_cuda_devices}', ResourceWarning)
            return

        model = nn.Sequential(
            nn.Linear(10, 100),
            nn.Linear(100, 5)
        )

        device = torch.device('cuda:0')
        model = trw.train.DataParallelExtended(model).to(device)
        batch = torch.zeros([11, 10], device=device)

        summary, total_output_size, total_params_size, total_params, trainable_params = model_summary_base(model, batch)

        # 2 denses + 2 biases
        expected_trainables = 10 * 100 + 100 * 5 + 100 + 5
        assert expected_trainables == trainable_params
        assert total_params == trainable_params
Ejemplo n.º 2
0
    def test_rnn_model(self):
        model = ModelRNN(28, 256, 1, 10)
        batch = {
            'images': torch.zeros([32, 1, 28, 28])
        }

        summary, total_output_size, total_params_size, total_params, trainable_params = model_summary_base(model, batch)

        # 2 denses + 2 biases
        assert len(list(summary.values())[0]['input_shape']) == 3
        assert len(list(summary.values())[0]['output_shape']) == 3

        expected_trainables = 295434
        assert expected_trainables == trainable_params
        assert total_params == trainable_params
Ejemplo n.º 3
0
    def test_simple_model_sequential(self):
        model = nn.Sequential(
            nn.Linear(10, 100),
            nn.Linear(100, 5)
        )

        batch = torch.zeros([11, 10])

        summary, total_output_size, total_params_size, total_params, trainable_params = model_summary_base(model, batch)

        # 2 denses + 2 biases
        expected_trainables = 10 * 100 + 100 * 5 + 100 + 5
        assert expected_trainables == trainable_params
        assert total_params == trainable_params
Ejemplo n.º 4
0
    def test_simple_model_nested_2heads(self):
        model = ModelDense_2_inputs_head()
        batch = {
            'input': torch.zeros([11, 10])
        }

        summary, total_output_size, total_params_size, total_params, trainable_params = model_summary_base(model, batch)

        # 2 denses + 2 biases
        expected_trainables = 2 * 10 * 100 + 2 * 100
        assert expected_trainables == trainable_params
        assert total_params == trainable_params
Ejemplo n.º 5
0
    def test_simple_model_internal_sequential(self):
        model = ModelDenseSequential()
        batch = {
            'input': torch.zeros([11, 10])
        }

        summary, total_output_size, total_params_size, total_params, trainable_params = model_summary_base(model, batch)

        # 2 denses + 2 biases
        expected_trainables = 10 * 100 + 100 * 5 + 100 + 5
        assert expected_trainables == trainable_params
        assert total_params == trainable_params