예제 #1
0
def main(width, depth, max_epochs, state_dict_path, device, data_dir, num_workers):
    """
    This function constructs and trains a model from scratch, without any knowledge transfer method applied. 

    :param int depth: factor for controlling the depth of the model.
    :param int width: factor for controlling the width of the model.
    :param int max_epochs: maximum number of epochs for training the student model.
    :param string state_dict_path: path to save the trained model.
    :param int device: device to use for training the model.
    :param string data_dir: directory to save and load the dataset.
    :param int num_workers: number of workers to use for loading the dataset.
    """

    # Define the device for training the model.
    device = torch.device(device)

    # Get data loaders for the CIFAR-10 dataset.
    train_loader, validation_loader, test_loader = get_cifar10_loaders(
        data_dir, batch_size=BATCH_SIZE, num_workers=num_workers
    )

    # Construct the model to be trained.
    model = WideResidualNetwork(depth=depth, width=width)
    model = model.to(device)

    # Define optimizer and learning rate scheduler.
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=LEARNING_RATE_DECAY_MILESTONES, gamma=LEARNING_RATE_DECAY_FACTOR
    )

    # Construct the loss function to be used for training.
    criterion = torch.nn.CrossEntropyLoss()

    # Define the ignite engines for training and evaluation.
    batch_updater = BatchUpdaterWithoutTransfer(model=model, optimizer=optimizer, criterion=criterion, device=device)
    batch_evaluator = BatchEvaluator(model=model, device=device)
    trainer = Engine(batch_updater)
    evaluator = Engine(batch_evaluator)

    # Define and attach the progress bar, loss metric, and the accuracy metrics.
    attach_pbar_and_metrics(trainer, evaluator)

    # The training engine updates the learning rate schedule at end of each epoch.
    lr_updater = LearningRateUpdater(lr_scheduler=lr_scheduler)
    trainer.on(Events.EPOCH_COMPLETED(every=1))(lr_updater)

    # The training engine logs the training and the evaluation metrics at end of each epoch.
    metric_logger = MetricLogger(evaluator=evaluator, eval_loader=validation_loader)
    trainer.on(Events.EPOCH_COMPLETED(every=1))(metric_logger)

    # Train the model
    trainer.run(train_loader, max_epochs=max_epochs)

    # Save the model to pre-defined path. We move the model to CPU which is desirable as the default device
    # for loading the model.
    model.cpu()
    state_dict_dir = "/".join(state_dict_path.split("/")[:-1])
    os.makedirs(state_dict_dir, exist_ok=True)
    torch.save(model.state_dict(), state_dict_path)
예제 #2
0
def main(
    student_depth,
    student_width,
    teacher_depth,
    teacher_width,
    max_epochs,
    variational_information_distillation_factor,
    knowledge_distillation_factor,
    knowledge_distillation_temperature,
    state_dict_path,
    teacher_state_dict_path,
    device,
    data_dir,
    num_workers,
):
    """
    This function constructs and trains a student model with knowledge transfer from the pretrained teacher model. 
    
    :param int student_depth: factor for controlling the depth of the student model.
    :param int student_width: factor for controlling the width of the student model.
    :param int teacher_depth: factor for controlling the depth of the teacher model. 
    :param int teacher_width: factor for controlling the width of the teacher model.
    :param int max_epochs: maximum number of epochs for training the student model.
    :param float variational_information_distillation_factor: scaling factor for variational information distillation.
    :param float knowledge_distillation_factor: scaling factor for knowledge distillation.
    :param float knowledge_distillation_temperature: degree of smoothing on distributions for computing the Kuback-Leibler 
    divergence for knowledge distillation. 
    :param string state_dict_path: path to save the student model.
    :param string teacher_state_dict_path: path to load the teacher model from.
    :param int device: device to use for training the model
    :param string data_dir: directory to save and load the dataset.
    :param int num_workers: number of workers to use for loading the dataset.
    """

    # Define the device for training the model.
    device = torch.device(device)

    # Get data loaders for the CIFAR-10 dataset.
    train_loader, validation_loader, test_loader = get_cifar10_loaders(
        data_dir, batch_size=BATCH_SIZE, num_workers=num_workers)

    # Construct the student model to be trained.
    model = StudentWideResidualNetwork(depth=student_depth,
                                       width=student_width,
                                       teacher_width=teacher_width)
    model = model.to(device)

    # Construct and load the teacher model for guiding the student model.
    teacher_model = TeacherWideResidualNetwork(
        depth=teacher_depth,
        width=teacher_width,
        load_path=teacher_state_dict_path)
    teacher_model = teacher_model.to(device)

    # Define optimizer and learning rate scheduler
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LEARNING_RATE,
                                momentum=MOMENTUM,
                                weight_decay=WEIGHT_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=LEARNING_RATE_DECAY_MILESTONES,
        gamma=LEARNING_RATE_DECAY_FACTOR)

    # Construct the loss function to be used for training.
    label_criterion = torch.nn.CrossEntropyLoss()
    teacher_logit_criterion = TemperatureScaledKLDivLoss(
        temperature=knowledge_distillation_temperature)
    teacher_feature_criterion = GaussianLoss()
    criterion = EnsembleKnowledgeTransferLoss(
        label_criterion=label_criterion,
        teacher_logit_criterion=teacher_logit_criterion,
        teacher_feature_criterion=teacher_feature_criterion,
        teacher_logit_factor=knowledge_distillation_factor,
        teacher_feature_factor=variational_information_distillation_factor,
    )

    # Define the ignite engines for training and evaluation.
    batch_updater = BatchUpdaterWithTransfer(model=model,
                                             teacher_model=teacher_model,
                                             optimizer=optimizer,
                                             criterion=criterion,
                                             device=device)
    batch_evaluator = BatchEvaluator(model=model, device=device)
    trainer = Engine(batch_updater)
    evaluator = Engine(batch_evaluator)

    # Define and attach the progress bar, loss metric, and the accuracy metrics.
    attach_pbar_and_metrics(trainer, evaluator)

    # The training engine updates the learning rate schedule at end of each epoch.
    lr_updater = LearningRateUpdater(lr_scheduler=lr_scheduler)
    trainer.on(Events.EPOCH_COMPLETED(every=1))(lr_updater)

    # The training engine logs the training and the evaluation metrics at end of each epoch.
    metric_logger = MetricLogger(evaluator=evaluator,
                                 eval_loader=validation_loader)
    trainer.on(Events.EPOCH_COMPLETED(every=1))(metric_logger)

    # Train the model
    trainer.run(train_loader, max_epochs=max_epochs)

    # Save the model to pre-defined path. We move the model to CPU which is desirable as the default device
    # for loading the model.
    model.cpu()
    state_dict_dir = "/".join(state_dict_path.split("/")[:-1])
    os.makedirs(state_dict_dir, exist_ok=True)
    torch.save(model.state_dict(), state_dict_path)