Exemple #1
0
def validate(val_loader: DataLoader, model: Regressor,
             args: argparse.Namespace, factors) -> Tuple[float, float]:
    batch_time = AverageMeter('Time', ':6.3f')
    mae_losses = [
        AverageMeter('mae {}'.format(factor), ':6.3f') for factor in factors
    ]
    progress = ProgressMeter(len(val_loader), [batch_time] + mae_losses,
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            for j in range(len(factors)):
                mae_loss = F.l1_loss(output[:, j], target[:, j])
                mae_losses[j].update(mae_loss.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        for i, factor in enumerate(factors):
            print("{} MAE {mae.avg:6.3f}".format(factor, mae=mae_losses[i]))
        mean_mae = sum(l.avg for l in mae_losses) / len(factors)
    return mean_mae
def validate(val_loader: DataLoader, model: Classifier,
             args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    classes = val_loader.dataset.classes
    confmat = ConfusionMatrix(len(classes))
    progress = ProgressMeter(len(val_loader), [batch_time], prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            softmax_output = F.softmax(output, dim=1)
            softmax_output[:, -1] = args.threshold

            # measure accuracy and record loss
            confmat.update(target, softmax_output.argmax(1))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        acc_global, accs, iu = confmat.compute()
        all_acc = torch.mean(accs).item() * 100
        known = torch.mean(accs[:-1]).item() * 100
        unknown = accs[-1].item() * 100
        h_score = 2 * known * unknown / (known + unknown)
        if args.per_class_eval:
            print(confmat.format(classes))
        print(
            ' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'
            .format(all=all_acc, known=known, unknown=unknown,
                    h_score=h_score))

    return h_score
def extract_reid_feature(data_loader,
                         model,
                         device,
                         normalize,
                         print_freq=200):
    """Extract feature for person ReID. If `normalize` is True, `cosine` distance will be employed as distance
    metric, otherwise `euclidean` distance.
    """
    batch_time = AverageMeter('Time', ':6.3f')
    progress = ProgressMeter(len(data_loader), [batch_time],
                             prefix='Collect feature: ')

    # switch to eval mode
    model.eval()
    feature_dict = dict()

    with torch.no_grad():
        end = time.time()
        for i, (images_batch, filenames_batch, _, _) in enumerate(data_loader):

            images_batch = images_batch.to(device)
            features_batch = model(images_batch)
            if normalize:
                features_batch = F.normalize(features_batch)

            for filename, feature in zip(filenames_batch, features_batch):
                feature_dict[filename] = feature

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                progress.display(i)

    return feature_dict
Exemple #4
0
def validate(val_loader, model, criterion, visualize,
             args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.2e')
    acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f")
    progress = ProgressMeter(len(val_loader), [batch_time, losses, acc['all']],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (x, label, weight, meta) in enumerate(val_loader):
            x = x.to(device)
            label = label.to(device)
            weight = weight.to(device)

            # compute output
            y = model(x)
            loss = criterion(y, label, weight)

            # measure accuracy and record loss
            losses.update(loss.item(), x.size(0))
            acc_per_points, avg_acc, cnt, pred = accuracy(
                y.cpu().numpy(),
                label.cpu().numpy())

            group_acc = val_loader.dataset.group_accuracy(acc_per_points)
            acc.update(group_acc, x.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
                if visualize is not None:
                    visualize(x[0],
                              pred[0] * args.image_size / args.heatmap_size,
                              "val_{}_pred.jpg".format(i))
                    visualize(x[0], meta['keypoint2d'][0],
                              "val_{}_label.jpg".format(i))

    return acc.average()
Exemple #5
0
def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    acc = Meter('Acc', ':3.2f')
    iou = Meter('IoU', ':3.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, acc, iou],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    confmat = ConfusionMatrix(model.num_classes)

    with torch.no_grad():
        end = time.time()
        for i, (x, label) in enumerate(val_loader):
            x = x.to(device)
            label = label.long().to(device)

            # compute output
            output = interp(model(x))
            loss = criterion(output, label)

            # measure accuracy and record loss
            losses.update(loss.item(), x.size(0))
            confmat.update(label.flatten(), output.argmax(1).flatten())
            acc_global, accs, iu = confmat.compute()
            acc.update(accs.mean().item())
            iou.update(iu.mean().item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

                if visualize is not None:
                    visualize(x[0], output[0], label[0], "val_{}".format(i))

    return confmat
Exemple #6
0
def pretrain(train_source_iter, model, criterion, optimizer, epoch: int,
             args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ":.2e")
    acc_s = AverageMeter("Acc (s)", ":3.2f")

    progress = ProgressMeter(args.iters_per_epoch,
                             [batch_time, data_time, losses_s, acc_s],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s, weight_s, meta_s = next(train_source_iter)

        x_s = x_s.to(device)
        label_s = label_s.to(device)
        weight_s = weight_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        loss_s = criterion(y_s, label_s, weight_s)

        # compute gradient and do SGD step
        loss_s.backward()
        optimizer.step()

        # measure accuracy and record loss
        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
                                               label_s.detach().cpu().numpy())
        acc_s.update(avg_acc_s, cnt_s)
        losses_s.update(loss_s, cnt_s)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #7
0
def train(train_source_iter, train_target_iter, model, criterion,
          regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
          lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ":.2e")
    losses_gf = AverageMeter('Loss (t, false)', ":.2e")
    losses_gt = AverageMeter('Loss (t, truth)', ":.2e")
    acc_s = AverageMeter("Acc (s)", ":3.2f")
    acc_t = AverageMeter("Acc (t)", ":3.2f")
    acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f")
    acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f")

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t,
        acc_s_adv, acc_t_adv
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, label_s, weight_s, meta_s = next(train_source_iter)
        x_t, label_t, weight_t, meta_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.to(device)
        weight_s = weight_s.to(device)

        x_t = x_t.to(device)
        label_t = label_t.to(device)
        weight_t = weight_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step A train all networks to minimize loss on source domain
        optimizer_f.zero_grad()
        optimizer_h.zero_grad()
        optimizer_h_adv.zero_grad()

        y_s, y_s_adv = model(x_s)
        loss_s = criterion(y_s, label_s, weight_s) + \
                 args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')
        loss_s.backward()
        optimizer_f.step()
        optimizer_h.step()
        optimizer_h_adv.step()

        # Step B train adv regressor to maximize regression disparity
        optimizer_h_adv.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_false = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='max')
        loss_ground_false.backward()
        optimizer_h_adv.step()

        # Step C train feature extractor to minimize regression disparity
        optimizer_f.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_truth = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='min')
        loss_ground_truth.backward()
        optimizer_f.step()

        # do update step
        model.step()
        lr_scheduler_f.step()
        lr_scheduler_h.step()
        lr_scheduler_h_adv.step()

        # measure accuracy and record loss
        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
                                               label_s.detach().cpu().numpy())
        acc_s.update(avg_acc_s, cnt_s)
        _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),
                                               label_t.detach().cpu().numpy())
        acc_t.update(avg_acc_t, cnt_t)
        _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(
            y_s_adv.detach().cpu().numpy(),
            label_s.detach().cpu().numpy())
        acc_s_adv.update(avg_acc_s_adv, cnt_s)
        _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(
            y_t_adv.detach().cpu().numpy(),
            label_t.detach().cpu().numpy())
        acc_t_adv.update(avg_acc_t_adv, cnt_t)
        losses_s.update(loss_s, cnt_s)
        losses_gf.update(loss_ground_false, cnt_s)
        losses_gt.update(loss_ground_truth, cnt_s)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if visualize is not None:
                visualize(x_s[0],
                          pred_s[0] * args.image_size / args.heatmap_size,
                          "source_{}_pred".format(i))
                visualize(x_s[0], meta_s['keypoint2d'][0],
                          "source_{}_label".format(i))
                visualize(x_t[0],
                          pred_t[0] * args.image_size / args.heatmap_size,
                          "target_{}_pred".format(i))
                visualize(x_t[0], meta_t['keypoint2d'][0],
                          "target_{}_label".format(i))
                visualize(x_s[0],
                          pred_s_adv[0] * args.image_size / args.heatmap_size,
                          "source_adv_{}_pred".format(i))
                visualize(x_t[0],
                          pred_t_adv[0] * args.image_size / args.heatmap_size,
                          "target_adv_{}_pred".format(i))
Exemple #8
0
def train(train_iter: ForeverDataIterator, model: Classifier,
          backbone_regularization: nn.Module, head_regularization: nn.Module,
          target_getter: IntermediateLayerGetter,
          source_getter: IntermediateLayerGetter, optimizer: SGD, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f')
    losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, losses_reg_head, losses_reg_backbone,
        cls_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)
        x = x.to(device)
        label = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        intermediate_output_s, output_s = source_getter(x)
        intermediate_output_t, output_t = target_getter(x)
        y, f = output_t

        # measure accuracy and record loss
        cls_acc = accuracy(y, label)[0]
        cls_loss = F.cross_entropy(y, label)
        if args.regularization_type == 'feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        elif args.regularization_type == 'attention_feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        else:
            loss_reg_backbone = backbone_regularization()
        loss_reg_head = head_regularization()
        loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head

        losses_reg_backbone.update(
            loss_reg_backbone.item() * args.trade_off_backbone, x.size(0))
        losses_reg_head.update(loss_reg_head.item() * args.trade_off_head,
                               x.size(0))
        losses.update(loss.item(), x.size(0))
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #9
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          mdd: MarginDisparityDiscrepancy, optimizer: Adam,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    cls_adv_accs = AverageMeter('Cls Adv Acc', ':3.1f')
    tgt_adv_accs = AverageMeter('Tgt Adv Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs,
        cls_adv_accs, tgt_adv_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mdd.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        outputs, outputs_adv = model(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute cross entropy loss on source domain
        cls_loss = F.cross_entropy(y_s, labels_s)
        # compute margin disparity discrepancy between domains
        # for adversarial classifier, minimize negative mdd is equal to maximize mdd
        transfer_loss = -mdd(y_s, y_s_adv, y_t, y_t_adv)
        loss = cls_loss + transfer_loss * args.trade_off
        model.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]
        cls_adv_acc = accuracy(y_s_adv, labels_s)[0]
        tgt_adv_acc = accuracy(y_t_adv, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        cls_adv_accs.update(cls_adv_acc.item(), x_s.size(0))
        tgt_adv_accs.update(tgt_adv_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #10
0
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer,
          lr_scheduler: CosineAnnealingLR, n_domains_per_batch: int,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_ce = AverageMeter('CELoss', ':3.2f')
    losses_penalty = AverageMeter('Penalty Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_all, labels_all, _ = next(train_iter)
        x_all = x_all.to(device)
        labels_all = labels_all.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_all, _ = model(x_all)

        loss_ce_per_domain = torch.zeros(n_domains_per_batch).to(device)
        for domain_id, (y_per_domain, labels_per_domain) in enumerate(
                zip(y_all.chunk(n_domains_per_batch, dim=0),
                    labels_all.chunk(n_domains_per_batch, dim=0))):
            loss_ce_per_domain[domain_id] = F.cross_entropy(
                y_per_domain, labels_per_domain)

        # cls loss
        loss_ce = loss_ce_per_domain.mean()
        # penalty loss
        loss_penalty = ((loss_ce_per_domain - loss_ce)**2).mean()

        global_iter = epoch * args.iters_per_epoch + i
        if global_iter >= args.anneal_iters:
            trade_off = args.trade_off
        else:
            trade_off = 1

        loss = loss_ce + loss_penalty * trade_off
        cls_acc = accuracy(y_all, labels_all)[0]

        losses.update(loss.item(), x_all.size(0))
        losses_ce.update(loss_ce.item(), x_all.size(0))
        losses_penalty.update(loss_penalty.item(), x_all.size(0))
        cls_accs.update(cls_acc.item(), x_all.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #11
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          teacher: EmaTeacher, consistent_loss, class_balance_loss,
          optimizer: Adam, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    cons_losses = AverageMeter('Cons Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, cls_losses, cons_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    teacher.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        (x_t1, x_t2), labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t1 = x_t1.to(device)
        x_t2 = x_t2.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, _ = model(x_s)
        y_t, _ = model(x_t1)
        y_t_teacher, _ = teacher(x_t2)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)
        # compute output and mask
        y_t = F.softmax(y_t, dim=1)
        y_t_teacher = F.softmax(y_t_teacher, dim=1)
        max_prob, _ = y_t_teacher.max(dim=1)
        mask = (max_prob > args.threshold).float()

        # consistent loss
        cons_loss = consistent_loss(y_t, y_t_teacher, mask)
        # balance loss
        balance_loss = class_balance_loss(y_t) * mask.mean()

        loss = cls_loss + args.trade_off_cons * cons_loss + args.trade_off_balance * balance_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # update teacher
        teacher.update()

        # update statistics
        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]
        cls_losses.update(cls_loss.item(), x_s.size(0))
        cons_losses.update(cons_loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #12
0
def validate(val_loader: DataLoader, model: ImageClassifier,
             args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    if args.per_class_eval:
        classes = val_loader.dataset.classes
        confmat = ConfusionMatrix(len(classes))
    else:
        confmat = None

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            if confmat:
                confmat.update(target, output.argmax(1))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
                                                                    top5=top5))
        if confmat:
            print(confmat.format(classes))

    return top1.avg
def validate(val_loader, model, args, device, visualize=None) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, = accuracy(output, target, topk=(1, ))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
                if visualize is not None:
                    visualize(images[0], "val_{}".format(i))

        print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))

    return top1.avg
Exemple #14
0
def train(train_source_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_s,
         accuracies_s, iou_s],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s = next(train_source_iter)
        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        # compute gradient and do SGD step
        optimizer.step()
        lr_scheduler.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        accuracies_s.update(acc_s.mean().item())
        iou_s.update(iu_s.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
Exemple #15
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    norm_losses = AverageMeter('Norm Loss', ':3.2f')
    src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')
    tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, cls_losses, norm_losses, src_feature_norm,
        tgt_feature_norm, cls_accs, tgt_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)
        # norm loss
        norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)
        loss = cls_loss + norm_loss * args.trade_off_norm

        # using entropy minimization
        if args.trade_off_entropy:
            y_t = F.softmax(y_t, dim=1)
            entropy_loss = entropy(y_t, reduction='mean')
            loss += entropy_loss * args.trade_off_entropy

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update statistics
        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        cls_losses.update(cls_loss.item(), x_s.size(0))
        norm_losses.update(norm_loss.item(), x_s.size(0))
        src_feature_norm.update(
            f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))
        tgt_feature_norm.update(
            f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #16
0
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,
          epoch: int, relationship, co_tuning_loss, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch,
                             [batch_time, data_time, losses, cls_accs],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, label_t = next(train_iter)

        x = x.to(device)
        label_s = torch.from_numpy(relationship[label_t]).cuda().float()
        label_t = label_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, y_t = model(x)
        tgt_loss = F.cross_entropy(y_t, label_t)
        src_loss = co_tuning_loss(y_s, label_s)
        loss = tgt_loss + args.trade_off * src_loss

        # measure accuracy and record loss
        losses.update(loss.item(), x.size(0))
        cls_acc = accuracy(y_t, label_t)[0]
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S,
          netD_T, criterion_gan, criterion_cycle, criterion_identity,
          optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
    losses_D_S = AverageMeter('D_S', ':3.2f')
    losses_D_T = AverageMeter('D_T', ':3.2f')
    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
    losses_identity_S = AverageMeter('idt_S', ':3.2f')
    losses_identity_T = AverageMeter('idt_T', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S,
        losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S,
        losses_identity_T
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()

    for i in range(args.iters_per_epoch):
        real_S, _ = next(train_source_iter)
        real_T, _ = next(train_target_iter)

        real_S = real_S.to(device)
        real_T = real_T.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Compute fake images and reconstruction images.
        fake_T = netG_S2T(real_S)
        rec_S = netG_T2S(fake_T)
        fake_S = netG_T2S(real_T)
        rec_T = netG_S2T(fake_S)

        # Optimizing generators
        # discriminators require no gradients
        set_requires_grad(netD_S, False)
        set_requires_grad(netD_T, False)

        optimizer_G.zero_grad()
        # GAN loss D_T(G_S2T(S))
        loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
        # GAN loss D_S(G_T2S(B))
        loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
        # Cycle loss || G_T2S(G_S2T(S)) - S||
        loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle
        # Cycle loss || G_S2T(G_T2S(T)) - T||
        loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle
        # Identity loss
        # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
        identity_T = netG_S2T(real_T)
        loss_identity_T = criterion_identity(identity_T,
                                             real_T) * args.trade_off_identity
        # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
        identity_S = netG_T2S(real_S)
        loss_identity_S = criterion_identity(identity_S,
                                             real_S) * args.trade_off_identity
        # combined loss and calculate gradients
        loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
        loss_G.backward()
        optimizer_G.step()

        # Optimize discriminator
        set_requires_grad(netD_S, True)
        set_requires_grad(netD_T, True)
        optimizer_D.zero_grad()
        # Calculate GAN loss for discriminator D_S
        fake_S_ = fake_S_pool.query(fake_S.detach())
        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) +
                          criterion_gan(netD_S(fake_S_), False))
        loss_D_S.backward()
        # Calculate GAN loss for discriminator D_T
        fake_T_ = fake_T_pool.query(fake_T.detach())
        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) +
                          criterion_gan(netD_T(fake_T_), False))
        loss_D_T.backward()
        optimizer_D.step()

        # measure elapsed time
        losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
        losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
        losses_D_S.update(loss_D_S.item(), real_S.size(0))
        losses_D_T.update(loss_D_T.item(), real_S.size(0))
        losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
        losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
        losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
        losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            for tensor, name in zip([
                    real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S,
                    identity_T
            ], [
                    "real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T",
                    "identity_S", "identity_T"
            ]):
                visualize(tensor[0], "{}_{}".format(i, name))
Exemple #18
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          teacher: EmaTeacher, cluster_assignment: ASoftmax,
          cluster_distribution: ClusterDistribution, consistent_loss,
          class_balance_loss, kl_loss: nn.KLDivLoss, conditional_loss,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):

    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        print(y.shape)
        print(f.shape)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        cdt_loss = conditional_loss(y_t)
        kls_loss = kl_loss(cluster_assignment(f),
                           cluster_distribution.calculate(f))

        loss = cls_loss + cdt_loss + kls_loss

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          model: Regressor, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    mse_losses = AverageMeter('MSE Loss', ':6.3f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, mse_losses, mae_losses_s, mae_losses_t],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, _ = model(x_s)
        y_t, _ = model(x_t)

        mse_loss = F.mse_loss(y_s, labels_s)
        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)
        loss = mse_loss

        mse_losses.update(mse_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #20
0
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = domain_adv(f_s, f_t)
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #21
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv_D: DomainAdversarialLoss,
          domain_adv_D_0: DomainAdversarialLoss, importance_weight_module,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f')
    domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f')
    partial_classes_weights = AverageMeter('Partial Weight', ':3.2f')
    non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D,
        domain_accs_D_0, partial_classes_weights, non_partial_classes_weights
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv_D.train()
    domain_adv_D_0.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)

        # domain adversarial loss for D
        adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach())

        # get importance weights
        w_s = importance_weight_module.get_importance_weight(f_s)
        # domain adversarial loss for D_0
        adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s)

        # entropy loss
        y_t = F.softmax(y_t, dim=1)
        entropy_loss = entropy(y_t, reduction='mean')

        loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \
               args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy,
                             x_s.size(0))
        domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy,
                               x_s.size(0))

        # debug: output class weight averaged on the partial classes and non-partial classes respectively
        partial_class_weight, non_partial_classes_weight = \
            importance_weight_module.get_partial_classes_weight(w_s, labels_s)
        partial_classes_weights.update(partial_class_weight.item(),
                                       x_s.size(0))
        non_partial_classes_weights.update(non_partial_classes_weight.item(),
                                           x_s.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_iter: ForeverDataIterator, bituning: BiTuning, optimizer: SGD,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    contrastive_losses = AverageMeter('Contrastive Loss', ':3.2f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, cls_losses, contrastive_losses, cls_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    classifier_criterion = torch.nn.CrossEntropyLoss().to(device)
    contrastive_criterion = torch.nn.KLDivLoss(
        reduction='batchmean').to(device)

    # switch to train mode
    bituning.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)
        img_q, img_k = x[0], x[1]

        img_q = img_q.to(device)
        img_k = img_k.to(device)
        labels = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y, logits_z, logits_y, bituning_labels = bituning(img_q, img_k, labels)
        cls_loss = classifier_criterion(y, labels)
        contrastive_loss_z = contrastive_criterion(logits_z, bituning_labels)
        contrastive_loss_y = contrastive_criterion(logits_y, bituning_labels)
        contrastive_loss = (contrastive_loss_z + contrastive_loss_y)
        loss = cls_loss + contrastive_loss * args.trade_off

        # measure accuracy and record loss
        losses.update(loss.item(), x[0].size(0))
        cls_losses.update(cls_loss.item(), x[0].size(0))
        contrastive_losses.update(contrastive_loss.item(), x[0].size(0))

        cls_acc = accuracy(y, labels)[0]
        cls_accs.update(cls_acc.item(), x[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_discri: DomainDiscriminator,
          domain_adv: DomainAdversarialLoss, gl, optimizer: SGD,
          lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses_s = AverageMeter('Cls Loss', ':6.2f')
    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
    losses_discriminator = AverageMeter('Discriminator Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        cls_accs, domain_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step 1: Train the classifier, freeze the discriminator
        model.train()
        domain_discri.eval()
        set_requires_grad(model, True)
        set_requires_grad(domain_discri, False)
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        loss_s = F.cross_entropy(y_s, labels_s)

        # adversarial training to fool the discriminator
        d = domain_discri(gl(f))
        d_s, d_t = d.chunk(2, dim=0)
        loss_transfer = 0.5 * (domain_adv(d_s, 'target') +
                               domain_adv(d_t, 'source'))

        optimizer.zero_grad()
        (loss_s + loss_transfer * args.trade_off).backward()
        optimizer.step()
        lr_scheduler.step()

        # Step 2: Train the discriminator
        model.eval()
        domain_discri.train()
        set_requires_grad(model, False)
        set_requires_grad(domain_discri, True)
        d = domain_discri(f.detach())
        d_s, d_t = d.chunk(2, dim=0)
        loss_discriminator = 0.5 * (domain_adv(d_s, 'source') +
                                    domain_adv(d_t, 'target'))

        optimizer_d.zero_grad()
        loss_discriminator.backward()
        optimizer_d.step()
        lr_scheduler_d.step()

        losses_s.update(loss_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        cls_acc = accuracy(y_s, labels_s)[0]
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_acc = 0.5 * (binary_accuracy(d_s, torch.ones_like(d_s)) +
                            binary_accuracy(d_t, torch.zeros_like(d_t)))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, optimizer: Adam,
          xbm: XBM, epoch: int, args: argparse.Namespace):
    # define loss function
    criterion_ce = BridgeProbLoss(args.n_s_classes +
                                  args.n_t_classes).to(device)
    criterion_triplet = TripletLoss(margin=args.margin).to(device)
    criterion_triplet_xbm = TripletLossXBM(margin=args.margin).to(device)
    criterion_bridge_feat = BridgeFeatLoss().to(device)
    criterion_diverse = DivLoss().to(device)

    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_ce = AverageMeter('CeLoss', ':3.2f')
    losses_triplet = AverageMeter('TripletLoss', ':3.2f')
    losses_triplet_xbm = AverageMeter('XBMTripletLoss', ':3.2f')
    losses_bridge_prob = AverageMeter('BridgeProbLoss', ':3.2f')
    losses_bridge_feat = AverageMeter('BridgeFeatLoss', ':3.2f')
    losses_diverse = AverageMeter('DiverseLoss', ':3.2f')
    losses = AverageMeter('Loss', ':3.2f')

    cls_accs_s = AverageMeter('Src Cls Acc', ':3.1f')
    cls_accs_t = AverageMeter('Tgt Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_ce, losses_triplet, losses_triplet_xbm,
        losses_bridge_prob, losses_bridge_feat, losses_diverse, losses,
        cls_accs_s, cls_accs_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, _, labels_s, _ = next(train_source_iter)
        x_t, _, labels_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # arrange batch for domain-specific BN
        device_num = torch.cuda.device_count()
        B, C, H, W = x_s.size()

        def reshape(tensor):
            return tensor.view(device_num, -1, C, H, W)

        x_s, x_t = reshape(x_s), reshape(x_t)
        x = torch.cat((x_s, x_t), 1).view(-1, C, H, W)

        labels = torch.cat(
            (labels_s.view(device_num, -1), labels_t.view(device_num, -1)), 1)
        labels = labels.view(-1)

        # compute output
        y, f, attention_lam = model(x, stage=args.stage)
        y = y[:, 0:args.n_s_classes + args.
              n_t_classes]  # only (n_s_classes + n_t_classes) classes are meaningful

        # split feats
        ori_f = f.view(device_num, -1, f.size(-1))
        f_s, f_t, f_mixed = ori_f.split(ori_f.size(1) // 3, dim=1)
        ori_f = torch.cat((f_s, f_t), 1).view(-1, ori_f.size(-1))

        # cross entropy loss
        loss_ce, loss_bridge_prob = criterion_ce(y, labels,
                                                 attention_lam[:, 0].detach(),
                                                 device_num)
        # triplet loss
        loss_triplet = criterion_triplet(ori_f, labels)
        # diverse loss
        loss_diverse = criterion_diverse(attention_lam)
        # bridge feature loss
        f_s = f_s.contiguous().view(-1, f.size(-1))
        f_t = f_t.contiguous().view(-1, f.size(-1))
        f_mixed = f_mixed.contiguous().view(-1, f.size(-1))
        loss_bridge_feat = criterion_bridge_feat(f_s, f_t, f_mixed,
                                                 attention_lam)
        # xbm triplet loss
        xbm.enqueue_dequeue(ori_f.detach(), labels.detach())
        xbm_f, xbm_labels = xbm.get()
        loss_triplet_xbm = criterion_triplet_xbm(ori_f, labels, xbm_f,
                                                 xbm_labels)

        loss = (1. - args.mu1) * loss_ce + loss_triplet + loss_triplet_xbm + \
               args.mu1 * loss_bridge_prob + args.mu2 * loss_bridge_feat + args.mu3 * loss_diverse

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ori_y = y.view(device_num, -1, y.size(-1))
        y_s, y_t, _ = ori_y.split(ori_y.size(1) // 3, dim=1)
        cls_acc_s = accuracy(y_s.reshape(-1, y_s.size(-1)), labels_s)[0]
        cls_acc_t = accuracy(y_t.reshape(-1, y_t.size(-1)), labels_t)[0]

        # update statistics
        losses_ce.update(loss_ce.item(), x_s.size(0))
        losses_triplet.update(loss_triplet.item(), x_s.size(0))
        losses_triplet_xbm.update(loss_triplet_xbm.item(), x_s.size(0))
        losses_bridge_prob.update(loss_bridge_prob.item(), x_s.size(0))
        losses_bridge_feat.update(loss_bridge_feat.item(), x_s.size(0))
        losses_diverse.update(loss_diverse.item(), x_s.size(0))
        losses.update(loss.item(), x_s.size(0))

        cls_accs_s.update(cls_acc_s.item(), x_s.size(0))
        cls_accs_t.update(cls_acc_t.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #25
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          mdd: MarginDisparityDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    source_losses = AverageMeter('Source Loss', ':6.3f')
    trans_losses = AverageMeter('Trans Loss', ':6.3f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, source_losses, trans_losses, mae_losses_s,
        mae_losses_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mdd.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat([x_s, x_t], dim=0)
        outputs, outputs_adv = model(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute mean square loss on source domain
        mse_loss = F.mse_loss(y_s, labels_s)

        # compute margin disparity discrepancy between domains
        transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv)
        # for adversarial classifier, minimize negative mdd is equal to maximize mdd
        loss = mse_loss - transfer_loss * args.trade_off
        model.step()

        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)

        source_losses.update(mse_loss.item(), x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #26
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          mkmmd_loss: MultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':5.4f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mkmmd_loss.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = mkmmd_loss(f_s, f_t)
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator, model: Classifier,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch,
                             [batch_time, data_time, losses, cls_accs],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)

        cls_loss = F.cross_entropy(y_s, labels_s)
        loss = cls_loss

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #28
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, interp, criterion,
          dann, optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD,
          lr_scheduler_d: LambdaLR, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    losses_transfer = AverageMeter('Loss (transfer)', ':3.2f')
    losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    accuracies_t = Meter('Acc (t)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')
    iou_t = Meter('IoU (t)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    confmat_t = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        accuracies_s, accuracies_t, iou_s, iou_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, label_s = next(train_source_iter)
        x_t, label_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        x_t = x_t.to(device)
        label_t = label_t.long().to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        optimizer.zero_grad()
        optimizer_d.zero_grad()

        # Step 1: Train the segmentation network, freeze the discriminator
        dann.eval()
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        # adversarial training to fool the discriminator
        y_t = model(x_t)
        pred_t = interp(y_t)
        loss_transfer = dann(pred_t, 'source')
        (loss_transfer * args.trade_off).backward()

        # Step 2: Train the discriminator
        dann.train()
        loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') +
                                    dann(pred_t.detach(), 'target'))
        loss_discriminator.backward()

        # compute gradient and do SGD step
        optimizer.step()
        optimizer_d.step()
        lr_scheduler.step()
        lr_scheduler_d.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        acc_global_t, acc_t, iu_t = confmat_t.compute()
        accuracies_s.update(acc_s.mean().item())
        accuracies_t.update(acc_t.mean().item())
        iou_s.update(iu_s.mean().item())
        iou_t.update(iu_t.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
                visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
def train(train_iter: ForeverDataIterator, model,
          criterion_ce: CrossEntropyLossWithLabelSmooth,
          criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_ce = AverageMeter('CeLoss', ':3.2f')
    losses_triplet = AverageMeter('TripletLoss', ':3.2f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x, _, labels, _ = next(train_iter)
        x = x.to(device)
        labels = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y, f = model(x)

        # cross entropy loss
        loss_ce = criterion_ce(y, labels)
        # triplet loss
        loss_triplet = criterion_triplet(f, f, labels)
        loss = loss_ce + loss_triplet * args.trade_off

        cls_acc = accuracy(y, labels)[0]
        losses_ce.update(loss_ce.item(), x.size(0))
        losses_triplet.update(loss_triplet.item(), x.size(0))
        losses.update(loss.item(), x.size(0))
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #30
0
def calculate_channel_attention(dataset, return_layers, args):
    backbone = models.__dict__[args.arch](pretrained=True)
    classifier = Classifier(backbone, dataset.num_classes).to(device)
    optimizer = SGD(classifier.get_parameters(args.lr),
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    data_loader = DataLoader(dataset,
                             batch_size=args.attention_batch_size,
                             shuffle=True,
                             num_workers=args.workers,
                             drop_last=False)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=math.exp(math.log(0.1) / args.attention_lr_decay_epochs))
    criterion = nn.CrossEntropyLoss()

    channel_weights = []
    for layer_id, name in enumerate(return_layers):
        layer = get_attribute(classifier, name)
        layer_channel_weight = [0] * layer.out_channels
        channel_weights.append(layer_channel_weight)

    # train the classifier
    classifier.train()
    classifier.backbone.requires_grad = False
    print("Pretrain a classifier to calculate channel attention.")
    for epoch in range(args.attention_epochs):
        losses = AverageMeter('Loss', ':3.2f')
        cls_accs = AverageMeter('Cls Acc', ':3.1f')
        progress = ProgressMeter(len(data_loader), [losses, cls_accs],
                                 prefix="Epoch: [{}]".format(epoch))

        for i, data in enumerate(data_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs, _ = classifier(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_acc = accuracy(outputs, labels)[0]

            losses.update(loss.item(), inputs.size(0))
            cls_accs.update(cls_acc.item(), inputs.size(0))

            if i % args.print_freq == 0:
                progress.display(i)
        lr_scheduler.step()

    # calculate the channel attention
    print('Calculating channel attention.')
    classifier.eval()
    if args.attention_iteration_limit > 0:
        total_iteration = min(len(data_loader), args.attention_iteration_limit)
    else:
        total_iteration = len(args.data_loader)

    progress = ProgressMeter(total_iteration, [], prefix="Iteration: ")

    for i, data in enumerate(data_loader):
        if i >= total_iteration:
            break
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs, _ = classifier(inputs)
        loss_0 = criterion(outputs, labels)
        progress.display(i)
        for layer_id, name in enumerate(tqdm(return_layers)):
            layer = get_attribute(classifier, name)
            for j in range(layer.out_channels):
                tmp = classifier.state_dict()[name + '.weight'][j, ].clone()
                classifier.state_dict()[name + '.weight'][j, ] = 0.0
                outputs, _ = classifier(inputs)
                loss_1 = criterion(outputs, labels)
                difference = loss_1 - loss_0
                difference = difference.detach().cpu().numpy().item()
                history_value = channel_weights[layer_id][j]
                channel_weights[layer_id][j] = 1.0 * (i * history_value +
                                                      difference) / (i + 1)
                classifier.state_dict()[name + '.weight'][j, ] = tmp

    channel_attention = []
    for weight in channel_weights:
        weight = np.array(weight)
        weight = (weight - np.mean(weight)) / np.std(weight)
        weight = torch.from_numpy(weight).float().to(device)
        channel_attention.append(F.softmax(weight / 5).detach())
    return channel_attention