示例#1
0
 def load_data(self):
     print("=> loading data")
     self.data = MNISTData(self.subset, self.fine, self.permutation,
                           self.train_indices, self.test_indices)
     self.train_loader = self.data.get_train_loader(self.batch_size,
                                                    self.num_workers)
     self.test_loader = self.data.get_test_loader(self.batch_size,
                                                  self.num_workers)
def run(**kwargs):
    allow_plots = ALLOW_PLOTS
    continue_training = CONTINUE_TRAINING
    inference_cpu_only = INFERENCE_CPU_ONLY
    use_batchnorm = USE_BATCHNORM

    for k in kwargs.keys():
        if k == 'allow_plots':
            allow_plots = kwargs[k]
        elif k == 'continue_training':
            continue_training = kwargs[k]
        elif k == 'inference_cpu_only':
            inference_cpu_only = kwargs[k]
        elif k == 'use_batchnorm':
            use_batchnorm = kwargs[k]
        else:
            logger.warn('Keyword \'%s\' is unknown.' % k)

    logger.info('### Loading dataset ...')

    data = MNISTData(config.dataset_path, use_one_hot=True)

    # Important! Let the network know, which dataset to use.
    shared.data = data

    logger.info('### Loading dataset ... Done')

    logger.info('### Build, train and test network ...')

    train_net = setup_network(allow_plots,
                              continue_training,
                              inference_cpu_only,
                              use_batchnorm,
                              mode='train')

    train_net.train(num_iter=10001)

    test_net = setup_network(allow_plots,
                             continue_training,
                             inference_cpu_only,
                             use_batchnorm,
                             mode='inference')
    test_net.test()

    if allow_plots:
        # Example Test Samples
        sample_batch = data.next_test_batch(8)
        predictions = test_net.run(sample_batch[0])
        shared.data.plot_samples('Example MNIST Predictions',
                                 sample_batch[0],
                                 outputs=sample_batch[1],
                                 predictions=predictions,
                                 interactive=True)

    logger.info('### Build, train and test network ... Done')
def run(**kwargs):
    allow_plots = ALLOW_PLOTS
    
    for k in kwargs.keys():
        if k == 'allow_plots':
            allow_plots = kwargs[k]
        else:
            logger.warn('Keyword \'%s\' is unknown.' % k)

    logger.info('### Loading dataset ...')

    data = MNISTData(config.dataset_path)

    # Important! Let the network know, which dataset to use.
    shared.data = data

    logger.info('### Loading dataset ... Done')


    logger.info('### Build, train and test network ...')

    # Train the network
    train_net = SimpleAE(mode='train')
    train_net.allow_plots = allow_plots
    train_net.build()
    train_net.train()

    # Test the network
    test_net = SimpleAE(mode='inference')
    test_net.allow_plots = allow_plots
    test_net.build()
    test_net.test()

    if allow_plots:
        # Feed a random test sample through the network and display the output
        # for the user.
        sample = data.next_test_batch(1)
        net_out = test_net.run(sample[0])
    
        fig = plt.figure()
        plt.ion()
        plt.suptitle('Sample Image')
        ax = fig.add_subplot(1,2,1)
        ax.set_axis_off()
        ax.imshow(np.squeeze(sample[0].reshape(data.in_shape)),
                  vmin=-1.0, vmax=1.0)
        ax.set_title('Input')
        ax = fig.add_subplot(1,2,2)
        ax.set_axis_off()
        ax.imshow(np.squeeze(net_out.reshape(data.in_shape)),
                  vmin=-1.0, vmax=1.0)
        ax.set_title('Output')
        plt.show()

    logger.info('### Build, train and test network ... Done')
