def _main(dataset_str, model_str, batch_size, epochs, optimizer, **kwargs): '''Run the training procedure. Parameters ---------- dataset_str : str Name of the dataset to use model_str : str Unqualified name of the model class to use batch_size : int epochs : int optimizer : str Name of the optimizer to use ''' dataset_train, dataset_test = load_dataset(dataset_str, train_transforms=[ToTensor()], test_transforms=[ToTensor()]) # for some strange reason, python claims 'torch referenced before assignment' when importing at # the top. hahaaaaa import torch bus = MessageBus('main') trainer = Trainer(dataset_train, batch_size=batch_size, exporter=Exporter(depth=kwargs['depth'], module_filter=[torch.nn.Conv2d], message_bus=bus)) trainer.set_model(model_str) trainer.optimize(name=optimizer, lr=kwargs.get('learning_rate', 0.01)) if 'exponential_decay' in kwargs: decay = kwargs['exponential_decay'] if decay is not None: trainer.set_schedule(torch.optim.lr_scheduler.ExponentialLR, decay) subsample = kwargs['subsample'] backend = kwargs['visualisation'] subscriber_added = False if kwargs['hessian']: from torch.utils.data import DataLoader from ikkuna.export.subscriber import HessianEigenSubscriber loader = DataLoader(dataset_train.dataset, batch_size=batch_size, shuffle=True) trainer.add_subscriber(HessianEigenSubscriber(trainer.model.forward, trainer.loss, loader, batch_size, frequency=trainer.batches_per_epoch, num_eig=1, power_steps=25, backend=backend)) trainer.create_graph = True subscriber_added = True if kwargs['spectral_norm']: for kind in kwargs['spectral_norm']: spectral_norm_subscriber = SpectralNormSubscriber(kind, backend=backend) trainer.add_subscriber(spectral_norm_subscriber) subscriber_added = True if kwargs['variance']: for kind in kwargs['variance']: var_sub = VarianceSubscriber(kind, backend=backend) trainer.add_subscriber(var_sub) subscriber_added = True if kwargs['test_accuracy']: test_accuracy_subscriber = TestAccuracySubscriber(dataset_test, trainer.model.forward, frequency=trainer.batches_per_epoch, batch_size=batch_size, backend=backend) trainer.add_subscriber(test_accuracy_subscriber) subscriber_added = True if kwargs['train_accuracy']: train_accuracy_subscriber = TrainAccuracySubscriber(subsample=subsample, backend=backend) trainer.add_subscriber(train_accuracy_subscriber) subscriber_added = True if kwargs['ratio']: for kind1, kind2 in kwargs['ratio']: ratio_subscriber = RatioSubscriber([kind1, kind2], subsample=subsample, backend=backend) trainer.add_subscriber(ratio_subscriber) pubs = ratio_subscriber.publications type, topics = pubs.popitem() # there can be multiple publications per type, but we know the RatioSubscriber only # publishes one trainer.add_subscriber(MessageMeanSubscriber(topics[0])) subscriber_added = True if kwargs['histogram']: for kind in kwargs['histogram']: histogram_subscriber = HistogramSubscriber(kind, backend=backend) trainer.add_subscriber(histogram_subscriber) subscriber_added = True if kwargs['norm']: for kind in kwargs['norm']: norm_subscriber = NormSubscriber(kind, backend=backend) trainer.add_subscriber(norm_subscriber) subscriber_added = True if kwargs['svcca']: svcca_subscriber = SVCCASubscriber(dataset_test, 500, trainer.model.forward, subsample=trainer.batches_per_epoch, backend=backend) trainer.add_subscriber(svcca_subscriber) subscriber_added = True if not subscriber_added: warnings.warn('No subscriber was added, the will be no visualisation.') batches_per_epoch = trainer.batches_per_epoch print(f'Batches per epoch: {batches_per_epoch}') # exporter = trainer.exporter # modules = exporter.modules # n_modules = len(modules) epoch_range = range(epochs) batch_range = range(batches_per_epoch) if kwargs['verbose']: epoch_range = tqdm(epoch_range, desc='Epoch') batch_range = tqdm(batch_range, desc='Batch') for e in epoch_range: # freeze_idx = int(e/epochs * n_modules) - 1 # if freeze_idx >= 0: # exporter.freeze_module(modules[freeze_idx]) for batch_idx in batch_range: trainer.train_batch()