Esempio n. 1
0
def _setup_defaults(dataset, data_path, train_transform, validation_transform,
                    fast):
    """Helper function to setup default path and transforms when creating data loaders.

    If any of :param:`data_path`, :param:`train_transform`, or :param:`test_transform` are None, they will be replaced
    by default values.

    If :param:`fast` is set, transforms compatible with :class:`FastLoader` will be used.

    :param dataset: name of the dataset, (MNIST, FashionMNIST, CIFAR-10, CIFAR-100, ImageNette, ImageWoof, ImageNet)
                    are available.
    :type dataset: str
    :param data_path: path to folder containing dataset.
    :param train_transform: PyTorch transform to apply to images for training.
    :type train_transform: torchvision.transforms.Compose
    :param validation_transform: PyTorch transform to apply to images for validation.
    :type validation_transform: torchvision.transforms.Compose
    :param fast: whether fast loaders are used.
    :type fast: bool
    :return: path to data, train transform, and test transform.
    :rtype: (str, torchvision.transforms.Compose, torchvision.transforms.Compose)
    """
    # Setup the path to the dataset.
    if data_path is None:
        if dataset in ['ImageNette', 'ImageWoof', 'ImageNet']:
            # For these datasets, we cannot rely on torchvision for automatic downloading=
            # TODO: Implement automatic downloading of ImageNette, Imagewoof, and ImageNet.
            log.log(
                "Auto-download of dataset {0} is not currently supported. "
                "Specify a path containing the 'train' and 'val' folders of the dataset."
                .format(dataset), LOGTAG, log.Level.ERROR)
            raise NotImplementedError(
                "Auto-download of dataset {0} is not currently supported, select a path."
            )
        data_path = dataset  # Default to putting the dataset in a folder named 'dataset' in the working folder.

    # Setup the train and validation/test transforms.
    if fast:
        # Use the fast default transforms base instead:
        transforms_base = fast_default_transforms
    else:
        transforms_base = default_transforms
    # Currently, the same train and validation/test transforms are used.
    # If default data augmentation is implemented, it should be on the training set and not on the validation/test one.
    if train_transform is None:
        train_transform = transforms_base[DATASETS[dataset]]
    if validation_transform is None:
        validation_transform = transforms_base[DATASETS[dataset]]

    return data_path, train_transform, validation_transform
Esempio n. 2
0
def _get_data(dataset, data_path, transform, test):
    """Helper function to retrieve training/test data associated with a dataset.

    :param dataset: name of the dataset, (MNIST, FashionMNIST, CIFAR-10, CIFAR-100, ImageNette, ImageWoof, ImageNet)
                    are available.
    :type dataset: str
    :param data_path: path to a folder containing the dataset.
    :type data_path: str
    :param transform: PyTorch transform to apply to the data.
    :type transform: torchvision.transforms.Compose
    :param test: if true, return test data instead of training data.
    :type test: bool
    :return: full training data from dataset with transform applied.
    :rtype: torch.utils.data.Dataset
    """
    if dataset in ['ImageNet', 'ImageNette', 'ImageWoof']:
        data_path = os.path.join(data_path, FOLDER_NAME[dataset])
    if dataset == 'MNIST':
        data = datasets.MNIST(data_path,
                              train=not test,
                              download=True,
                              transform=transform)
    elif dataset == 'FashionMNIST':
        data = datasets.FashionMNIST(data_path,
                                     train=not test,
                                     download=True,
                                     transform=transform)
    elif dataset == 'CIFAR-10':
        data = datasets.CIFAR10(data_path,
                                train=not test,
                                download=True,
                                transform=transform)
    elif dataset == 'CIFAR-100':
        data = datasets.CIFAR100(data_path,
                                 train=not test,
                                 download=True,
                                 transform=transform)
    elif dataset in ['ImageNette', 'ImageWoof', 'ImageNet']:
        # These datasets are not available in torchvision, so we find and build them ourselves:
        train_directory = os.path.join(data_path, 'val' if test else 'train')
        data = datasets.ImageFolder(train_directory, transform)
    else:
        log.log(
            "Dataset {0} is not available ! Choose from (MNIST, FashionMNIST, CIFAR-10, CIFAR-100, "
            "ImageNette, ImageWoof, ImageNet).".format(dataset), LOGTAG,
            log.Level.ERROR)
        raise NotImplementedError(
            "Dataset {0} is not available!".format(dataset))
    return data
