def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, interp, criterion,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    losses_t = AverageMeter('Loss (t)', ':3.2f')
    losses_entropy_t = AverageMeter('Entropy (t)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    accuracies_t = Meter('Acc (t)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')
    iou_t = Meter('IoU (t)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    confmat_t = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_t, losses_entropy_t,
        accuracies_s, accuracies_t, iou_s, iou_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s = next(train_source_iter)
        x_t, label_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        x_t = x_t.to(device)
        label_t = label_t.long().to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        y_t = model(x_t)
        pred_t = interp(y_t)
        loss_cls_t = criterion(pred_t, label_t)
        loss_entropy_t = robust_entropy(y_t, args.ita)
        (args.entropy_weight * loss_entropy_t).backward()

        # compute gradient and do SGD step
        optimizer.step()
        lr_scheduler.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        losses_t.update(loss_cls_t.item(), x_s.size(0))
        losses_entropy_t.update(loss_entropy_t.item(), x_s.size(0))

        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        acc_global_t, acc_t, iu_t = confmat_t.compute()
        accuracies_s.update(acc_s.mean().item())
        accuracies_t.update(acc_t.mean().item())
        iou_s.update(iu_s.mean().item())
        iou_t.update(iu_t.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
                visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          mkmmd_loss: MultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':5.4f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mkmmd_loss.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = mkmmd_loss(f_s, f_t)
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter, train_target_iter, model, criterion,
          regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
          lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ":.2e")
    losses_gf = AverageMeter('Loss (t, false)', ":.2e")
    losses_gt = AverageMeter('Loss (t, truth)', ":.2e")
    acc_s = AverageMeter("Acc (s)", ":3.2f")
    acc_t = AverageMeter("Acc (t)", ":3.2f")
    acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f")
    acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f")

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t,
        acc_s_adv, acc_t_adv
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, label_s, weight_s, meta_s = next(train_source_iter)
        x_t, label_t, weight_t, meta_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.to(device)
        weight_s = weight_s.to(device)

        x_t = x_t.to(device)
        label_t = label_t.to(device)
        weight_t = weight_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step A train all networks to minimize loss on source domain
        optimizer_f.zero_grad()
        optimizer_h.zero_grad()
        optimizer_h_adv.zero_grad()

        y_s, y_s_adv = model(x_s)
        loss_s = criterion(y_s, label_s, weight_s) + \
                 args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')
        loss_s.backward()
        optimizer_f.step()
        optimizer_h.step()
        optimizer_h_adv.step()

        # Step B train adv regressor to maximize regression disparity
        optimizer_h_adv.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_false = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='max')
        loss_ground_false.backward()
        optimizer_h_adv.step()

        # Step C train feature extractor to minimize regression disparity
        optimizer_f.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_truth = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='min')
        loss_ground_truth.backward()
        optimizer_f.step()

        # do update step
        model.step()
        lr_scheduler_f.step()
        lr_scheduler_h.step()
        lr_scheduler_h_adv.step()

        # measure accuracy and record loss
        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
                                               label_s.detach().cpu().numpy())
        acc_s.update(avg_acc_s, cnt_s)
        _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),
                                               label_t.detach().cpu().numpy())
        acc_t.update(avg_acc_t, cnt_t)
        _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(
            y_s_adv.detach().cpu().numpy(),
            label_s.detach().cpu().numpy())
        acc_s_adv.update(avg_acc_s_adv, cnt_s)
        _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(
            y_t_adv.detach().cpu().numpy(),
            label_t.detach().cpu().numpy())
        acc_t_adv.update(avg_acc_t_adv, cnt_t)
        losses_s.update(loss_s, cnt_s)
        losses_gf.update(loss_ground_false, cnt_s)
        losses_gt.update(loss_ground_truth, cnt_s)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if visualize is not None:
                visualize(x_s[0],
                          pred_s[0] * args.image_size / args.heatmap_size,
                          "source_{}_pred".format(i))
                visualize(x_s[0], meta_s['keypoint2d'][0],
                          "source_{}_label".format(i))
                visualize(x_t[0],
                          pred_t[0] * args.image_size / args.heatmap_size,
                          "target_{}_pred".format(i))
                visualize(x_t[0], meta_t['keypoint2d'][0],
                          "target_{}_label".format(i))
                visualize(x_s[0],
                          pred_s_adv[0] * args.image_size / args.heatmap_size,
                          "source_adv_{}_pred".format(i))
                visualize(x_t[0],
                          pred_t_adv[0] * args.image_size / args.heatmap_size,
                          "target_adv_{}_pred".format(i))
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          criterion_ce: CrossEntropyLossWithLabelSmooth,
          criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_ce = AverageMeter('CeLoss', ':3.2f')
    losses_triplet = AverageMeter('TripletLoss', ':3.2f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, _, labels_s, _ = next(train_source_iter)
        x_t, _, _, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        # cross entropy loss
        loss_ce = criterion_ce(y_s, labels_s)
        # triplet loss
        loss_triplet = criterion_triplet(f_s, f_s, labels_s)
        loss = loss_ce + loss_triplet * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        losses_ce.update(loss_ce.item(), x_s.size(0))
        losses_triplet.update(loss_triplet.item(), x_s.size(0))
        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
示例#5
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv: DomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = domain_adv(f_s, f_t)
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
示例#6
0
def train(train_iter: ForeverDataIterator, model: Classifier,
          backbone_regularization: nn.Module, head_regularization: nn.Module,
          target_getter: IntermediateLayerGetter,
          source_getter: IntermediateLayerGetter, optimizer: SGD, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f')
    losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, losses_reg_head, losses_reg_backbone,
        cls_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)
        x = x.to(device)
        label = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        intermediate_output_s, output_s = source_getter(x)
        intermediate_output_t, output_t = target_getter(x)
        y, f = output_t

        # measure accuracy and record loss
        cls_acc = accuracy(y, label)[0]
        cls_loss = F.cross_entropy(y, label)
        if args.regularization_type == 'feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        elif args.regularization_type == 'attention_feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        else:
            loss_reg_backbone = backbone_regularization()
        loss_reg_head = head_regularization()
        loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head

        losses_reg_backbone.update(
            loss_reg_backbone.item() * args.trade_off_backbone, x.size(0))
        losses_reg_head.update(loss_reg_head.item() * args.trade_off_head,
                               x.size(0))
        losses.update(loss.item(), x.size(0))
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S,
          netD_T, criterion_gan, criterion_cycle, criterion_identity,
          optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
    losses_D_S = AverageMeter('D_S', ':3.2f')
    losses_D_T = AverageMeter('D_T', ':3.2f')
    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
    losses_identity_S = AverageMeter('idt_S', ':3.2f')
    losses_identity_T = AverageMeter('idt_T', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S,
        losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S,
        losses_identity_T
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()

    for i in range(args.iters_per_epoch):
        real_S, _ = next(train_source_iter)
        real_T, _ = next(train_target_iter)

        real_S = real_S.to(device)
        real_T = real_T.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Compute fake images and reconstruction images.
        fake_T = netG_S2T(real_S)
        rec_S = netG_T2S(fake_T)
        fake_S = netG_T2S(real_T)
        rec_T = netG_S2T(fake_S)

        # Optimizing generators
        # discriminators require no gradients
        set_requires_grad(netD_S, False)
        set_requires_grad(netD_T, False)

        optimizer_G.zero_grad()
        # GAN loss D_T(G_S2T(S))
        loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
        # GAN loss D_S(G_T2S(B))
        loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
        # Cycle loss || G_T2S(G_S2T(S)) - S||
        loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle
        # Cycle loss || G_S2T(G_T2S(T)) - T||
        loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle
        # Identity loss
        # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
        identity_T = netG_S2T(real_T)
        loss_identity_T = criterion_identity(identity_T,
                                             real_T) * args.trade_off_identity
        # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
        identity_S = netG_T2S(real_S)
        loss_identity_S = criterion_identity(identity_S,
                                             real_S) * args.trade_off_identity
        # combined loss and calculate gradients
        loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
        loss_G.backward()
        optimizer_G.step()

        # Optimize discriminator
        set_requires_grad(netD_S, True)
        set_requires_grad(netD_T, True)
        optimizer_D.zero_grad()
        # Calculate GAN loss for discriminator D_S
        fake_S_ = fake_S_pool.query(fake_S.detach())
        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) +
                          criterion_gan(netD_S(fake_S_), False))
        loss_D_S.backward()
        # Calculate GAN loss for discriminator D_T
        fake_T_ = fake_T_pool.query(fake_T.detach())
        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) +
                          criterion_gan(netD_T(fake_T_), False))
        loss_D_T.backward()
        optimizer_D.step()

        # measure elapsed time
        losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
        losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
        losses_D_S.update(loss_D_S.item(), real_S.size(0))
        losses_D_T.update(loss_D_T.item(), real_S.size(0))
        losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
        losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
        losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
        losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            for tensor, name in zip([
                    real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S,
                    identity_T
            ], [
                    "real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T",
                    "identity_S", "identity_T"
            ]):
                visualize(tensor[0], "{}_{}".format(i, name))
