def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): lr_scheduler.step() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, source_model: ImageClassifier, target_model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses_transfer = AverageMeter('Transfer Loss', ':6.2f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_transfer, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode target_model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, _ = next(train_source_iter) x_t, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) # measure data loading time data_time.update(time.time() - end) _, f_s = source_model(x_s) _, f_t = target_model(x_t) loss_transfer = domain_adv(f_s, f_t) # Compute gradient and do SGD step optimizer.zero_grad() loss_transfer.backward() optimizer.step() lr_scheduler.step() losses_transfer.update(loss_transfer.item(), x_s.size(0)) domain_acc = domain_adv.domain_discriminator_accuracy domain_accs.update(domain_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True # warnings.warn('You have chosen to seed training. ' # 'This will turn on the CUDNN deterministic setting, ' # 'which can slow down your training considerably! ' # 'You may see unexpected behavior when restarting ' # 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.center_crop: train_transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) else: train_transform = T.Compose([ ResizeImage(256), T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) val_transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), T.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] source_dataset = open_set(dataset, source=True) target_dataset = open_set(dataset, source=False) train_source_dataset = source_dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = target_dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = target_dataset(root=args.root, task=args.target, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = target_dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) num_classes = train_source_dataset.num_classes backbone = models.__dict__[args.arch](pretrained=True) classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, classifier, args) print(acc1) return # start training best_h_score = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set h_score = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if h_score > best_h_score: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_h_score = max(h_score, best_h_score) print("best_h_score = {:3.1f}".format(best_h_score)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) h_score = validate(test_loader, classifier, args) print("test_h_score = {:3.1f}".format(h_score)) logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv_D: DomainAdversarialLoss, domain_adv_D_0: DomainAdversarialLoss, importance_weight_module, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f') domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f') partial_classes_weights = AverageMeter('Partial Weight', ':3.2f') non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D, domain_accs_D_0, partial_classes_weights, non_partial_classes_weights ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv_D.train() domain_adv_D_0.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # domain adversarial loss for D adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach()) # get importance weights w_s = importance_weight_module.get_importance_weight(f_s) # domain adversarial loss for D_0 adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s) # entropy loss y_t = F.softmax(y_t, dim=1) entropy_loss = entropy(y_t, reduction='mean') loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \ args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy, x_s.size(0)) domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy, x_s.size(0)) # debug: output class weight averaged on the partial classes and non-partial classes respectively partial_class_weight, non_partial_classes_weight = \ importance_weight_module.get_partial_classes_weight(w_s, labels_s) partial_classes_weights.update(partial_class_weight.item(), x_s.size(0)) non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None if args.data == 'ImageNetCaltech': classifier = Classifier(backbone, num_classes, head=backbone.copy_head(), pool_layer=pool_layer).to(device) else: classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device) # define domain classifier D, D_0 D = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024, batch_norm=False).to(device) D_0 = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024, batch_norm=False).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + D.get_parameters() + D_0.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # define loss function domain_adv_D = DomainAdversarialLoss(D).to(device) domain_adv_D_0 = DomainAdversarialLoss(D_0).to(device) # define importance weight module importance_weight_module = ImportanceWeightModule( D, train_target_dataset.partial_classes_idx) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv_D, domain_adv_D_0, importance_weight_module, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()
def main(args: argparse.Namespace): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ ResizeImage(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_tranform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_tranform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, evaluate=True, download=True, transform=val_tranform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) classifier = ImageClassifier(backbone, train_source_dataset.num_classes).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = StepwiseLR(optimizer, init_lr=args.lr, gamma=0.001, decay_rate=0.75) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint if acc1 > best_acc1: best_model = copy.deepcopy(classifier.state_dict()) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(best_model) acc1 = validate(test_loader, classifier, args) print("test_acc1 = {:3.1f}".format(acc1))
def da_methods(discr, frame, ag, dim_x, dim_y, device, ckpt, discr_src=None): if ag.mode not in MODES_GEN: celossobj, celossfn = get_ce_or_bce_loss(discr, dim_y, ag.reduction) if ag.mode == "dann": domdisc = DomainDiscriminator( in_feature=discr.dim_s, hidden_size=ag.domdisc_dimh).to(device) dalossobj = DomainAdversarialLoss( domdisc, reduction=ag.reduction).to(device) auto_load(locals(), ['domdisc', 'dalossobj'], ckpt) def lossfn(x, y, xt): logit, feat = discr.ys1x(x) featt = discr.s1x(xt) if feat.shape[0] < featt.shape[0]: featt = featt[:feat.shape[0]] elif feat.shape[0] > featt.shape[0]: feat = feat[:featt.shape[0]] return celossobj(logit, y.float() if dim_y == 1 else y) + ag.wda * dalossobj(feat, featt) elif ag.mode == "cdan": # In both randomized and not randomized versions, the code has problems. # For randomized, the dim_s*num_cls is fed to domdisc. # For not rand, `tc.mm` receives the wrong input order in `RandomizedMultiLinearMap.forward()` num_classes = (2 if dim_y == 1 else dim_y) domdisc = DomainDiscriminator( in_feature=discr.dim_s * (1 if ag.cdan_rand else num_classes), # confusing `in_feature` hidden_size=ag.domdisc_dimh).to(device) dalossobj = ConditionalDomainAdversarialLoss( domdisc, reduction=ag.reduction, randomized=ag.cdan_rand, num_classes=num_classes, features_dim=discr.dim_s, randomized_dim=discr.dim_s).to(device) auto_load(locals(), ['domdisc', 'dalossobj'], ckpt) def lossfn(x, y, xt): logit, feat = discr.ys1x(x) logitt, featt = discr.ys1x(xt) logit_stack = tc.stack([tc.zeros_like(logit), logit], dim=-1) if dim_y == 1 else logit logitt_stack = tc.stack([tc.zeros_like(logitt), logitt], dim=-1) if dim_y == 1 else logitt return celossobj( logit, y.float() if dim_y == 1 else y) + ag.wda * dalossobj( logit_stack, feat, logitt_stack, featt) elif ag.mode == "dan": domdisc = None dalossobj = MultipleKernelMaximumMeanDiscrepancy([ GaussianKernel(alpha=alpha) for alpha in ag.ker_alphas ]).to(device) def lossfn(x, y, xt): logit, feat = discr.ys1x(x) featt = discr.s1x(xt) if feat.shape[0] < featt.shape[0]: featt = featt[:feat.shape[0]] elif feat.shape[0] > featt.shape[0]: feat = feat[:featt.shape[0]] return celossobj(logit, y.float() if dim_y == 1 else y) + ag.wda * dalossobj(feat, featt) elif ag.mode == "mdd": num_classes = (2 if dim_y == 1 else dim_y) domdisc = mlp.MLP([ dim_x, ag.domdisc_dimh, ag.domdisc_dimh, num_classes ]).to( device ) # actually not domain discriminator but an auxiliary (adversarial) classifier dalossobj = MarginDisparityDiscrepancy( margin=ag.mdd_margin, reduction=ag.reduction).to(device) auto_load(locals(), ['domdisc', 'dalossobj'], ckpt) def lossfn(x, y, xt): logit, logitt = discr(x), discr(xt) logit_adv, logitt_adv = domdisc(x.reshape(-1, dim_x)), domdisc( xt.reshape(-1, dim_x)) logit_stack = tc.stack([tc.zeros_like(logit), logit], dim=-1) if dim_y == 1 else logit logitt_stack = tc.stack([tc.zeros_like(logitt), logitt], dim=-1) if dim_y == 1 else logitt return celossobj( logit, y.float() if dim_y == 1 else y) + ag.wda * dalossobj( logit_stack, logit_adv, logitt_stack, logitt_adv) elif ag.mode == "bnm": domdisc = None dalossobj = None def lossfn(x, y, xt): logit, logitt = discr(x), discr(xt) logitt_stack = tc.stack([tc.zeros_like(logitt), logitt], dim=-1) if dim_y == 1 else logitt softmax_tgt = logitt_stack.softmax(dim=1) _, s_tgt, _ = tc.svd(softmax_tgt) # if config["method"]=="BNM": transfer_loss = -tc.mean(s_tgt) # elif config["method"]=="BFM": # transfer_loss = -tc.sqrt(tc.sum(s_tgt*s_tgt)/s_tgt.shape[0]) # elif config["method"]=="ENT": # transfer_loss = -tc.mean(tc.sum(softmax_tgt*tc.log(softmax_tgt+1e-8),dim=1))/tc.log(softmax_tgt.shape[1]) return celossobj( logit, y.float() if dim_y == 1 else y) + ag.wda * transfer_loss else: pass for obj in [dalossobj, domdisc]: if obj is not None: obj.train() else: if ag.mode.endswith("-da2") and discr_src is not None: true_discr = discr_src elif ag.mode in MODES_TWIST and ag.true_sup: true_discr = partial(frame.logit_y1x_src, n_mc_q=ag.n_mc_q) else: true_discr = discr celossfn = get_ce_or_bce_loss(true_discr, dim_y, ag.reduction)[1] lossobj = frame.get_lossfn(ag.n_mc_q, ag.reduction, "defl", weight_da=ag.wda / ag.wgen, wlogpi=ag.wlogpi / ag.wgen) lossfn = add_ce_loss(lossobj, celossfn, ag) domdisc, dalossobj = None, None return lossfn, domdisc, dalossobj
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose( [T.Resize(args.resize_size), T.ToTensor(), normalize]) val_transform = T.Compose( [T.Resize(args.resize_size), T.ToTensor(), normalize]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) if args.normalization == 'IN': backbone = convert_model(backbone) num_factors = train_source_dataset.num_factors bottleneck = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(backbone.out_features, 256), nn.ReLU()) regressor = Regressor(backbone=backbone, num_factors=num_factors, bottleneck=bottleneck, bottleneck_dim=256).to(device) print(regressor) domain_discri = DomainDiscriminator(in_feature=regressor.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(regressor.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # define loss function dann = DomainAdversarialLoss(domain_discri).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') regressor.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) # extract features from both domains feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': mae = validate(val_loader, regressor, args, train_source_dataset.factors, device) print(mae) return # start training best_mae = 100000. for epoch in range(args.epochs): # train for one epoch print("lr", lr_scheduler.get_lr()) train(train_source_iter, train_target_iter, regressor, dann, optimizer, lr_scheduler, epoch, args) # evaluate on validation set mae = validate(val_loader, regressor, args, train_source_dataset.factors, device) # remember best mae and save checkpoint torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest')) if mae < best_mae: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_mae = min(mae, best_mae) print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae)) print("best_mae = {:6.3f}".format(best_mae)) logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Regressor, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') mse_losses = AverageMeter('MSE Loss', ':6.3f') dann_losses = AverageMeter('DANN Loss', ':6.3f') domain_accs = AverageMeter('Domain Acc', ':3.1f') mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f') mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, mse_losses, dann_losses, mae_losses_s, mae_losses_t, domain_accs ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device).float() x_t, labels_t = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) mse_loss = F.mse_loss(y_s, labels_s) mae_loss_s = F.l1_loss(y_s, labels_s) mae_loss_t = F.l1_loss(y_t, labels_t) transfer_loss = domain_adv(f_s, f_t) loss = mse_loss + transfer_loss * args.trade_off domain_acc = domain_adv.domain_discriminator_accuracy mse_losses.update(mse_loss.item(), x_s.size(0)) dann_losses.update(transfer_loss.item(), x_s.size(0)) mae_losses_s.update(mae_loss_s.item(), x_s.size(0)) mae_losses_t.update(mae_loss_t.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def main(args: argparse.Namespace): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ ResizeImage(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_transform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) classifier = ImageClassifier(backbone, train_source_dataset.num_classes).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = StepwiseLR(optimizer, init_lr=args.lr, gamma=0.001, decay_rate=0.75) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # start training best_acc1 = 0. best_model = classifier.state_dict() for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint if acc1 > best_acc1: best_model = classifier.state_dict() torch.save(best_model, 'best_model.pth.tar') best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # visualize the results using T-SNE classifier.load_state_dict(best_model) classifier.eval() features, labels, domains = [], [], [] source_val_dataset = dataset(root=args.root, task=args.source, download=True, transform=val_transform) source_val_loader = DataLoader(source_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) with torch.no_grad(): for loader in [source_val_loader, val_loader]: for i, (images, target) in enumerate(loader): images = images.to(device) target = target.to(device) # compute output _, f = classifier(images) features.extend(f.cpu().numpy().tolist()) labels.extend(target) domains = np.concatenate((np.ones(len(source_val_dataset)), np.zeros(len(val_dataset)))) features, labels = np.array(features), np.array(labels) print("source:", len(source_val_dataset), "target:", len(val_dataset)) X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features) plt.figure(figsize=(10, 10)) plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap(["r", "b"]), s=2) plt.savefig(os.path.join('{}_{}2{}.pdf'.format("dann", args.source, args.target)))
def train(feature_extractor: FeatureExtractor, domain_adv: DomainAdversarialLoss, src_iter: ForeverDataIterator, tar_iter: ForeverDataIterator, src_val_loader, tar_val_loader): optimizer = Adam(itertools.chain(feature_extractor.parameters(), domain_adv.parameters()), lr=lr, weight_decay=weight_decay) npair_loss = NPairsLoss() # n pair loss # loss loss_rec = AverageMeter('tot_loss', tb_tag='Loss/tot', writer=writer) loss_lb_rec = AverageMeter('lb_loss', tb_tag='Loss/lb', writer=writer) loss_lb_g_rec = AverageMeter('lb_g_loss', tb_tag='Loss/lb_g', writer=writer) # loss_ulb_rec = AverageMeter('ulb_loss', tb_tag='Loss/ulb') loss_da_rec = AverageMeter('da_loss', tb_tag='Loss/da', writer=writer) # acc da_acc_rec = AverageMeter('da_acc', tb_tag='Acc/da', writer=writer) n_iter = 0 best_nmi = 0 for e_i in range(epoch): feature_extractor.train() domain_adv.train() progress = ProgressMeter( iter_per_epoch, [loss_lb_g_rec, loss_lb_rec, loss_da_rec, da_acc_rec], prefix="Epoch: [{}]".format(e_i), logger=global_logger) for i in range(iter_per_epoch): x_s, l_s = next(src_iter) x_t, l_t = next(tar_iter) # for obj in [x_s, x_t, l_s, l_t]: # to device # obj = obj.to(device) x_s, l_s, x_t, l_t = x_s.to(device), l_s.to(device), x_t.to( device), l_t.to(device) x = torch.cat((x_s, x_t), dim=0) f, g = feature_extractor(x) f_s, f_t = f.chunk(2, dim=0) g_s, g_t = g.chunk(2, dim=0) # source only part loss_s = npair_loss(f_s, l_s) # get n-pair loss on source domain loss_s_g = npair_loss(g_s, l_s) # get n-pair loss on source domain loss_lb_rec.update(loss_s.item(), x_s.size(0), iter=n_iter) loss_lb_g_rec.update(loss_s_g.item(), x_s.size(0), iter=n_iter) # dann # da_loss = domain_adv(f_s,f_t) da_loss = domain_adv(g_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss_da_rec.update(da_loss.item(), f.size(0), iter=n_iter) da_acc_rec.update(domain_acc.item(), f.size(0), iter=n_iter) loss = 0.5 * (loss_s + loss_s_g) + w_da * da_loss # loss = loss_s optimizer.zero_grad() loss.backward() optimizer.step() n_iter += 1 if i % print_freq == 0: progress.display(i) if e_i % 5 == 0: # global_logger.info(f"saving embedding in epoch{e_i}") # # show embedding # show_embedding(backbone, [src_val_loader], tag=f'src_{e_i}', epoch=e_i, writer, device) # show_embedding(backbone, [tar_val_loader], tag=f'tar_{e_i}', epoch=e_i, writer, device) nmi = NMI_eval(feature_extractor, src_val_loader, 5, device, type='src') global_logger.info(f'test on train set nmi: {nmi}') nmi = NMI_eval(feature_extractor, tar_val_loader, 5, device, type='tar') global_logger.info(f'test on test set nmi: {nmi}') if nmi > best_nmi: global_logger.info(f"save best model to {model_dir}") torch.save(backbone.state_dict(), os.path.join(model_dir, 'minst_best_model.pth')) best_nmi = nmi
device, type='tar') global_logger.info(f'test on test set nmi: {nmi}') if nmi > best_nmi: global_logger.info(f"save best model to {model_dir}") torch.save(backbone.state_dict(), os.path.join(model_dir, 'minst_best_model.pth')) best_nmi = nmi if __name__ == "__main__": # setup model backbone = FeatureExtractor().to(device) domain_discri = DomainDiscriminator(in_feature=128, hidden_size=256).to(device) domain_adv = DomainAdversarialLoss(domain_discri).to(device) # TODO feature reconstruction loss # TODO feautre transfer module src_domain_class = [0, 1, 2, 3, 4] tar_domain_class = [5, 6, 7, 8, 9] # setup dataloader src_train_loader = get_mnist_m_loader( dataset_root='./dataset/MNIST-M', label_filter=lambda x: x in src_domain_class, sample_per_cls=sample_num_per_cls, cls_num=cls_num) tar_train_loader = get_mnist_loader( dataset_root='./dataset/MNIST', label_filter=lambda x: x in tar_domain_class,
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None source_classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) if args.phase == 'train' and args.pretrain is None: # first pretrain the classifier wish source data print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) pretrain_optimizer = SGD(pretrain_model.get_parameters(), args.pretrain_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) pretrain_lr_scheduler = LambdaLR(pretrain_optimizer, lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** ( -args.lr_decay)) # start pretraining for epoch in range(args.pretrain_epochs): print("lr:", pretrain_lr_scheduler.get_lr()) # pretrain for one epoch utils.pretrain(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args, device) # validate to show pretrain process utils.validate(val_loader, pretrain_model, args, device) torch.save(pretrain_model.state_dict(), args.pretrain) print("Pretraining process is done.") checkpoint = torch.load(args.pretrain, map_location='cpu') source_classifier.load_state_dict(checkpoint) target_classifier = copy.deepcopy(source_classifier) # freeze source classifier set_requires_grad(source_classifier, False) source_classifier.freeze_bn() domain_discri = DomainDiscriminator(in_feature=source_classifier.features_dim, hidden_size=1024).to(device) # define loss function grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=2., max_iters=1000, auto_step=True) domain_adv = DomainAdversarialLoss(domain_discri, grl=grl).to(device) # define optimizer and lr scheduler # note that we only optimize target feature extractor optimizer = SGD(target_classifier.get_parameters(optimize_head=False) + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') target_classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(target_classifier.backbone, target_classifier.pool_layer, target_classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, target_classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, source_classifier, target_classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, target_classifier, args, device) # remember best acc@1 and save checkpoint torch.save(target_classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set target_classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, target_classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()