Esempio n. 3
0
    def _end_epoch(self,
                   epoch,
                   epoch_metrics,
                   epoch_validation_metrics,
                   optimizer,
                   scheduler,
                   verbose=True,
                   loss_criterion=None):
        """Wrap-up the epoch by stepping the scheduler if there is one.

        Override to alter operations done at the very end of every epoch, such as stepping the scheduler, or
        printing some information about the performance of the network.

        :param epoch: current epoch number.
        :type epoch: int
        :param optimizer: optimizer of the model parameters.
        :type optimizer: torch.optim.Optimizer
        :param scheduler: scheduler of the model parameters.
        :type scheduler: torch.optim.Scheduler
        """
        if scheduler is not None:
            scheduler.step()
        if self.tb_writer is not None:
            self.tb_writer.add_scalar(tag='train_loss',
                                      scalar_value=epoch_metrics['loss'],
                                      global_step=epoch)
            self.tb_writer.add_scalar(
                tag='val_loss',
                scalar_value=epoch_validation_metrics['loss'],
                global_step=epoch)
            for name, param in self.model.named_parameters():
                self.tb_writer.add_histogram(
                    tag=name,
                    values=param.clone().cpu().data.numpy(),
                    global_step=epoch)
                if param.grad is not None:
                    self.tb_writer.add_histogram(
                        tag='{0}_grad'.format(name),
                        values=param.grad.clone().cpu().data.numpy(),
                        global_step=epoch)
        if verbose:
            log.log(
                "   Training loss: {0} -- Validation loss: {1}.".format(
                    epoch_metrics['loss'], epoch_validation_metrics['loss']),
                LOGTAG, log.Level.INFO)
Esempio n. 4
0
 def _end_epoch(self,
                epoch,
                epoch_metrics,
                epoch_validation_metrics,
                optimizer,
                scheduler,
                verbose=True,
                loss_criterion=None):
     super(Classifier,
           self)._end_epoch(epoch, epoch_metrics, epoch_validation_metrics,
                            optimizer, scheduler, False)
     if self.tb_writer is not None:
         for i, kacc in enumerate(epoch_metrics['topk']):
             self.tb_writer.add_scalar(tag='train_top{0}acc'.format(
                 self.top_predictions[i]),
                                       scalar_value=kacc,
                                       global_step=epoch)
             self.tb_writer.add_scalar(
                 tag='val_top{0}acc'.format(self.top_predictions[i]),
                 scalar_value=epoch_validation_metrics['topk'][i],
                 global_step=epoch)
     if verbose:
         train_accuracy = ', '.join([
             'top-{0}:{1:.1f}%'.format(k, epoch_metrics['topk'][i])
             for i, k in enumerate(self.top_predictions)
         ])
         validation_accuracy = ', '.join([
             'top-{0}:{1:.1f}%'.format(k,
                                       epoch_validation_metrics['topk'][i])
             for i, k in enumerate(self.top_predictions)
         ])
         log.log(
             "   Training loss: {0} ({2}) -- Validation loss: {1} ({3}).".
             format(epoch_metrics['loss'], epoch_validation_metrics['loss'],
                    train_accuracy, validation_accuracy), LOGTAG,
             log.Level.INFO)
