Exemple #1
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv: DomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        lr_scheduler.step()

        # measure data loading time
        data_time.update(time.time() - end)

        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = domain_adv(f_s, f_t)
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #2
0
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          source_model: ImageClassifier, target_model: ImageClassifier, domain_adv: DomainAdversarialLoss,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_transfer, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    target_model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, _ = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        _, f_s = source_model(x_s)
        _, f_t = target_model(x_t)
        loss_transfer = domain_adv(f_s, f_t)

        # Compute gradient and do SGD step
        optimizer.zero_grad()
        loss_transfer.backward()
        optimizer.step()
        lr_scheduler.step()

        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        domain_acc = domain_adv.domain_discriminator_accuracy
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #3
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv_D: DomainAdversarialLoss,
          domain_adv_D_0: DomainAdversarialLoss, importance_weight_module,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f')
    domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f')
    partial_classes_weights = AverageMeter('Partial Weight', ':3.2f')
    non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D,
        domain_accs_D_0, partial_classes_weights, non_partial_classes_weights
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv_D.train()
    domain_adv_D_0.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)

        # domain adversarial loss for D
        adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach())

        # get importance weights
        w_s = importance_weight_module.get_importance_weight(f_s)
        # domain adversarial loss for D_0
        adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s)

        # entropy loss
        y_t = F.softmax(y_t, dim=1)
        entropy_loss = entropy(y_t, reduction='mean')

        loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \
               args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy,
                             x_s.size(0))
        domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy,
                               x_s.size(0))

        # debug: output class weight averaged on the partial classes and non-partial classes respectively
        partial_class_weight, non_partial_classes_weight = \
            importance_weight_module.get_partial_classes_weight(w_s, labels_s)
        partial_classes_weights.update(partial_class_weight.item(),
                                       x_s.size(0))
        non_partial_classes_weights.update(non_partial_classes_weight.item(),
                                           x_s.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: Regressor,
          domain_adv: DomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    mse_losses = AverageMeter('MSE Loss', ':6.3f')
    dann_losses = AverageMeter('DANN Loss', ':6.3f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, mse_losses, dann_losses, mae_losses_s,
        mae_losses_t, domain_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        mse_loss = F.mse_loss(y_s, labels_s)
        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)
        transfer_loss = domain_adv(f_s, f_t)
        loss = mse_loss + transfer_loss * args.trade_off
        domain_acc = domain_adv.domain_discriminator_accuracy

        mse_losses.update(mse_loss.item(), x_s.size(0))
        dann_losses.update(transfer_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #5
0
def train(feature_extractor: FeatureExtractor,
          domain_adv: DomainAdversarialLoss, src_iter: ForeverDataIterator,
          tar_iter: ForeverDataIterator, src_val_loader, tar_val_loader):
    optimizer = Adam(itertools.chain(feature_extractor.parameters(),
                                     domain_adv.parameters()),
                     lr=lr,
                     weight_decay=weight_decay)

    npair_loss = NPairsLoss()  # n pair loss

    # loss
    loss_rec = AverageMeter('tot_loss', tb_tag='Loss/tot', writer=writer)
    loss_lb_rec = AverageMeter('lb_loss', tb_tag='Loss/lb', writer=writer)
    loss_lb_g_rec = AverageMeter('lb_g_loss',
                                 tb_tag='Loss/lb_g',
                                 writer=writer)
    # loss_ulb_rec = AverageMeter('ulb_loss', tb_tag='Loss/ulb')
    loss_da_rec = AverageMeter('da_loss', tb_tag='Loss/da', writer=writer)

    # acc
    da_acc_rec = AverageMeter('da_acc', tb_tag='Acc/da', writer=writer)

    n_iter = 0
    best_nmi = 0
    for e_i in range(epoch):
        feature_extractor.train()
        domain_adv.train()
        progress = ProgressMeter(
            iter_per_epoch,
            [loss_lb_g_rec, loss_lb_rec, loss_da_rec, da_acc_rec],
            prefix="Epoch: [{}]".format(e_i),
            logger=global_logger)
        for i in range(iter_per_epoch):
            x_s, l_s = next(src_iter)
            x_t, l_t = next(tar_iter)
            # for obj in [x_s, x_t, l_s, l_t]: # to device
            # obj = obj.to(device)

            x_s, l_s, x_t, l_t = x_s.to(device), l_s.to(device), x_t.to(
                device), l_t.to(device)

            x = torch.cat((x_s, x_t), dim=0)
            f, g = feature_extractor(x)
            f_s, f_t = f.chunk(2, dim=0)
            g_s, g_t = g.chunk(2, dim=0)

            # source only part
            loss_s = npair_loss(f_s, l_s)  # get n-pair loss on source domain
            loss_s_g = npair_loss(g_s, l_s)  # get n-pair loss on source domain
            loss_lb_rec.update(loss_s.item(), x_s.size(0), iter=n_iter)
            loss_lb_g_rec.update(loss_s_g.item(), x_s.size(0), iter=n_iter)

            # dann
            # da_loss = domain_adv(f_s,f_t)
            da_loss = domain_adv(g_s, f_t)
            domain_acc = domain_adv.domain_discriminator_accuracy
            loss_da_rec.update(da_loss.item(), f.size(0), iter=n_iter)
            da_acc_rec.update(domain_acc.item(), f.size(0), iter=n_iter)

            loss = 0.5 * (loss_s + loss_s_g) + w_da * da_loss
            # loss = loss_s
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n_iter += 1
            if i % print_freq == 0:
                progress.display(i)

        if e_i % 5 == 0:
            # global_logger.info(f"saving embedding in epoch{e_i}")
            # # show embedding
            # show_embedding(backbone, [src_val_loader], tag=f'src_{e_i}', epoch=e_i, writer, device)
            # show_embedding(backbone, [tar_val_loader], tag=f'tar_{e_i}', epoch=e_i, writer, device)

            nmi = NMI_eval(feature_extractor,
                           src_val_loader,
                           5,
                           device,
                           type='src')
            global_logger.info(f'test on train set nmi: {nmi}')
            nmi = NMI_eval(feature_extractor,
                           tar_val_loader,
                           5,
                           device,
                           type='tar')
            global_logger.info(f'test on test set nmi: {nmi}')
            if nmi > best_nmi:
                global_logger.info(f"save best model to {model_dir}")
                torch.save(backbone.state_dict(),
                           os.path.join(model_dir, 'minst_best_model.pth'))
                best_nmi = nmi