示例#4
0
def run(**kwargs):
    allow_plots = ALLOW_PLOTS
    continue_training = CONTINUE_TRAINING
    inference_cpu_only = INFERENCE_CPU_ONLY
    use_biases = USE_BIASES

    for k in kwargs.keys():
        if k == 'allow_plots':
            allow_plots = kwargs[k]
        elif k == 'continue_training':
            continue_training = kwargs[k]
        elif k == 'inference_cpu_only':
            inference_cpu_only = kwargs[k]
        elif k == 'use_biases':
            use_biases = kwargs[k]
        else:
            logger.warn('Keyword \'%s\' is unknown.' % k)

    logger.info('### Loading dataset ...')

    data = MNISTData(config.dataset_path)

    # Important! Let the network know, which dataset to use.
    shared.data = data

    logger.info('### Loading dataset ... Done')

    logger.info('### Build, train and test network ...')

    train_net = setup_network(allow_plots,
                              continue_training,
                              inference_cpu_only,
                              use_biases,
                              mode='train')
    train_net.train(num_iter=50000)

    test_net = setup_network(allow_plots,
                             continue_training,
                             inference_cpu_only,
                             use_biases,
                             mode='inference')
    test_net.test()

    if allow_plots:
        # Generate some fake images.
        latent_inputs = test_net.sample_latent(8)
        _, fake_dis_outs, fake_imgs = test_net.run(np.empty((0, 0)),
                                                   latent_inputs=latent_inputs)
        dplt.plot_gan_images('Generator Samples',
                             np.empty((0, 0)),
                             fake_imgs,
                             fake_dis_outputs=fake_dis_outs,
                             shuffle=True,
                             interactive=True)

    logger.info('### Build, train and test network ... Done')
示例#5
0
class MNISTAgent(Agent):
    '''
    MNISTAgent for MNIST and Fashion-MNIST.
    '''
    def __init__(self,
                 global_args,
                 subset=tuple(range(10)),
                 fine='MNIST',
                 train_indices=None,
                 test_indices=None):

        super().__init__(global_args, subset, fine, train_indices,
                         test_indices)
        self.permutation = None
        if self.fine == 'MNIST':
            self.permutation = np.random.permutation(28 * 28)

    def load_data(self):
        print("=> loading data")
        self.data = MNISTData(self.subset, self.fine, self.permutation,
                              self.train_indices, self.test_indices)
        self.train_loader = self.data.get_train_loader(self.batch_size,
                                                       self.num_workers)
        self.test_loader = self.data.get_test_loader(self.batch_size,
                                                     self.num_workers)

    def build_model(self):
        print("=> building model")
        if self.fusion == 'none':
            self.model = MNISTModel().to(self.device)
        else:
            self.model = MNISTwithAttn(self.fusion).to(self.device)
        if self.fusion in ['multi', 'single']:
            self.shadow = torch.zeros(self.model.attn.gamma.size(),
                                      device=self.device)
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.lr,
                                         momentum=0.9,
                                         weight_decay=5e-4)