Esempio n. 5
0
def run_once(data_path,
             save_path,
             test=False,
             use_dfa=True,
             use_feedback_normalization=True,
             use_PAI=False,
             use_conv=False,
             dropout_rate=None,
             use_batchnorm=False,
             activation=nn.Tanh,
             dataset='CIFAR-10',
             epochs=50,
             batch_size=128,
             past_state_path=None,
             seed=0,
             gpu_id=0):
    base_name = f"3_{'test' if test else 'eval'}_{'DFA' if use_dfa else 'BP'}" \
        f"_{'FBnorm' if use_feedback_normalization else 'nonorm'}" \
        f"_{'initPAI' if use_PAI else 'initSTD'}_{'CONV' if use_conv else 'FC'}" \
        f"_{'ND' if dropout_rate is None else f'e{dropout_rate[0]}_fc{dropout_rate[1]}_conv{dropout_rate[2]}'}" \
        f"_{'BN' if use_batchnorm else 'X'}_{activation().__str__()}_{dataset}_e{epochs}_bs{batch_size}" \
        f"_{'retrieved' if past_state_path is not None else 'from_scratch'}_s{seed}"

    # Set-up a logging file.
    if save_path is not None:
        log_file_name = "log_{0}.txt".format(base_name)
        log_save_path = os.path.join(save_path, log_file_name)
        log.setup_logging(log.Level.INFO, log_save_path)
    else:
        log.setup_logging(log.Level.INFO)
        log.log(
            "You have not specified a save file, no data will be kept from the run!",
            LOGTAG, log.Level.ERROR)

    log.log(
        "<b><u>Establishing Baselines for Direct Feedback Alignment</u></b>",
        LOGTAG, log.Level.WARNING)
    log.log("<b>Section 3 -- Establishing Best Practices for DFA</b>", LOGTAG,
            log.Level.WARNING)

    log.log("Setting-up processing back-end and seeds...", LOGTAG,
            log.Level.INFO)
    # For larger architectures that have high memory needs, the feedback matrix can be kept on another GPU (rp_device)
    # and the BP model used for angle calculations as well (bp_device). Implementation is not fully tested for BP on
    # a separate device, and some tensors may need to be moved around. This code has also not been tested on CPU only.
    device = proc.enable_cuda(gpu_id, seed)
    rp_device, bp_device = device, device

    # Setting-up random number generation.
    log.log(f"Seeding with <b>{seed}</b>.", LOGTAG, log.Level.DEBUG)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Setting-up data: transforms and data loaders.
    log.log(
        f"Preparing data for dataset <b>{dataset}</b> with batch size {batch_size}...",
        LOGTAG, log.Level.INFO)

    train_loader, validation_loader = data.get_loaders(dataset, batch_size,
                                                       test, data_path)
    if use_PAI:
        # Prepare a separate loader for prealignment. This allows to experiment with different batch size and transforms
        # on prealignment.
        PAI_loader, _ = data.get_loaders(dataset, batch_size, test, data_path)

    # Setting-up model.
    log.log("Creating model...", LOGTAG, log.Level.INFO)
    # We create a model description with all the features possible (batchnorm and dropout) and then remove them if
    # they are not required.
    keep_dropout = True
    if dropout_rate is None:
        keep_dropout = False
        dropout_rate = [0, 0, 0]
    if use_conv:
        log.log("Using a <b>convolutional</b> architecture.", LOGTAG,
                log.Level.DEBUG)
        model_description = OrderedDict([
            ('drop1', nn.Dropout2d(dropout_rate[0])),
            ('conv1', nn.Conv2d(3, 32, 5, padding=2)),
            ('drop2', nn.Dropout2d(dropout_rate[2])), ('act1', activation()),
            ('bn1', nn.BatchNorm2d(32)), ('maxp1', nn.MaxPool2d(3, 2)),
            ('conv2', nn.Conv2d(32, 64, 5, padding=2)),
            ('drop3', nn.Dropout2d(dropout_rate[2])), ('act2', activation()),
            ('bn2', nn.BatchNorm2d(64)), ('maxp2', nn.MaxPool2d(3, 2)),
            ('conv3', nn.Conv2d(64, 64, 5, padding=2)),
            ('drop4', nn.Dropout2d(dropout_rate[2])), ('act3', activation()),
            ('bn3', nn.BatchNorm2d(64)), ('maxp3', nn.MaxPool2d(3, 2)),
            ('flat', util.Flatten()), ('lin1', nn.Linear(576, 128)),
            ('drop5', nn.Dropout(dropout_rate[1])), ('act4', activation()),
            ('bn4', nn.BatchNorm1d(128)), ('lin2', nn.Linear(128, 10))
        ])
    else:
        log.log("Using a <b>fully-connected</b> architecture.", LOGTAG,
                log.Level.DEBUG)
        model_description = OrderedDict([
            ('flat', util.Flatten()), ('drop1', nn.Dropout(dropout_rate[0])),
            ('lin1', nn.Linear(3072, 800)),
            ('drop2', nn.Dropout(dropout_rate[1])), ('act1', activation()),
            ('bn1', nn.BatchNorm1d(800)), ('lin2', nn.Linear(800, 800)),
            ('drop3', nn.Dropout(dropout_rate[1])), ('act2', activation()),
            ('bn2', nn.BatchNorm1d(800)), ('lin3', nn.Linear(800, 800)),
            ('drop4', nn.Dropout(dropout_rate[1])), ('act3', activation()),
            ('bn3', nn.BatchNorm1d(800)), ('lin4', nn.Linear(800, 10))
        ])

    if not keep_dropout:
        model_description = remove_dropout(model_description)
    else:
        log.log(
            f"Using <b>dropout</b> with rate {dropout_rate[0]} for input, {dropout_rate[2]} for conv. layers,"
            f" {dropout_rate[2]} for fully-connected layers.", LOGTAG,
            log.Level.DEBUG)
    if not use_batchnorm:
        model_description = remove_batchnorm(model_description)
    else:
        log.log("Using <b>batch normalization</b>.", LOGTAG, log.Level.DEBUG)

    if use_dfa or use_PAI:
        # Create the DFA model, even if we run BP we will need it to calculate the prealignment initialization.
        model_dfa = dfamodels.DFAClassifier(device,
                                            device,
                                            model_description,
                                            train_loader,
                                            validation_loader,
                                            saving=(base_name, save_path, 5))
        model_dfa = model_dfa.to(device)
        if not use_feedback_normalization:
            log.log("<b>Not</b> using feedback normalization.", LOGTAG,
                    log.Level.INFO)
            model_dfa.model.feedback_normalization = False
        model_dfa.initialize()
        if use_PAI:
            log.log("Using <b>PAI</b> (Pre-Alignment Initialization).", LOGTAG,
                    log.Level.DEBUG)
            model_dfa.prealignment(PAI_loader, nn.CrossEntropyLoss())
        if use_dfa:
            log.log("Will be using <b>DFA</b> for training!", LOGTAG,
                    log.Level.WARNING)
            model = model_dfa
    if not use_dfa:
        log.log("Will be using <b>BP</b> for training!", LOGTAG,
                log.Level.WARNING)
        model = models.Classifier(model_description,
                                  train_loader,
                                  validation_loader,
                                  saving=(base_name, save_path, 5))
        model = model.to(device)
        if use_PAI:
            model.model.load_state_dict(model_dfa.model.state_dict())

    if past_state_path is not None:
        with open(past_state_path, 'rb') as state_file:
            model.model.load_state_dict(torch.load(state_file))

    # Setting-up optimizer.
    optimizer_description = (opt.SGD, {'lr': 5 * 1e-4})

    # Train the model.
    log.log(
        f"Training with {'DFA' if use_dfa else 'BP'} initialized by {'PAI' if use_PAI else 'STD'} "
        f"for {epochs} epochs on dataset {dataset} with a batch size of {batch_size}:",
        LOGTAG, log.Level.INFO)
    log.log(f"Model: {model}", LOGTAG, log.Level.DEBUG)
    model.train(epochs,
                optimizer_description,
                loss_criterion=nn.CrossEntropyLoss())

    # Validate the model one last time.
    log.log(
        f"Final validation with {'DFA' if use_dfa else 'BP'} initialized by {'PAI' if use_PAI else 'STD'} "
        f"for {epochs} epochs on dataset {dataset} with a batch size of {batch_size}:",
        LOGTAG, log.Level.INFO)
    model.validate(loss_criterion=nn.CrossEntropyLoss())

    # Save the model training log and the weights (message log and angles are already saved dynamically).
    training_log_file_name = "training_log_{0}.tl".format(base_name)
    training_log_file_path = os.path.join(save_path, training_log_file_name)
    log.log(
        "Finishing up: saving training log to {0}...".format(
            training_log_file_path), LOGTAG, log.Level.INFO)
    with open(training_log_file_path, 'wb') as training_log_file:
        pk.dump(model.training_log, training_log_file)

    state_dict_file_name = "state_{0}.pt".format(base_name)
    state_dict_file_path = os.path.join(save_path, state_dict_file_name)
    log.log(
        "Finishing up: saving model state to {0}...".format(
            state_dict_file_path), LOGTAG, log.Level.INFO)
    torch.save(model.model.state_dict(), state_dict_file_path)

    return model.training_log
