def main(width, depth, max_epochs, state_dict_path, device, data_dir, num_workers): """ This function constructs and trains a model from scratch, without any knowledge transfer method applied. :param int depth: factor for controlling the depth of the model. :param int width: factor for controlling the width of the model. :param int max_epochs: maximum number of epochs for training the student model. :param string state_dict_path: path to save the trained model. :param int device: device to use for training the model. :param string data_dir: directory to save and load the dataset. :param int num_workers: number of workers to use for loading the dataset. """ # Define the device for training the model. device = torch.device(device) # Get data loaders for the CIFAR-10 dataset. train_loader, validation_loader, test_loader = get_cifar10_loaders( data_dir, batch_size=BATCH_SIZE, num_workers=num_workers ) # Construct the model to be trained. model = WideResidualNetwork(depth=depth, width=width) model = model.to(device) # Define optimizer and learning rate scheduler. optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=LEARNING_RATE_DECAY_MILESTONES, gamma=LEARNING_RATE_DECAY_FACTOR ) # Construct the loss function to be used for training. criterion = torch.nn.CrossEntropyLoss() # Define the ignite engines for training and evaluation. batch_updater = BatchUpdaterWithoutTransfer(model=model, optimizer=optimizer, criterion=criterion, device=device) batch_evaluator = BatchEvaluator(model=model, device=device) trainer = Engine(batch_updater) evaluator = Engine(batch_evaluator) # Define and attach the progress bar, loss metric, and the accuracy metrics. attach_pbar_and_metrics(trainer, evaluator) # The training engine updates the learning rate schedule at end of each epoch. lr_updater = LearningRateUpdater(lr_scheduler=lr_scheduler) trainer.on(Events.EPOCH_COMPLETED(every=1))(lr_updater) # The training engine logs the training and the evaluation metrics at end of each epoch. metric_logger = MetricLogger(evaluator=evaluator, eval_loader=validation_loader) trainer.on(Events.EPOCH_COMPLETED(every=1))(metric_logger) # Train the model trainer.run(train_loader, max_epochs=max_epochs) # Save the model to pre-defined path. We move the model to CPU which is desirable as the default device # for loading the model. model.cpu() state_dict_dir = "/".join(state_dict_path.split("/")[:-1]) os.makedirs(state_dict_dir, exist_ok=True) torch.save(model.state_dict(), state_dict_path)
def main( student_depth, student_width, teacher_depth, teacher_width, max_epochs, variational_information_distillation_factor, knowledge_distillation_factor, knowledge_distillation_temperature, state_dict_path, teacher_state_dict_path, device, data_dir, num_workers, ): """ This function constructs and trains a student model with knowledge transfer from the pretrained teacher model. :param int student_depth: factor for controlling the depth of the student model. :param int student_width: factor for controlling the width of the student model. :param int teacher_depth: factor for controlling the depth of the teacher model. :param int teacher_width: factor for controlling the width of the teacher model. :param int max_epochs: maximum number of epochs for training the student model. :param float variational_information_distillation_factor: scaling factor for variational information distillation. :param float knowledge_distillation_factor: scaling factor for knowledge distillation. :param float knowledge_distillation_temperature: degree of smoothing on distributions for computing the Kuback-Leibler divergence for knowledge distillation. :param string state_dict_path: path to save the student model. :param string teacher_state_dict_path: path to load the teacher model from. :param int device: device to use for training the model :param string data_dir: directory to save and load the dataset. :param int num_workers: number of workers to use for loading the dataset. """ # Define the device for training the model. device = torch.device(device) # Get data loaders for the CIFAR-10 dataset. train_loader, validation_loader, test_loader = get_cifar10_loaders( data_dir, batch_size=BATCH_SIZE, num_workers=num_workers) # Construct the student model to be trained. model = StudentWideResidualNetwork(depth=student_depth, width=student_width, teacher_width=teacher_width) model = model.to(device) # Construct and load the teacher model for guiding the student model. teacher_model = TeacherWideResidualNetwork( depth=teacher_depth, width=teacher_width, load_path=teacher_state_dict_path) teacher_model = teacher_model.to(device) # Define optimizer and learning rate scheduler optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=LEARNING_RATE_DECAY_MILESTONES, gamma=LEARNING_RATE_DECAY_FACTOR) # Construct the loss function to be used for training. label_criterion = torch.nn.CrossEntropyLoss() teacher_logit_criterion = TemperatureScaledKLDivLoss( temperature=knowledge_distillation_temperature) teacher_feature_criterion = GaussianLoss() criterion = EnsembleKnowledgeTransferLoss( label_criterion=label_criterion, teacher_logit_criterion=teacher_logit_criterion, teacher_feature_criterion=teacher_feature_criterion, teacher_logit_factor=knowledge_distillation_factor, teacher_feature_factor=variational_information_distillation_factor, ) # Define the ignite engines for training and evaluation. batch_updater = BatchUpdaterWithTransfer(model=model, teacher_model=teacher_model, optimizer=optimizer, criterion=criterion, device=device) batch_evaluator = BatchEvaluator(model=model, device=device) trainer = Engine(batch_updater) evaluator = Engine(batch_evaluator) # Define and attach the progress bar, loss metric, and the accuracy metrics. attach_pbar_and_metrics(trainer, evaluator) # The training engine updates the learning rate schedule at end of each epoch. lr_updater = LearningRateUpdater(lr_scheduler=lr_scheduler) trainer.on(Events.EPOCH_COMPLETED(every=1))(lr_updater) # The training engine logs the training and the evaluation metrics at end of each epoch. metric_logger = MetricLogger(evaluator=evaluator, eval_loader=validation_loader) trainer.on(Events.EPOCH_COMPLETED(every=1))(metric_logger) # Train the model trainer.run(train_loader, max_epochs=max_epochs) # Save the model to pre-defined path. We move the model to CPU which is desirable as the default device # for loading the model. model.cpu() state_dict_dir = "/".join(state_dict_path.split("/")[:-1]) os.makedirs(state_dict_dir, exist_ok=True) torch.save(model.state_dict(), state_dict_path)