Esempio n. 1
0
 def test_forward_multi_input(self, multi_input_dnn):
     """Test Forward of Multi Input ConvNet."""
     master_device_setter(multi_input_dnn, 'cpu')
     input_tensor = [torch.rand(size=[10, 1, 28]),
                     torch.rand(size=[10, 1, 28, 28])]
     out = multi_input_dnn(input_tensor)
     assert out.shape == (10, 50)
Esempio n. 2
0
    def test_fail_mixed_devices(self, multi_input_cnn, conv3D_net,
                                multi_input_dnn, conv1D_net,
                                multi_input_dnn_data, multi_input_cnn_data):
        """Test training throws ValueError when network has mixed devices."""
        assert hasattr(conv1D_net, 'device')
        assert hasattr(conv3D_net, 'device')
        assert hasattr(multi_input_dnn, 'device')
        assert hasattr(multi_input_cnn, 'device')

        master_device_setter(multi_input_cnn, 'cuda:0')
        assert conv3D_net == multi_input_cnn.input_networks['conv3D_net']
        assert multi_input_dnn == multi_input_cnn.input_networks[
            'multi_input_dnn']

        data_len = len(multi_input_cnn_data)
        train_loader = DataLoader(
            Subset(multi_input_cnn_data, range(data_len // 2)))
        valid_loader = DataLoader(
            Subset(multi_input_cnn_data, range(data_len // 2, data_len)))

        multi_input_cnn.fit(train_loader=train_loader,
                            val_loader=valid_loader,
                            epochs=1,
                            plot=False)

        with pytest.raises(ValueError) as e_info:
            multi_input_cnn.input_networks['conv3D_net'].device = 'cpu'
            multi_input_cnn.fit(train_loader=train_loader,
                                val_loader=valid_loader,
                                epochs=1,
                                plot=False)

        assert str(e_info.value).endswith("{'conv3D_net': device(type='cpu')}")
Esempio n. 3
0
def test_master_device_setter(multi_input_cnn):
    """If CUDA is available, test master_device_setter usage."""
    # Make sure the network is in cpu first
    assert str(multi_input_cnn.device) == 'cpu'
    master_device_setter(multi_input_cnn, device='cuda:0')
    assert str(multi_input_cnn.device) == 'cuda:0'
    assert str(list(multi_input_cnn.input_networks.values())[0] == 'cuda:0')
    assert str(
        list(
            list(multi_input_cnn.input_networks.values())
            [2].input_networks.values())[0] == 'cuda:0')
Esempio n. 4
0
 def test_master_net_device_set_to_cuda(self, multi_input_cnn):
     """Test if the network as whole gets switched to cuda."""
     assert hasattr(multi_input_cnn, 'device')
     master_device_setter(multi_input_cnn, 'cuda:0')
     assert multi_input_cnn.device == torch.device(type='cuda', index=0)
     assert multi_input_cnn.input_networks['conv3D_net']\
         .device == torch.device(type='cuda', index=0)
     assert multi_input_cnn.input_networks['multi_input_dnn']\
         .device == torch.device(type='cuda', index=0)
     assert multi_input_cnn.input_networks['multi_input_dnn'].\
         input_networks['conv1D_net'].\
         device == torch.device(type='cuda', index=0)
     assert multi_input_cnn.input_networks['multi_input_dnn'].\
         input_networks['conv2D_net'].\
         device == torch.device(type='cuda', index=0)