Esempio n. 6
0
def get_fast_loaders(dataset,
                     batch_size,
                     test,
                     device,
                     data_path=None,
                     train_transform=None,
                     validation_transform=None,
                     train_percentage=0.85,
                     workers=4):
    """Return :class:`FastLoader` for training and validation, outfitted with a random sampler.

    If set to run on the test set, :param:`train_percentage` will be ignored and set to 1.

    The transforms should only include operations on PIL images and should not convert the images to a tensor, nor
    handle normalization of the tensors. This is handled at runtime by the fast loaders.

    If you are not looking for high-performance, prefer :func:`get_loaders`.

    :param dataset: name of the dataset, (MNIST, FashionMNIST, CIFAR-10, CIFAR-100, ImageNette, ImageWoof, ImageNet)
                    are available.
    :type dataset: str
    :param batch_size: batch size for training and validation.
    :type batch_size: int
    :param test: run validation on the test set.
    :type test: bool
    :param data_path: path to folder containing dataset.
    :type data_path: str
    :param train_transform: PyTorch transform to apply to images for training.
    :type train_transform: torchvision.transforms.Compose
    :param validation_transform: PyTorch transform to apply to images for validation.
    :type validation_transform: torchvision.transforms.Compose
    :param train_percentage: percentage of the data in the training set.
    :type train_percentage: float
    :param workers: number of subprocesses to use for data loading. Use 0 for loading in the main process.
    :type workers: int
    :return: training and validation fast data loaders.
    :rtype: (FastLoader, FastLoader)
    """
    # Check if any parameters has been set to its default value, and if so, setup the defaults.
    data_path, train_transform, validation_transform = _setup_defaults(
        dataset, data_path, train_transform, validation_transform, fast=True)

    # Get all of the training data available.
    train_data = _get_train_data(dataset, data_path, train_transform)
    log.log("Training data succesfully fetched!", LOGTAG, log.Level.DEBUG)

    if not test:
        # Perform a train/validation split on the training data available:
        # For performance reasons, the train/validation split will always be the same.
        # TODO: Implement random train/validation split with fast loading and distributed training.
        log.log("Running in standard training/validation mode.", LOGTAG,
                log.Level.INFO)
        dataset_size = len(train_data)
        split_index = int(dataset_size * train_percentage)
        log.log("{0}:{1}".format(dataset_size, split_index), LOGTAG,
                log.Level.HIGHLIGHT)
        validation_data = train_data[split_index:]
        train_data = train_data[:split_index]
        log.log("Validation data succesfully fetched!", LOGTAG,
                log.Level.DEBUG)
    else:
        # Fetch the test data:
        log.log(
            "Running in <b>test</b> mode. All training data available will be used, and "
            "validation will be done on the test set. Are you really ready to publish?",
            LOGTAG, log.Level.WARNING)
        validation_data = _get_test_data(dataset, data_path,
                                         validation_transform)
        log.log("Test data succesfully fetched!", LOGTAG, log.Level.DEBUG)

    if distributed.is_initialized():
        # If running in distributed mode, use a DistributedSampler:
        log.log(
            "Running in <b>distributed</b> mode. This hasn't been thoroughly tested, beware!",
            LOGTAG, log.Level.WARNING)
        train_sampler = data_utils.distributed.DistributedSampler(train_data)
    else:
        # Otherwise, default to a RandomSampler:
        train_sampler = data_utils.RandomSampler(train_data)

    # Build the train and validation loaders, using pinned memory and a custom collate function to build the batches.
    train_loader = data_utils.DataLoader(train_data,
                                         batch_size=batch_size,
                                         num_workers=workers,
                                         pin_memory=True,
                                         sampler=train_sampler,
                                         collate_fn=_fast_collate,
                                         drop_last=True)
    log.log("Train loader succesfully created!", LOGTAG, log.Level.DEBUG)
    validation_loader = data_utils.DataLoader(validation_data,
                                              batch_size=batch_size,
                                              num_workers=workers,
                                              pin_memory=True,
                                              collate_fn=_fast_collate)
    log.log("Validation loader succesfully created!", LOGTAG, log.Level.DEBUG)

    # Wrap the PyTorch loaders in the custom FastLoader class and feed it the normalization parameters associated
    # with the dataset.
    return FastLoader(train_loader, device, *NORMALIZATION[DATASETS[dataset]]), \
           FastLoader(validation_loader, device, *NORMALIZATION[DATASETS[dataset]])
