Пример #1
0
def hinton_train(model, student_model, T, alpha, optimizer, device, train_loader, is_debug=False):
    total_loss = 0.

    # One epoch step gradient for target
    optimizer.zero_grad()
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        if torch.cuda.device_count() > 1:
            teacher_logits = model.module.nforward(data)
            student_logits = student_model.module.nforward(data)
        else:
            teacher_logits = model.nforward(data)
            student_logits = student_model.nforward(data)

        loss = hinton_distillation(teacher_logits, student_logits, target, T, alpha)
        total_loss += float(loss.item())
        loss.backward()
        # torch.nn.utils.clip_grad_value_(model.parameters(), 10)
        optimizer.step()
        if is_debug:
            break

    del loss
    del teacher_logits
    del student_logits
    # torch.cuda.empty_cache()
    return total_loss / len(train_loader)
Пример #2
0
def mmd_hinton_train_alt(current_epoch,
                         epochs,
                         teacher_model,
                         student_model,
                         optimizer_da,
                         optimizer_kd,
                         device,
                         source_dataloader,
                         target_dataloader,
                         T,
                         alpha,
                         beta,
                         kd_loss_fn,
                         is_debug=False,
                         **kwargs):

    logger = kwargs["logger"]
    if "logger_id" not in kwargs:
        logger_id = ""
    else:
        logger_id = kwargs["logger_id"]

    #teacher_model.train()
    #student_model.train()

    total_loss = 0.
    teacher_da_temp_loss = 0.
    kd_temp_loss = 0.
    kd_target_loss = 0.
    kd_source_loss = 0.

    iter_source = iter(source_dataloader)
    iter_target = iter(target_dataloader)

    for i in range(1, len(source_dataloader) + 1):

        data_source, label_source = iter_source.next()
        data_target, _ = iter_target.next()

        if data_source.shape[0] != data_target.shape[0]:
            if data_target.shape[0] < source_dataloader.batch_size:
                iter_target = iter(target_dataloader)
                data_target, _ = iter_target.next()

            if data_source.shape[0] < source_dataloader.batch_size:
                data_target = data_target[:data_source.shape[0]]

        data_source, label_source = data_source.to(device), label_source.to(
            device)
        data_target = data_target.to(device)

        # Teacher domain adaptation
        optimizer_da.zero_grad()
        teacher_label_source_pred, teacher_loss_mmd, _ = teacher_model(
            data_source, data_target)
        teacher_source_loss_cls = F.nll_loss(
            F.log_softmax(teacher_label_source_pred, dim=1), label_source)
        gamma = 2 / (1 + np.exp(-10 * (i) / len(source_dataloader))) - 1
        teacher_da_mmd_loss = (1 - beta) * (teacher_source_loss_cls +
                                            gamma * teacher_loss_mmd)
        teacher_da_temp_loss += teacher_da_mmd_loss.mean().item()

        # Possible to do end2end or alternative here: For now it's alternative
        teacher_da_mmd_loss.mean().backward()
        optimizer_da.step()  # May need to have 2 optimizers
        optimizer_da.zero_grad()

        #Knowledge distillation: We only learn on target logits now

        optimizer_kd.zero_grad()
        teacher_source_logits, teacher_loss_mmd, teacher_target_logits = teacher_model(
            data_source, data_target)
        student_source_logits, student_loss_mmd, student_target_logits = student_model(
            data_source, data_target)

        source_kd_loss = hinton_distillation(teacher_source_logits,
                                             student_source_logits,
                                             label_source, T, alpha,
                                             kd_loss_fn).abs()
        target_kd_loss = hinton_distillation_wo_ce(teacher_target_logits,
                                                   student_target_logits, T,
                                                   kd_loss_fn).abs()

        kd_source_loss += source_kd_loss.mean().item()
        kd_target_loss += target_kd_loss.mean().item()

        kd_loss = beta * (target_kd_loss + source_kd_loss)
        kd_temp_loss += kd_loss.mean().item()
        total_loss += teacher_da_mmd_loss.mean().item() + kd_loss.mean().item()

        kd_loss.mean().backward()
        optimizer_kd.step()
        optimizer_kd.zero_grad()

        if logger is not None:
            logger.log_scalar("iter_total_training_loss".format(logger_id),
                              teacher_da_mmd_loss.item() + kd_loss.item(), i)
            logger.log_scalar("iter_total_da_loss".format(logger_id),
                              teacher_da_mmd_loss.item(), i)
            logger.log_scalar("iter_total_kd_loss".format(logger_id),
                              kd_loss.item(), i)

        if is_debug:
            break

    del kd_loss
    del teacher_da_mmd_loss
    # torch.cuda.empty_cache()
    return total_loss / len(source_dataloader), teacher_da_temp_loss / len(source_dataloader), \
           kd_temp_loss / len(source_dataloader), kd_source_loss / len(source_dataloader), kd_target_loss / len(source_dataloader)