示例#1
0
def _main(dataset_str, model_str, batch_size, epochs, optimizer, **kwargs):
    '''Run the training procedure.

    Parameters
    ----------
    dataset_str :   str
                    Name of the dataset to use
    model_str   :   str
                    Unqualified name of the model class to use
    batch_size  :   int
    epochs      :   int
    optimizer   :   str
                    Name of the optimizer to use
    '''

    dataset_train, dataset_test = load_dataset(dataset_str, train_transforms=[ToTensor()],
                                               test_transforms=[ToTensor()])

    # for some strange reason, python claims 'torch referenced before assignment' when importing at
    # the top. hahaaaaa
    import torch
    bus = MessageBus('main')
    trainer = Trainer(dataset_train, batch_size=batch_size,
                      exporter=Exporter(depth=kwargs['depth'],
                                        module_filter=[torch.nn.Conv2d],
                                        message_bus=bus))
    trainer.set_model(model_str)
    trainer.optimize(name=optimizer, lr=kwargs.get('learning_rate', 0.01))
    if 'exponential_decay' in kwargs:
        decay = kwargs['exponential_decay']
        if decay is not None:
            trainer.set_schedule(torch.optim.lr_scheduler.ExponentialLR, decay)

    subsample = kwargs['subsample']
    backend   = kwargs['visualisation']
    subscriber_added = False

    if kwargs['hessian']:
        from torch.utils.data import DataLoader
        from ikkuna.export.subscriber import HessianEigenSubscriber
        loader = DataLoader(dataset_train.dataset, batch_size=batch_size, shuffle=True)
        trainer.add_subscriber(HessianEigenSubscriber(trainer.model.forward, trainer.loss, loader,
                                                      batch_size,
                                                      frequency=trainer.batches_per_epoch,
                                                      num_eig=1, power_steps=25,
                                                      backend=backend))
        trainer.create_graph = True
        subscriber_added = True

    if kwargs['spectral_norm']:
        for kind in kwargs['spectral_norm']:
            spectral_norm_subscriber = SpectralNormSubscriber(kind, backend=backend)
            trainer.add_subscriber(spectral_norm_subscriber)
        subscriber_added = True

    if kwargs['variance']:
        for kind in kwargs['variance']:
            var_sub = VarianceSubscriber(kind, backend=backend)
            trainer.add_subscriber(var_sub)
        subscriber_added = True

    if kwargs['test_accuracy']:
        test_accuracy_subscriber = TestAccuracySubscriber(dataset_test, trainer.model.forward,
                                                          frequency=trainer.batches_per_epoch,
                                                          batch_size=batch_size,
                                                          backend=backend)
        trainer.add_subscriber(test_accuracy_subscriber)
        subscriber_added = True

    if kwargs['train_accuracy']:
        train_accuracy_subscriber = TrainAccuracySubscriber(subsample=subsample,
                                                            backend=backend)
        trainer.add_subscriber(train_accuracy_subscriber)
        subscriber_added = True

    if kwargs['ratio']:
        for kind1, kind2 in kwargs['ratio']:
            ratio_subscriber = RatioSubscriber([kind1, kind2],
                                               subsample=subsample,
                                               backend=backend)
            trainer.add_subscriber(ratio_subscriber)
            pubs = ratio_subscriber.publications
            type, topics = pubs.popitem()
            # there can be multiple publications per type, but we know the RatioSubscriber only
            # publishes one
            trainer.add_subscriber(MessageMeanSubscriber(topics[0]))
        subscriber_added = True

    if kwargs['histogram']:
        for kind in kwargs['histogram']:
            histogram_subscriber = HistogramSubscriber(kind, backend=backend)
            trainer.add_subscriber(histogram_subscriber)
        subscriber_added = True

    if kwargs['norm']:
        for kind in kwargs['norm']:
            norm_subscriber = NormSubscriber(kind, backend=backend)
            trainer.add_subscriber(norm_subscriber)
        subscriber_added = True

    if kwargs['svcca']:
        svcca_subscriber = SVCCASubscriber(dataset_test, 500, trainer.model.forward,
                                           subsample=trainer.batches_per_epoch, backend=backend)
        trainer.add_subscriber(svcca_subscriber)
        subscriber_added = True

    if not subscriber_added:
        warnings.warn('No subscriber was added, the will be no visualisation.')
    batches_per_epoch = trainer.batches_per_epoch
    print(f'Batches per epoch: {batches_per_epoch}')

    # exporter = trainer.exporter
    # modules = exporter.modules
    # n_modules = len(modules)

    epoch_range = range(epochs)
    batch_range = range(batches_per_epoch)
    if kwargs['verbose']:
        epoch_range = tqdm(epoch_range, desc='Epoch')
        batch_range = tqdm(batch_range, desc='Batch')

    for e in epoch_range:

        # freeze_idx = int(e/epochs * n_modules) - 1
        # if freeze_idx >= 0:
        #     exporter.freeze_module(modules[freeze_idx])
        for batch_idx in batch_range:
            trainer.train_batch()