Esempio n. 7
0
    def train(self,
              epochs,
              optimizer_description=(opt.Adam, {
                  'lr': 1e-4
              }),
              scheduler_description=None,
              loss_criterion=nn.CrossEntropyLoss()):
        """Train the model for a fixed number of epochs, using the optimizer, scheduler, and loss criterion provided,
        and running evaluation according to the frequency selected.

        Optimizer and scheduler are created from their 'descriptions'. A description is a tuple first containing a base
        class from which to derive the optimizer/scheduler, then containing a dictionnary of parameters that will be
        unpacked and passed to the base class. The model parameters will always be passed as the first argument.

        The general organization is as follows:
            - :func:`_setup_optimization` is used to create the optimizer and scheduler from their descriptions;
            - For each epoch,
                - For each batch,
                    - We get the (input, target_output) batch from the train_loader and push it to the device;
                    - :func:`process` is used to infer model_output from input;
                    - :func:`_train` can be overridden to perform additionnal backward operations;
                    - :func:`_evaluate_loss` is called to evaluate the batch_loss using loss_criterion;
                    - :func:`_batch_metrics` is used to calculate the batch metrics from;
                    - :func:`_end_batch` is called at the very end of the batch to actually perform the backward step.
                - :func:`_epoch_metrics` combine the metrics from the batches into metrics for the epoch;
                - The epoch metrics are appended to the training_log;
                - :func:`_end_batch` is called at the very end of the epoch to step the scheduler.
                - If it is time, validation is run once.

        :param epochs: how many epochs to run.
        :type epochs: int
        :param optimizer_description: tuple of base class to use for optimizer and arguments to pass.
        :type optimizer_description: (torch.optim.Optimizer, dict)
        :param scheduler_description: tuple of base class to use for scheduler and arguments to pass.
        :type scheduler_description: (torch.optim.Scheduler, dict)
        :param loss_criterion: loss criterion to use to evaluate loss.
        :type loss_criterion: torch.nn._Loss
        :param evaluation_frequency: how often to evaluate model (in epochs).
        :type evaluation_frequency: int
        """
        # Setup optimizer and scheduler.
        optimizer, scheduler = self._setup_optimization(
            optimizer_description, scheduler_description)

        epoch_processing = [time.time(), None]
        for e in range(1, epochs + 1):
            if e != 1:
                log.log(
                    "   EPOCH {0}/{1}, last epoch processed in {2}s:".format(
                        e, epochs, epoch_processing[1] - epoch_processing[0]),
                    LOGTAG, log.Level.INFO)
            else:
                log.log("   EPOCH {0}/{1}:".format(e, epochs), LOGTAG,
                        log.Level.INFO)

            epoch_processing[0] = time.time()

            batches_metrics = []  # Metrics of every batch in the epoch.
            self.model.train()
            data_loading = [time.time(), None, None]
            for i, (input, target_output) in enumerate(self.train_loader):
                #if True:
                #    self._end_epoch(None, None, None, None, None, True, nn.CrossEntropyLoss())
                # Do one batch:
                data_loading[2] = time.time()
                # Push the data to GPU asynchronously.
                input, target_output = input.to(self.device, non_blocking=True), \
                                       target_output.to(self.device, non_blocking=True)

                input, target_output = self._process_data(input, target_output)
                if i != 0:
                    log.log(
                        '       Batch {0}/{1} (loaded in {2:.4f}s, '
                        'last batch forward in {3:.4f}s, and backward in {4:.4f}s)'
                        .format(i, len(self.train_loader),
                                data_loading[2] - data_loading[1],
                                forward_pass[1] - forward_pass[0],
                                backward_pass[1] - backward_pass[0]),
                        LOGTAG,
                        log.Level.INFO,
                        temporary=True)
                else:
                    log.log('       Batch {0}/{1} (loaded in {0}s)'.format(
                        i, len(self.train_loader),
                        data_loading[2] - data_loading[0]),
                            LOGTAG,
                            log.Level.INFO,
                            temporary=True)

                # Forward pass: infer output from input.
                forward_pass = [time.time(), None]
                model_output = self.infer(input)
                forward_pass[1] = time.time()

                # Backward pass: calculate loss and other metrics, and descend gradient.
                backward_pass = [time.time(), None]
                self._train(
                    input, target_output, model_output
                )  # Perform additional computations on input/output.
                batch_loss = self._evaluate_loss(input, model_output,
                                                 target_output, loss_criterion)
                batches_metrics.append(
                    self._batch_metrics(input, model_output, target_output,
                                        batch_loss))
                self._end_batch(batch_loss, optimizer, input, model_output,
                                target_output)  # Ascend gradient.
                backward_pass[1] = time.time()
                data_loading[1] = time.time()

            # Compute epoch metrics from batches metrics and log them.
            epoch_metrics = self._epoch_metrics('train', e, batches_metrics)
            self.epochs_metrics.append(epoch_metrics)
            self.training_log.append(('TRAIN', e, epoch_metrics))

            # Wrap-up everything, check validation performance, and step scheduler.
            epoch_validation_metrics = self.validate(
                loss_criterion=loss_criterion, epoch=e)
            self._end_epoch(e,
                            epoch_metrics,
                            epoch_validation_metrics,
                            optimizer,
                            scheduler,
                            loss_criterion=loss_criterion)

            if self.saving_frequency is not None and e % self.saving_frequency == 0:
                state_dict_file_name = "state_{0}.pt".format(self.model_name)
                state_dict_file_path = os.path.join(self.saving_path,
                                                    state_dict_file_name)
                log.log(
                    "Checkpoint! saving model state to {0}...".format(
                        state_dict_file_path), LOGTAG, log.Level.INFO)
                torch.save(self.model.state_dict(), state_dict_file_path)

            epoch_processing[1] = time.time()