示例#6
0
    def input_to_torch_tensor(self,
                              x,
                              device,
                              mode='inference',
                              force_no_preprocessing=False,
                              sample_ids=None):
        """This method can be used to map the internal numpy arrays to PyTorch
        tensors.

        Note, this method has been overwritten from the base class.

        It applies zero padding and pixel permutations.

        Args:
            (....): See docstring of method
                :meth:`data.dataset.Dataset.input_to_torch_tensor`.

        Returns:
            (torch.Tensor): The given input ``x`` as PyTorch tensor.
        """
        if not force_no_preprocessing:
            assert (len(x.shape) == 2)  # batch size plus flattened image.

            from torch import stack

            img_size = 28 + 2 * self._padding

            # Transform the numpy data into a representation as expected by the
            # ToPILImage transformation.
            x = (x * 255.0).astype('uint8')
            x = x.reshape(-1, 28, 28, 1)

            x = stack([self._transform(x[i, ...])
                       for i in range(x.shape[0])]).to(device)

            # Transform tensor back to numpy shape.
            # FIXME This is a horrible solution, but at least we ensure that the
            # user gets a tensor in the same shape as always and does not have to
            # deal with cases.
            x = x.permute(0, 2, 3, 1)
            x = x.contiguous().view(-1, img_size**2)

            return x

        else:
            return MNISTData.input_to_torch_tensor(
                self,
                x,
                device,
                mode=mode,
                force_no_preprocessing=force_no_preprocessing,
                sample_ids=sample_ids)
    def __init__(self,
                 data_path,
                 use_one_hot=False,
                 validation_size=0,
                 use_torch_augmentation=False):
        super().__init__()

        fmnist_train = FashionMNIST(data_path, train=True, download=True)
        fmnist_test = FashionMNIST(data_path, train=False, download=True)
        assert np.all(np.equal(fmnist_train.data.shape, [60000, 28, 28]))
        assert np.all(np.equal(fmnist_test.data.shape, [10000, 28, 28]))

        train_inputs = fmnist_train.data.numpy().reshape(60000, -1)
        test_inputs = fmnist_test.data.numpy().reshape(10000, -1)
        train_labels = fmnist_train.targets.numpy().reshape(60000, 1)
        test_labels = fmnist_test.targets.numpy().reshape(10000, 1)

        images = np.concatenate([train_inputs, test_inputs], axis=0)
        labels = np.concatenate([train_labels, test_labels], axis=0)

        # Scale images into a range between 0 and 1. Such that it is identical
        # to the default MNIST scale in `data.dataset.mnist_data`.
        images = images / 255

        val_inds = None
        train_inds = np.arange(train_labels.size)
        test_inds = np.arange(train_labels.size,
                              train_labels.size + test_labels.size)

        if validation_size > 0:
            if validation_size >= train_inds.size:
                raise ValueError('Validation set must contain less than %d ' \
                                 % (train_inds.size) + 'samples!')

            val_inds = np.arange(validation_size)
            train_inds = np.arange(validation_size, train_inds.size)

        # Bring everything into the internal structure of the Dataset class.
        self._data['classification'] = True
        self._data['sequence'] = False
        self._data['num_classes'] = 10
        self._data['is_one_hot'] = use_one_hot
        self._data['in_data'] = images
        self._data['in_shape'] = [28, 28, 1]
        self._data['out_shape'] = [10 if use_one_hot else 1]
        self._data['val_inds'] = val_inds
        self._data['train_inds'] = train_inds
        self._data['test_inds'] = test_inds

        if use_one_hot:
            labels = self._to_one_hot(labels)

        self._data['out_data'] = labels

        # Information specific to this dataset.
        assert np.all([fmnist_train.classes[i] == c for i, c in \
                       enumerate(fmnist_test.classes)])
        self._data['fmnist'] = dict()
        self._data['fmnist']['classes'] = fmnist_train.classes

        # Initialize PyTorch data augmentation.
        self._augment_inputs = False
        if use_torch_augmentation:
            self._augment_inputs = True
            self._train_transform, self._test_transform = \
                MNISTData.torch_input_transforms(use_random_hflips=True)
    config._equation_module = importlib.import_module('equations.' + \
                                                      config.equation_module)

    # Make all random processes predictable.
    np.random.seed(config.random_seed)
    random.seed(config.random_seed)
    bseed(config.random_seed)  # See brian docu for explanation.

    # Determine maximum number of threads to use.
    if config.num_threads is None:
        config.num_threads = multiprocessing.cpu_count()

    # Read the chosen dataset.
    logger.info('### Data preparation ...')
    if config.dataset == 'mnist':
        data = MNISTData()
    elif config.dataset == '7segment':
        data = SevenSegmentData()
    else:
        raise ConfigException('The chosen dataset \'%s\' is unknown. Please ' \
                              % config.dataset + 'reconsider the ''dataset''' \
                              + ' option of the configuration file.')
    logger.info('### Data preparation ... Done')

    # Assemble the network.
    logger.info('### Building Network ...')
    network = NetworkModel(data)
    logger.info('### Building Network ... Done')

    # Visualize just assembled network.
    if config.plot_network or config.save_network_plot: