コード例 #1
0
ファイル: kd_train.py プロジェクト: imatif17/KD-UDA
def hinton_train_without_label(teacher_model, student_model, T, optimizer, device, train_loader, is_debug=False):
    total_loss = 0.

    # One epoch step gradient for target
    optimizer.zero_grad()
    teacher_model.train()
    student_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        if torch.cuda.device_count() > 1:
            teacher_logits = teacher_model.module.nforward(data)
            student_logits = student_model.module.nforward(data)
        else:
            teacher_logits = teacher_model.nforward(data)
            student_logits = student_model.nforward(data)

        loss = hinton_distillation_wo_ce(teacher_logits, student_logits, T)
        total_loss += float(loss.item())
        loss.backward()
        # torch.nn.utils.clip_grad_value_(model.parameters(), 10)
        optimizer.step()
        if is_debug:
            break

    del loss
    del teacher_logits
    del student_logits
    # torch.cuda.empty_cache()
    return total_loss / len(train_loader)
コード例 #2
0
def grl_multi_target_hinton_train_alt(current_ep,
                                      epochs,
                                      teacher_models,
                                      student_model,
                                      optimizer_das,
                                      optimizer_kd,
                                      device,
                                      source_dataloader,
                                      targets_dataloader,
                                      T,
                                      alpha,
                                      beta,
                                      gamma,
                                      batch_norm,
                                      is_cst,
                                      is_debug=False,
                                      **kwargs):

    logger = kwargs["logger"]
    if "logger_id" not in kwargs:
        logger_id = ""
    else:
        logger_id = kwargs["logger_id"]

    if batch_norm:
        for teacher_model in teacher_models:
            teacher_model.train()
        student_model.train()

    total_losses = torch.zeros(len(teacher_models))
    teacher_da_temp_losses = torch.zeros(len(teacher_models))
    kd_temp_losses = torch.zeros(len(teacher_models))
    kd_target_loss = 0.
    kd_source_loss = 0.

    iter_targets = [0] * len(targets_dataloader)
    for i, d in enumerate(targets_dataloader):
        iter_targets[i] = iter(d)

    iter_source = iter(source_dataloader)

    for i in range(1, len(source_dataloader) + 1):

        data_source, label_source = iter_source.next()
        data_source = data_source.to(device)
        label_source = label_source.to(device)

        for ix, it in enumerate(iter_targets):
            try:
                data_target, _ = it.next()
            except StopIteration:
                it = iter(targets_dataloader[ix])
                data_target, _ = it.next()

            if data_target.shape[0] != data_source.shape[0]:
                data_target = data_target[:data_source.shape[0]]
            data_target = data_target.to(device)
            optimizer_das[ix].zero_grad()
            p = float(i + (current_ep - 1) *
                      len(source_dataloader)) / epochs / len(source_dataloader)
            delta = 2. / (1. + np.exp(-10 * p)) - 1
            teacher_label_source_pred, teacher_source_loss_adv = teacher_models[
                ix](data_source, delta)
            teacher_source_loss_cls = F.cross_entropy(
                F.log_softmax(teacher_label_source_pred, dim=1), label_source)

            _, teacher_target_loss_adv = teacher_models[ix](data_target,
                                                            delta,
                                                            source=False)
            teacher_loss_adv = teacher_source_loss_adv + teacher_target_loss_adv

            teacher_da_grl_loss = (1 - beta) * (teacher_source_loss_cls +
                                                gamma * teacher_loss_adv)
            teacher_da_temp_losses[ix] += teacher_da_grl_loss.mean().item()

            teacher_da_grl_loss.mean().backward()
            optimizer_das[ix].step()  # May need to have 2 optimizers
            optimizer_das[ix].zero_grad()

            optimizer_kd.zero_grad()
            teacher_source_logits, _ = teacher_models[ix](data_source,
                                                          delta,
                                                          source=True)
            teacher_target_logits, _ = teacher_models[ix](data_target,
                                                          delta,
                                                          source=True)

            student_source_logits, _ = student_model(data_source,
                                                     delta,
                                                     source=True)
            student_target_logits, student_target_loss_adv = student_model(
                data_target, delta, source=False)

            source_kd_loss = hinton_distillation_sw(teacher_source_logits,
                                                    student_source_logits,
                                                    label_source, T,
                                                    alpha).abs()
            if is_cst:
                target_kd_loss = hinton_distillation_wo_ce(
                    teacher_target_logits, student_target_logits,
                    T).abs() + alpha * student_target_loss_adv
            else:
                target_kd_loss = hinton_distillation_wo_ce(
                    teacher_target_logits, student_target_logits, T).abs()

            kd_source_loss += source_kd_loss.mean().item()
            kd_target_loss += target_kd_loss.mean().item()

            kd_loss = beta * (target_kd_loss + source_kd_loss)
            kd_temp_losses[ix] += kd_loss.mean().item()
            total_losses[ix] += teacher_da_grl_loss.mean().item(
            ) + kd_loss.mean().item()

            kd_loss.mean().backward()
            optimizer_kd.step()
            optimizer_kd.zero_grad()

        if is_debug:
            break

    del kd_loss
    del teacher_da_grl_loss
    # torch.cuda.empty_cache()
    return total_losses / len(source_dataloader), teacher_da_temp_losses / len(source_dataloader), \
           kd_temp_losses/ len(source_dataloader)