Esempio n. 8
0
def run_once(data_path,
             save_path,
             test=False,
             use_PAI=False,
             dataset='CIFAR-10',
             bottleneck_rate=0,
             epochs=50,
             batch_size=128,
             past_state_path=None,
             seed=0,
             gpu_id=0):
    base_name = f"3_{'test' if test else 'eval'}" \
        f"_{'initPAI' if use_PAI else 'initSTD'}" \
        f"_{dataset}_e{epochs}_bs{batch_size}" \
        f'_b{bottleneck_rate}' \
        f"_{'retrieved' if past_state_path is not None else 'from_scratch'}_s{seed}"

    # Set-up a logging file.
    if save_path is not None:
        log_file_name = "log_{0}.txt".format(base_name)
        log_save_path = os.path.join(save_path, log_file_name)
        log.setup_logging(log.Level.INFO, log_save_path)
    else:
        log.setup_logging(log.Level.INFO)
        log.log(
            "You have not specified a save file, no data will be kept from the run!",
            LOGTAG, log.Level.ERROR)

    log.log(
        "<b><u>Establishing Baselines for Direct Feedback Alignment</u></b>",
        LOGTAG, log.Level.WARNING)
    log.log("<b>Section 3 -- Establishing Best Practices for DFA</b>", LOGTAG,
            log.Level.WARNING)

    log.log("Setting-up processing back-end and seeds...", LOGTAG,
            log.Level.INFO)
    # For larger architectures that have high memory needs, the feedback matrix can be kept on another GPU (rp_device)
    # and the BP model used for angle calculations as well (bp_device). Implementation is not fully tested for BP on
    # a separate device, and some tensors may need to be moved around. This code has also not been tested on CPU only.
    device = proc.enable_cuda(gpu_id, seed)
    rp_device, bp_device = device, device

    # Setting-up random number generation.
    log.log(f"Seeding with <b>{seed}</b>.", LOGTAG, log.Level.DEBUG)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Setting-up data: transforms and data loaders.
    log.log(
        f"Preparing data for dataset <b>{dataset}</b> with batch size {batch_size}...",
        LOGTAG, log.Level.INFO)

    train_loader, validation_loader = data.get_loaders(dataset, batch_size,
                                                       test, data_path)
    if use_PAI:
        # Prepare a separate loader for prealignment. This allows to experiment with different batch size and transforms
        # on prealignment.
        PAI_loader, _ = data.get_loaders(dataset, batch_size, test, data_path)

    # Setting-up model.
    log.log(
        f"Creating model with <b>bottlenecking rate {bottleneck_rate}</b>...",
        LOGTAG, log.Level.INFO)
    # We create a model description with all the features possible (batchnorm and dropout) and then remove them if
    # they are not required.
    model_description = OrderedDict([
        ('flat', util.Flatten()), ('lin1', nn.Linear(3072, 800)),
        ('tanh1', nn.Tanh()),
        ('bot2', dfamodels.Bottleneck(bottleneck_rate, 800, 800)),
        ('tanh2', nn.Tanh()), ('lin3', nn.Linear(800, 800)),
        ('tanh3', nn.Tanh()), ('lin4', nn.Linear(800, 10))
    ])

    model = dfamodels.DFAClassifier(device,
                                    device,
                                    model_description,
                                    train_loader,
                                    validation_loader,
                                    saving=(base_name, save_path, 5))
    model = model.to(device)
    model.initialize()
    if use_PAI:
        log.log("Using <b>PAI</b> (Pre-Alignment Initialization).", LOGTAG,
                log.Level.DEBUG)
        model.prealignment(PAI_loader, nn.CrossEntropyLoss())

    if past_state_path is not None:
        with open(past_state_path, 'rb') as state_file:
            model.model.load_state_dict(torch.load(state_file))

    # Setting-up optimizer.
    optimizer_description = (opt.SGD, {'lr': 5 * 1e-4})

    # Train the model.
    log.log(
        f"Training with a bottlenecking rate {bottleneck_rate} initialized by {'PAI' if use_PAI else 'STD'} "
        f"for {epochs} epochs on dataset {dataset} with a batch size of {batch_size}:",
        LOGTAG, log.Level.INFO)
    log.log(f"Model: {model}", LOGTAG, log.Level.DEBUG)
    model.train(epochs,
                optimizer_description,
                loss_criterion=nn.CrossEntropyLoss())

    # Validate the model one last time.
    log.log(
        f"Final with a bottlenecking rate {bottleneck_rate} initialized by {'PAI' if use_PAI else 'STD'} "
        f"for {epochs} epochs on dataset {dataset} with a batch size of {batch_size}:",
        LOGTAG, log.Level.INFO)
    model.validate(loss_criterion=nn.CrossEntropyLoss())

    # Save the model training log and the weights (message log and angles are already saved dynamically).
    training_log_file_name = "training_log_{0}.tl".format(base_name)
    training_log_file_path = os.path.join(save_path, training_log_file_name)
    log.log(
        "Finishing up: saving training log to {0}...".format(
            training_log_file_path), LOGTAG, log.Level.INFO)
    with open(training_log_file_path, 'wb') as training_log_file:
        pk.dump(model.training_log, training_log_file)

    state_dict_file_name = "state_{0}.pt".format(base_name)
    state_dict_file_path = os.path.join(save_path, state_dict_file_name)
    log.log(
        "Finishing up: saving model state to {0}...".format(
            state_dict_file_path), LOGTAG, log.Level.INFO)
    torch.save(model.model.state_dict(), state_dict_file_path)

    return model.training_log
