def train(train_source_iter, train_target_iter, model, criterion, optimizer,
          epoch: int, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ":.2e")
    acc_s = AverageMeter("Acc (s)", ":3.2f")

    progress = ProgressMeter(args.iters_per_epoch,
                             [batch_time, data_time, losses_s, acc_s],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s, weight_s, meta_s = next(train_source_iter)

        x_s = x_s.to(device)
        label_s = label_s.to(device)
        weight_s = weight_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        loss_s = criterion(y_s, label_s, weight_s)

        # compute gradient and do SGD step
        loss_s.backward()
        optimizer.step()

        # measure accuracy and record loss
        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
                                               label_s.detach().cpu().numpy())
        acc_s.update(avg_acc_s, cnt_s)
        losses_s.update(loss_s, cnt_s)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if visualize is not None:
                visualize(x_s[0],
                          pred_s[0] * args.image_size / args.heatmap_size,
                          "source_{}_pred.jpg".format(i))
                visualize(x_s[0], meta_s['keypoint2d'][0],
                          "source_{}_label.jpg".format(i))
Exemplo n.º 2
0
def validate(val_loader, model, criterion, visualize,
             args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.2e')
    acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f")
    progress = ProgressMeter(len(val_loader), [batch_time, losses, acc['all']],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (x, label, weight, meta) in enumerate(val_loader):
            x = x.to(device)
            label = label.to(device)
            weight = weight.to(device)

            # compute output
            y = model(x)
            loss = criterion(y, label, weight)

            # measure accuracy and record loss
            losses.update(loss.item(), x.size(0))
            acc_per_points, avg_acc, cnt, pred = accuracy(
                y.cpu().numpy(),
                label.cpu().numpy())

            group_acc = val_loader.dataset.group_accuracy(acc_per_points)
            acc.update(group_acc, x.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
                if visualize is not None:
                    visualize(x[0],
                              pred[0] * args.image_size / args.heatmap_size,
                              "val_{}_pred.jpg".format(i))
                    visualize(x[0], meta['keypoint2d'][0],
                              "val_{}_label.jpg".format(i))

    return acc.average()
Exemplo n.º 3
0
def train(train_source_iter, train_target_iter, model, criterion,
          regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
          lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ":.2e")
    losses_gf = AverageMeter('Loss (t, false)', ":.2e")
    losses_gt = AverageMeter('Loss (t, truth)', ":.2e")
    acc_s = AverageMeter("Acc (s)", ":3.2f")
    acc_t = AverageMeter("Acc (t)", ":3.2f")
    acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f")
    acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f")

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t,
        acc_s_adv, acc_t_adv
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, label_s, weight_s, meta_s = next(train_source_iter)
        x_t, label_t, weight_t, meta_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.to(device)
        weight_s = weight_s.to(device)

        x_t = x_t.to(device)
        label_t = label_t.to(device)
        weight_t = weight_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step A train all networks to minimize loss on source domain
        optimizer_f.zero_grad()
        optimizer_h.zero_grad()
        optimizer_h_adv.zero_grad()

        y_s, y_s_adv = model(x_s)
        loss_s = criterion(y_s, label_s, weight_s) + \
                 args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')
        loss_s.backward()
        optimizer_f.step()
        optimizer_h.step()
        optimizer_h_adv.step()

        # Step B train adv regressor to maximize regression disparity
        optimizer_h_adv.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_false = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='max')
        loss_ground_false.backward()
        optimizer_h_adv.step()

        # Step C train feature extractor to minimize regression disparity
        optimizer_f.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_truth = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='min')
        loss_ground_truth.backward()
        optimizer_f.step()

        # do update step
        model.step()
        lr_scheduler_f.step()
        lr_scheduler_h.step()
        lr_scheduler_h_adv.step()

        # measure accuracy and record loss
        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
                                               label_s.detach().cpu().numpy())
        acc_s.update(avg_acc_s, cnt_s)
        _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),
                                               label_t.detach().cpu().numpy())
        acc_t.update(avg_acc_t, cnt_t)
        _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(
            y_s_adv.detach().cpu().numpy(),
            label_s.detach().cpu().numpy())
        acc_s_adv.update(avg_acc_s_adv, cnt_s)
        _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(
            y_t_adv.detach().cpu().numpy(),
            label_t.detach().cpu().numpy())
        acc_t_adv.update(avg_acc_t_adv, cnt_t)
        losses_s.update(loss_s, cnt_s)
        losses_gf.update(loss_ground_false, cnt_s)
        losses_gt.update(loss_ground_truth, cnt_s)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if visualize is not None:
                visualize(x_s[0],
                          pred_s[0] * args.image_size / args.heatmap_size,
                          "source_{}_pred".format(i))
                visualize(x_s[0], meta_s['keypoint2d'][0],
                          "source_{}_label".format(i))
                visualize(x_t[0],
                          pred_t[0] * args.image_size / args.heatmap_size,
                          "target_{}_pred".format(i))
                visualize(x_t[0], meta_t['keypoint2d'][0],
                          "target_{}_label".format(i))
                visualize(x_s[0],
                          pred_s_adv[0] * args.image_size / args.heatmap_size,
                          "source_adv_{}_pred".format(i))
                visualize(x_t[0],
                          pred_t_adv[0] * args.image_size / args.heatmap_size,
                          "target_adv_{}_pred".format(i))