def train(train_source_iter, train_target_iter, model, criterion, optimizer, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ":.2e") acc_s = AverageMeter("Acc (s)", ":3.2f") progress = ProgressMeter(args.iters_per_epoch, [batch_time, data_time, losses_s, acc_s], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, label_s, weight_s, meta_s = next(train_source_iter) x_s = x_s.to(device) label_s = label_s.to(device) weight_s = weight_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s = model(x_s) loss_s = criterion(y_s, label_s, weight_s) # compute gradient and do SGD step loss_s.backward() optimizer.step() # measure accuracy and record loss _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s.update(avg_acc_s, cnt_s) losses_s.update(loss_s, cnt_s) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred.jpg".format(i)) visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label.jpg".format(i))
def validate(val_loader, model, criterion, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.2e') acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f") progress = ProgressMeter(len(val_loader), [batch_time, losses, acc['all']], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (x, label, weight, meta) in enumerate(val_loader): x = x.to(device) label = label.to(device) weight = weight.to(device) # compute output y = model(x) loss = criterion(y, label, weight) # measure accuracy and record loss losses.update(loss.item(), x.size(0)) acc_per_points, avg_acc, cnt, pred = accuracy( y.cpu().numpy(), label.cpu().numpy()) group_acc = val_loader.dataset.group_accuracy(acc_per_points) acc.update(group_acc, x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x[0], pred[0] * args.image_size / args.heatmap_size, "val_{}_pred.jpg".format(i)) visualize(x[0], meta['keypoint2d'][0], "val_{}_label.jpg".format(i)) return acc.average()
def train(train_source_iter, train_target_iter, model, criterion, regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ":.2e") losses_gf = AverageMeter('Loss (t, false)', ":.2e") losses_gt = AverageMeter('Loss (t, truth)', ":.2e") acc_s = AverageMeter("Acc (s)", ":3.2f") acc_t = AverageMeter("Acc (t)", ":3.2f") acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f") acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f") progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t, acc_s_adv, acc_t_adv ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, label_s, weight_s, meta_s = next(train_source_iter) x_t, label_t, weight_t, meta_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.to(device) weight_s = weight_s.to(device) x_t = x_t.to(device) label_t = label_t.to(device) weight_t = weight_t.to(device) # measure data loading time data_time.update(time.time() - end) # Step A train all networks to minimize loss on source domain optimizer_f.zero_grad() optimizer_h.zero_grad() optimizer_h_adv.zero_grad() y_s, y_s_adv = model(x_s) loss_s = criterion(y_s, label_s, weight_s) + \ args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min') loss_s.backward() optimizer_f.step() optimizer_h.step() optimizer_h_adv.step() # Step B train adv regressor to maximize regression disparity optimizer_h_adv.zero_grad() y_t, y_t_adv = model(x_t) loss_ground_false = args.trade_off * regression_disparity( y_t, y_t_adv, weight_t, mode='max') loss_ground_false.backward() optimizer_h_adv.step() # Step C train feature extractor to minimize regression disparity optimizer_f.zero_grad() y_t, y_t_adv = model(x_t) loss_ground_truth = args.trade_off * regression_disparity( y_t, y_t_adv, weight_t, mode='min') loss_ground_truth.backward() optimizer_f.step() # do update step model.step() lr_scheduler_f.step() lr_scheduler_h.step() lr_scheduler_h_adv.step() # measure accuracy and record loss _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s.update(avg_acc_s, cnt_s) _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(), label_t.detach().cpu().numpy()) acc_t.update(avg_acc_t, cnt_t) _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy( y_s_adv.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s_adv.update(avg_acc_s_adv, cnt_s) _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy( y_t_adv.detach().cpu().numpy(), label_t.detach().cpu().numpy()) acc_t_adv.update(avg_acc_t_adv, cnt_t) losses_s.update(loss_s, cnt_s) losses_gf.update(loss_ground_false, cnt_s) losses_gt.update(loss_ground_truth, cnt_s) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred".format(i)) visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label".format(i)) visualize(x_t[0], pred_t[0] * args.image_size / args.heatmap_size, "target_{}_pred".format(i)) visualize(x_t[0], meta_t['keypoint2d'][0], "target_{}_label".format(i)) visualize(x_s[0], pred_s_adv[0] * args.image_size / args.heatmap_size, "source_adv_{}_pred".format(i)) visualize(x_t[0], pred_t_adv[0] * args.image_size / args.heatmap_size, "target_adv_{}_pred".format(i))