示例#8
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    norm_losses = AverageMeter('Norm Loss', ':3.2f')
    src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')
    tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, cls_losses, norm_losses, src_feature_norm,
        tgt_feature_norm, cls_accs, tgt_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)
        # norm loss
        norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)

        loss = cls_loss + norm_loss * args.trade_off_norm

        # using entropy minimization
        if args.trade_off_entropy:
            y_t = F.softmax(y_t, dim=1)
            entropy_loss = entropy(y_t, reduction='mean')
            loss += entropy_loss * args.trade_off_entropy

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update statistics
        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        cls_losses.update(cls_loss.item(), x_s.size(0))
        norm_losses.update(norm_loss.item(), x_s.size(0))
        src_feature_norm.update(
            f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))
        tgt_feature_norm.update(
            f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
示例#9
0
def calculate_channel_attention(dataset, return_layers, args):
    backbone = models.__dict__[args.arch](pretrained=True)
    classifier = Classifier(backbone, dataset.num_classes).to(device)
    optimizer = SGD(classifier.get_parameters(args.lr),
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    data_loader = DataLoader(dataset,
                             batch_size=args.attention_batch_size,
                             shuffle=True,
                             num_workers=args.workers,
                             drop_last=False)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=math.exp(math.log(0.1) / args.attention_lr_decay_epochs))
    criterion = nn.CrossEntropyLoss()

    channel_weights = []
    for layer_id, name in enumerate(return_layers):
        layer = get_attribute(classifier, name)
        layer_channel_weight = [0] * layer.out_channels
        channel_weights.append(layer_channel_weight)

    # train the classifier
    classifier.train()
    classifier.backbone.requires_grad = False
    print("Pretrain a classifier to calculate channel attention.")
    for epoch in range(args.attention_epochs):
        losses = AverageMeter('Loss', ':3.2f')
        cls_accs = AverageMeter('Cls Acc', ':3.1f')
        progress = ProgressMeter(len(data_loader), [losses, cls_accs],
                                 prefix="Epoch: [{}]".format(epoch))

        for i, data in enumerate(data_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs, _ = classifier(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_acc = accuracy(outputs, labels)[0]

            losses.update(loss.item(), inputs.size(0))
            cls_accs.update(cls_acc.item(), inputs.size(0))

            if i % args.print_freq == 0:
                progress.display(i)
        lr_scheduler.step()

    # calculate the channel attention
    print('Calculating channel attention.')
    classifier.eval()
    if args.attention_iteration_limit > 0:
        total_iteration = min(len(data_loader), args.attention_iteration_limit)
    else:
        total_iteration = len(args.data_loader)

    progress = ProgressMeter(total_iteration, [], prefix="Iteration: ")

    for i, data in enumerate(data_loader):
        if i >= total_iteration:
            break
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs, _ = classifier(inputs)
        loss_0 = criterion(outputs, labels)
        progress.display(i)
        for layer_id, name in enumerate(tqdm(return_layers)):
            layer = get_attribute(classifier, name)
            for j in range(layer.out_channels):
                tmp = classifier.state_dict()[name + '.weight'][j, ].clone()
                classifier.state_dict()[name + '.weight'][j, ] = 0.0
                outputs, _ = classifier(inputs)
                loss_1 = criterion(outputs, labels)
                difference = loss_1 - loss_0
                difference = difference.detach().cpu().numpy().item()
                history_value = channel_weights[layer_id][j]
                channel_weights[layer_id][j] = 1.0 * (i * history_value +
                                                      difference) / (i + 1)
                classifier.state_dict()[name + '.weight'][j, ] = tmp

    channel_attention = []
    for weight in channel_weights:
        weight = np.array(weight)
        weight = (weight - np.mean(weight)) / np.std(weight)
        weight = torch.from_numpy(weight).float().to(device)
        channel_attention.append(F.softmax(weight / 5).detach())
    return channel_attention
示例#10
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: Regressor, rsd,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    mse_losses = AverageMeter('MSE Loss', ':6.3f')
    rsd_losses = AverageMeter('RSD Loss', ':6.3f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, mse_losses, rsd_losses, mae_losses_s,
        mae_losses_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        mse_loss = F.mse_loss(y_s, labels_s)
        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)
        rsd_loss = rsd(f_s, f_t)
        loss = mse_loss + rsd_loss * args.trade_off

        mse_losses.update(mse_loss.item(), x_s.size(0))
        rsd_losses.update(rsd_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
示例#11
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          mdd: MarginDisparityDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    source_losses = AverageMeter('Source Loss', ':6.3f')
    trans_losses = AverageMeter('Trans Loss', ':6.3f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, source_losses, trans_losses, mae_losses_s,
        mae_losses_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mdd.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat([x_s, x_t], dim=0)
        outputs, outputs_adv = model(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute mean square loss on source domain
        mse_loss = F.mse_loss(y_s, labels_s)

        # compute margin disparity discrepancy between domains
        transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv)
        # for adversarial classifier, minimize negative mdd is equal to maximize mdd
        loss = mse_loss - transfer_loss * args.trade_off
        model.step()

        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)

        source_losses.update(mse_loss.item(), x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
示例#12
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, interp, criterion,
          dann, optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD,
          lr_scheduler_d: LambdaLR, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    losses_transfer = AverageMeter('Loss (transfer)', ':3.2f')
    losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    accuracies_t = Meter('Acc (t)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')
    iou_t = Meter('IoU (t)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    confmat_t = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        accuracies_s, accuracies_t, iou_s, iou_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, label_s = next(train_source_iter)
        x_t, label_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        x_t = x_t.to(device)
        label_t = label_t.long().to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        optimizer.zero_grad()
        optimizer_d.zero_grad()

        # Step 1: Train the segmentation network, freeze the discriminator
        dann.eval()
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        # adversarial training to fool the discriminator
        y_t = model(x_t)
        pred_t = interp(y_t)
        loss_transfer = dann(pred_t, 'source')
        (loss_transfer * args.trade_off).backward()

        # Step 2: Train the discriminator
        dann.train()
        loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') +
                                    dann(pred_t.detach(), 'target'))
        loss_discriminator.backward()

        # compute gradient and do SGD step
        optimizer.step()
        optimizer_d.step()
        lr_scheduler.step()
        lr_scheduler_d.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        acc_global_t, acc_t, iu_t = confmat_t.compute()
        accuracies_s.update(acc_s.mean().item())
        accuracies_t.update(acc_t.mean().item())
        iou_s.update(iu_s.mean().item())
        iou_t.update(iu_t.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
                visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
示例#13
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, G: nn.Module,
          F1: ImageClassifierHead, F2: ImageClassifierHead, optimizer_g: SGD,
          optimizer_f: SGD, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    G.train()
    F1.train()
    F2.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)
        x = torch.cat((x_s, x_t), dim=0)
        assert x.requires_grad is False

        # measure data loading time
        data_time.update(time.time() - end)

        # Step A train all networks to minimize loss on source domain
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

        g = G(x)
        y_1 = F1(g)
        y_2 = F2(g)
        y1_s, y1_t = y_1.chunk(2, dim=0)
        y2_s, y2_t = y_2.chunk(2, dim=0)

        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
        loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \
               0.01 * (entropy(y1_t) + entropy(y2_t))
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        # Step B train classifier to maximize discrepancy
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

        g = G(x)
        y_1 = F1(g)
        y_2 = F2(g)
        y1_s, y1_t = y_1.chunk(2, dim=0)
        y2_s, y2_t = y_2.chunk(2, dim=0)
        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
        loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \
               0.01 * (entropy(y1_t) + entropy(y2_t)) - classifier_discrepancy(y1_t, y2_t) * args.trade_off
        loss.backward()
        optimizer_f.step()

        # Step C train genrator to minimize discrepancy
        for k in range(args.num_k):
            optimizer_g.zero_grad()
            g = G(x)
            y_1 = F1(g)
            y_2 = F2(g)
            y1_s, y1_t = y_1.chunk(2, dim=0)
            y2_s, y2_t = y_2.chunk(2, dim=0)
            y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
            mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off
            mcd_loss.backward()
            optimizer_g.step()

        cls_acc = accuracy(y1_s, labels_s)[0]
        tgt_acc = accuracy(y1_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(mcd_loss.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
示例#14
0
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer,
          lr_scheduler: CosineAnnealingLR, epoch: int,
          n_domains_per_batch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch,
                             [batch_time, data_time, losses, cls_accs],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels, _ = next(train_iter)
        x = x.to(device)
        labels = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # split into support domain and query domain
        x_list = x.chunk(n_domains_per_batch, dim=0)
        labels_list = labels.chunk(n_domains_per_batch, dim=0)
        support_domain_list, query_domain_list = random_split(
            x_list, labels_list, n_domains_per_batch, args.n_support_domains)
        # clear grad
        optimizer.zero_grad()

        # compute output
        with higher.innerloop_ctx(
                model, optimizer,
                copy_initial_weights=False) as (inner_model, inner_optimizer):
            # perform inner optimization
            for _ in range(args.inner_iters):
                loss_inner = 0
                for (x_s, labels_s) in support_domain_list:
                    y_s, _ = inner_model(x_s)
                    # normalize loss by support domain num
                    loss_inner += F.cross_entropy(
                        y_s, labels_s) / args.n_support_domains

                inner_optimizer.step(loss_inner)

            # calculate outer loss
            loss_outer = 0
            cls_acc = 0

            # loss on support domains
            for (x_s, labels_s) in support_domain_list:
                y_s, _ = model(x_s)
                # normalize loss by support domain num
                loss_outer += F.cross_entropy(
                    y_s, labels_s) / args.n_support_domains

            # loss on query domains
            for (x_q, labels_q) in query_domain_list:
                y_q, _ = inner_model(x_q)
                # normalize loss by query domain num
                loss_outer += F.cross_entropy(
                    y_q, labels_q) * args.trade_off / args.n_query_domains
                cls_acc += accuracy(y_q, labels_q)[0] / args.n_query_domains

        # update statistics
        losses.update(loss_outer.item(), args.batch_size)
        cls_accs.update(cls_acc.item(), args.batch_size)

        # compute gradient and do SGD step
        loss_outer.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
示例#15
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, classifier: ImageClassifier,
          mdd: MarginDisparityDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    classifier.train()
    mdd.train()

    criterion = nn.CrossEntropyLoss().to(device)

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        outputs, outputs_adv = classifier(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute cross entropy loss on source domain
        cls_loss = criterion(y_s, labels_s)
        # compute margin disparity discrepancy between domains
        # for adversarial classifier, minimize negative mdd is equal to maximize mdd
        transfer_loss = -mdd(y_s, y_s_adv, y_t, y_t_adv)
        loss = cls_loss + transfer_loss * args.trade_off
        classifier.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
示例#16
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          teacher: EmaTeacher, consistent_loss, class_balance_loss,
          optimizer: Adam, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    cons_losses = AverageMeter('Cons Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, cls_losses, cons_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    teacher.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        (x_t1, x_t2), labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t1 = x_t1.to(device)
        x_t2 = x_t2.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, _ = model(x_s)
        y_t, _ = model(x_t1)
        y_t_teacher, _ = teacher(x_t2)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)
        # compute output and mask
        y_t = F.softmax(y_t, dim=1)
        y_t_teacher = F.softmax(y_t_teacher, dim=1)
        max_prob, _ = y_t_teacher.max(dim=1)
        mask = (max_prob > args.threshold).float()

        # consistent loss
        cons_loss = consistent_loss(y_t, y_t_teacher, mask)
        # balance loss
        balance_loss = class_balance_loss(y_t) * mask.mean()

        loss = cls_loss + args.trade_off_cons * cons_loss + args.trade_off_balance * balance_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # update teacher
        teacher.update()

        # update statistics
        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]
        cls_losses.update(cls_loss.item(), x_s.size(0))
        cons_losses.update(cons_loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)