Esempio n. 9
0
    def _end_epoch(self, epoch, epoch_metrics, epoch_validation_metrics, optimizer, scheduler, verbose=True, loss_criterion=None):
        if self.bp_device is not None:
            self.bp_model.model.load_state_dict(self.model.state_dict())
            input_sample, output_sample = next(iter(self.validation_loader))
            input_sample, output_sample = self._process_data(input_sample, output_sample)
            input_sample_dfa, output_sample_dfa = input_sample.to(self.device, non_blocking=True), \
                                                  output_sample.to(self.device, non_blocking=True)
            input_sample_bp, output_sample_bp = input_sample.to(self.bp_device, non_blocking=True), \
                                                output_sample.to(self.bp_device, non_blocking=True)

            model_output_dfa = self.model(input_sample_dfa)
            model_output_bp = self.bp_model.model(input_sample_bp)

            error = self._compute_error(model_output_dfa, output_sample_dfa)

            dfa_loss_crit = loss_criterion.to(self.device)
            bp_loss_crit = loss_criterion.to(self.bp_device)
            dfa_loss = dfa_loss_crit(model_output_dfa, output_sample_dfa)
            bp_loss = bp_loss_crit(model_output_bp, output_sample_bp)
            self.model.zero_grad()
            self.bp_model.model.zero_grad()
            dfa_loss.backward()
            self.model.backward(error)
            bp_loss.backward()

            self.alignments.append([])
            self.angles.append([])

            for bp_module, bp_gradient in self.bp_model.gradient_helper.gradients.items():
                module_id = self.bp_model.gradient_helper.architecture.index(bp_module)
                dfa_module = self.gradient_helper.architecture[module_id]
                bp_gradient = bp_gradient.view(input_sample.shape[0], -1)
                dfa_gradient = self.gradient_helper.gradients[dfa_module].view(input_sample.shape[0], -1)
                cosine_similarity = nn.CosineSimilarity(dim=1)
                angles = cosine_similarity(bp_gradient.to(self.device), dfa_gradient)
                self.angles[-1].append([bp_module, angles])
                angles_stats = [float(angles.mean()), float(angles.std())]
                self.alignments[-1].append([bp_module, angles_stats])

            self.alignments[-1].reverse()
            self.angles[-1].reverse()

            first_layer, offset = True, 0
            for i, module in enumerate(self.bp_model.gradient_helper.architecture[1:]):
                if isinstance(module, (nn.Dropout, nn.Dropout2d, nn.Dropout2d)) and first_layer:
                    offset += 1
                else:
                    first_layer = False
                self.alignments[-1][i - offset][0] = f'{i}_{module.__str__()}'
                self.angles[-1][i - offset][0] = f'{i}_{module.__str__()}'

            if verbose:
                for layer_alignment in self.alignments[-1]:
                    log.log(f'{layer_alignment[0]} -- mean:{layer_alignment[1][0]:.4f}, std:{layer_alignment[1][1]:.4f}', LOGTAG, log.Level.INFO)

            if self.tb_writer is not None:
                for layer_alignment in self.alignments[-1]:
                    self.tb_writer.add_scalar(tag='{0}_angle_mean'.format(layer_alignment[0]), scalar_value=layer_alignment[1][0], global_step=epoch)
                    self.tb_writer.add_scalar(tag='{0}_angle_std'.format(layer_alignment[0]), scalar_value=layer_alignment[1][1], global_step=epoch)

            if self.saving_path is not None:
                alignment_file_path = os.path.join(self.saving_path, f'{self.model_name}_alignment.al')
                angles_file_path = os.path.join(self.saving_path, f'{self.model_name}_angle.ang')
                with open(alignment_file_path, 'wb') as alignment_file:
                    pk.dump(self.alignments, alignment_file)
                with open(angles_file_path, 'wb') as angles_file:
                    pk.dump(self.angles, angles_file)

        super(DFAModel, self)._end_epoch(epoch, epoch_metrics, epoch_validation_metrics, optimizer, scheduler, verbose)