コード例 #3
0
ファイル: kd_da_alt.py プロジェクト: imatif17/KD-UDA
def mmd_hinton_train_alt(current_epoch,
                         epochs,
                         teacher_model,
                         student_model,
                         optimizer_da,
                         optimizer_kd,
                         device,
                         source_dataloader,
                         target_dataloader,
                         T,
                         alpha,
                         beta,
                         kd_loss_fn,
                         is_debug=False,
                         **kwargs):

    logger = kwargs["logger"]
    if "logger_id" not in kwargs:
        logger_id = ""
    else:
        logger_id = kwargs["logger_id"]

    #teacher_model.train()
    #student_model.train()

    total_loss = 0.
    teacher_da_temp_loss = 0.
    kd_temp_loss = 0.
    kd_target_loss = 0.
    kd_source_loss = 0.

    iter_source = iter(source_dataloader)
    iter_target = iter(target_dataloader)

    for i in range(1, len(source_dataloader) + 1):

        data_source, label_source = iter_source.next()
        data_target, _ = iter_target.next()

        if data_source.shape[0] != data_target.shape[0]:
            if data_target.shape[0] < source_dataloader.batch_size:
                iter_target = iter(target_dataloader)
                data_target, _ = iter_target.next()

            if data_source.shape[0] < source_dataloader.batch_size:
                data_target = data_target[:data_source.shape[0]]

        data_source, label_source = data_source.to(device), label_source.to(
            device)
        data_target = data_target.to(device)

        # Teacher domain adaptation
        optimizer_da.zero_grad()
        teacher_label_source_pred, teacher_loss_mmd, _ = teacher_model(
            data_source, data_target)
        teacher_source_loss_cls = F.nll_loss(
            F.log_softmax(teacher_label_source_pred, dim=1), label_source)
        gamma = 2 / (1 + np.exp(-10 * (i) / len(source_dataloader))) - 1
        teacher_da_mmd_loss = (1 - beta) * (teacher_source_loss_cls +
                                            gamma * teacher_loss_mmd)
        teacher_da_temp_loss += teacher_da_mmd_loss.mean().item()

        # Possible to do end2end or alternative here: For now it's alternative
        teacher_da_mmd_loss.mean().backward()
        optimizer_da.step()  # May need to have 2 optimizers
        optimizer_da.zero_grad()

        #Knowledge distillation: We only learn on target logits now

        optimizer_kd.zero_grad()
        teacher_source_logits, teacher_loss_mmd, teacher_target_logits = teacher_model(
            data_source, data_target)
        student_source_logits, student_loss_mmd, student_target_logits = student_model(
            data_source, data_target)

        source_kd_loss = hinton_distillation(teacher_source_logits,
                                             student_source_logits,
                                             label_source, T, alpha,
                                             kd_loss_fn).abs()
        target_kd_loss = hinton_distillation_wo_ce(teacher_target_logits,
                                                   student_target_logits, T,
                                                   kd_loss_fn).abs()

        kd_source_loss += source_kd_loss.mean().item()
        kd_target_loss += target_kd_loss.mean().item()

        kd_loss = beta * (target_kd_loss + source_kd_loss)
        kd_temp_loss += kd_loss.mean().item()
        total_loss += teacher_da_mmd_loss.mean().item() + kd_loss.mean().item()

        kd_loss.mean().backward()
        optimizer_kd.step()
        optimizer_kd.zero_grad()

        if logger is not None:
            logger.log_scalar("iter_total_training_loss".format(logger_id),
                              teacher_da_mmd_loss.item() + kd_loss.item(), i)
            logger.log_scalar("iter_total_da_loss".format(logger_id),
                              teacher_da_mmd_loss.item(), i)
            logger.log_scalar("iter_total_kd_loss".format(logger_id),
                              kd_loss.item(), i)

        if is_debug:
            break

    del kd_loss
    del teacher_da_mmd_loss
    # torch.cuda.empty_cache()
    return total_loss / len(source_dataloader), teacher_da_temp_loss / len(source_dataloader), \
           kd_temp_loss / len(source_dataloader), kd_source_loss / len(source_dataloader), kd_target_loss / len(source_dataloader)