Esempio n. 10
0
    def prealignment(self, alignment_loader, loss):
        modules_dfa = list(self.model.modules())[1:]
        modules_bp = list(self.bp_model.model.modules())[1:]
        print("DFA Arch.", modules_dfa)
        print("BP Arch.", modules_bp)
        modules_dfa.reverse()
        modules_bp.reverse()
        layer_index, layer_offset, last_layer, first_layer = -1, 0, True, False
        first_init = True
        for module in modules_dfa:
            layer_index += 1
            if isinstance(module, dfa.AsymmetricFeedback):
                layer_offset += 1
                last_layer = False
            if isinstance(module, util.WEIGHT_MODULES):
                log.log(f"Prealigning {module}...", LOGTAG, log.Level.INFO)
                next_module_bp = modules_bp[layer_index - layer_offset]
                next_module_dfa = modules_dfa[layer_index + 1]
                print("bp", next_module_bp)
                print("dfa", next_module_dfa)

                self.bp_model.model.load_state_dict(self.model.state_dict())

                input_sample, output_sample = next(iter(alignment_loader))
                input_sample, output_sample = self._process_data(input_sample, output_sample)
                input_sample_dfa, output_sample_dfa = input_sample.to(self.device, non_blocking=True), \
                                                      output_sample.to(self.device, non_blocking=True)
                input_sample_bp, output_sample_bp = input_sample.to(self.bp_device, non_blocking=True), \
                                                    output_sample.to(self.bp_device, non_blocking=True)

                model_output_dfa = self.model(input_sample_dfa)
                model_output_bp = self.bp_model.model(input_sample_bp)

                error = self._compute_error(model_output_dfa, output_sample_dfa)

                dfa_loss_crit = loss.to(self.device)
                bp_loss_crit = loss.to(self.bp_device)
                dfa_loss = dfa_loss_crit(model_output_dfa, output_sample_dfa)
                bp_loss = bp_loss_crit(model_output_bp, output_sample_bp)
                self.model.zero_grad()
                self.bp_model.model.zero_grad()
                dfa_loss.backward()
                self.model.backward(error)
                bp_loss.backward()

                if isinstance(next_module_dfa, dfa.AsymmetricFeedback):
                    self.gradient_helper.gradients[next_module_dfa] = next_module_dfa.rp

                try:
                    print(self.bp_model.gradient_helper.gradients.keys())
                    gradient_dfa = self.gradient_helper.gradients[next_module_dfa]
                    gradient_bp = self.bp_model.gradient_helper.gradients[next_module_bp]
                    print("Be", gradient_dfa.shape)
                    print("W", module.weight.shape)
                    print("grad", gradient_bp.shape)
                    gradient_dfa = torch.t(gradient_dfa)
                    gradient_bp = torch.t(gradient_bp)

                    weight_transpose = torch.mm(gradient_dfa, torch.pinverse(gradient_bp))
                    print("norm pre-normalization", weight_transpose.norm())
                    weight_transpose = weight_transpose / weight_transpose.norm()
                    print("norm post-normalization", weight_transpose.norm())
                    module.weight.data = torch.t(weight_transpose)
                except:
                    log.log(f"Prealignment of module {module} failed!", LOGTAG, log.Level.ERROR)