def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.center_crop: train_transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) else: train_transform = T.Compose([ ResizeImage(256), T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) val_transform = T.Compose( [ResizeImage(256), T.CenterCrop(224), T.ToTensor(), normalize]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) num_classes = train_source_dataset.num_classes classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # define loss function mkmmd_loss = MultipleKernelMaximumMeanDiscrepancy( kernels=[GaussianKernel(alpha=2**k) for k in range(-3, 2)], linear=not args.non_linear) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, classifier, args) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, mkmmd_loss, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = validate(test_loader, classifier, args) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()
def da_methods(discr, frame, ag, dim_x, dim_y, device, ckpt, discr_src=None): if ag.mode not in MODES_GEN: celossobj, celossfn = get_ce_or_bce_loss(discr, dim_y, ag.reduction) if ag.mode == "dann": domdisc = DomainDiscriminator( in_feature=discr.dim_s, hidden_size=ag.domdisc_dimh).to(device) dalossobj = DomainAdversarialLoss( domdisc, reduction=ag.reduction).to(device) auto_load(locals(), ['domdisc', 'dalossobj'], ckpt) def lossfn(x, y, xt): logit, feat = discr.ys1x(x) featt = discr.s1x(xt) if feat.shape[0] < featt.shape[0]: featt = featt[:feat.shape[0]] elif feat.shape[0] > featt.shape[0]: feat = feat[:featt.shape[0]] return celossobj(logit, y.float() if dim_y == 1 else y) + ag.wda * dalossobj(feat, featt) elif ag.mode == "cdan": # In both randomized and not randomized versions, the code has problems. # For randomized, the dim_s*num_cls is fed to domdisc. # For not rand, `tc.mm` receives the wrong input order in `RandomizedMultiLinearMap.forward()` num_classes = (2 if dim_y == 1 else dim_y) domdisc = DomainDiscriminator( in_feature=discr.dim_s * (1 if ag.cdan_rand else num_classes), # confusing `in_feature` hidden_size=ag.domdisc_dimh).to(device) dalossobj = ConditionalDomainAdversarialLoss( domdisc, reduction=ag.reduction, randomized=ag.cdan_rand, num_classes=num_classes, features_dim=discr.dim_s, randomized_dim=discr.dim_s).to(device) auto_load(locals(), ['domdisc', 'dalossobj'], ckpt) def lossfn(x, y, xt): logit, feat = discr.ys1x(x) logitt, featt = discr.ys1x(xt) logit_stack = tc.stack([tc.zeros_like(logit), logit], dim=-1) if dim_y == 1 else logit logitt_stack = tc.stack([tc.zeros_like(logitt), logitt], dim=-1) if dim_y == 1 else logitt return celossobj( logit, y.float() if dim_y == 1 else y) + ag.wda * dalossobj( logit_stack, feat, logitt_stack, featt) elif ag.mode == "dan": domdisc = None dalossobj = MultipleKernelMaximumMeanDiscrepancy([ GaussianKernel(alpha=alpha) for alpha in ag.ker_alphas ]).to(device) def lossfn(x, y, xt): logit, feat = discr.ys1x(x) featt = discr.s1x(xt) if feat.shape[0] < featt.shape[0]: featt = featt[:feat.shape[0]] elif feat.shape[0] > featt.shape[0]: feat = feat[:featt.shape[0]] return celossobj(logit, y.float() if dim_y == 1 else y) + ag.wda * dalossobj(feat, featt) elif ag.mode == "mdd": num_classes = (2 if dim_y == 1 else dim_y) domdisc = mlp.MLP([ dim_x, ag.domdisc_dimh, ag.domdisc_dimh, num_classes ]).to( device ) # actually not domain discriminator but an auxiliary (adversarial) classifier dalossobj = MarginDisparityDiscrepancy( margin=ag.mdd_margin, reduction=ag.reduction).to(device) auto_load(locals(), ['domdisc', 'dalossobj'], ckpt) def lossfn(x, y, xt): logit, logitt = discr(x), discr(xt) logit_adv, logitt_adv = domdisc(x.reshape(-1, dim_x)), domdisc( xt.reshape(-1, dim_x)) logit_stack = tc.stack([tc.zeros_like(logit), logit], dim=-1) if dim_y == 1 else logit logitt_stack = tc.stack([tc.zeros_like(logitt), logitt], dim=-1) if dim_y == 1 else logitt return celossobj( logit, y.float() if dim_y == 1 else y) + ag.wda * dalossobj( logit_stack, logit_adv, logitt_stack, logitt_adv) elif ag.mode == "bnm": domdisc = None dalossobj = None def lossfn(x, y, xt): logit, logitt = discr(x), discr(xt) logitt_stack = tc.stack([tc.zeros_like(logitt), logitt], dim=-1) if dim_y == 1 else logitt softmax_tgt = logitt_stack.softmax(dim=1) _, s_tgt, _ = tc.svd(softmax_tgt) # if config["method"]=="BNM": transfer_loss = -tc.mean(s_tgt) # elif config["method"]=="BFM": # transfer_loss = -tc.sqrt(tc.sum(s_tgt*s_tgt)/s_tgt.shape[0]) # elif config["method"]=="ENT": # transfer_loss = -tc.mean(tc.sum(softmax_tgt*tc.log(softmax_tgt+1e-8),dim=1))/tc.log(softmax_tgt.shape[1]) return celossobj( logit, y.float() if dim_y == 1 else y) + ag.wda * transfer_loss else: pass for obj in [dalossobj, domdisc]: if obj is not None: obj.train() else: if ag.mode.endswith("-da2") and discr_src is not None: true_discr = discr_src elif ag.mode in MODES_TWIST and ag.true_sup: true_discr = partial(frame.logit_y1x_src, n_mc_q=ag.n_mc_q) else: true_discr = discr celossfn = get_ce_or_bce_loss(true_discr, dim_y, ag.reduction)[1] lossobj = frame.get_lossfn(ag.n_mc_q, ag.reduction, "defl", weight_da=ag.wda / ag.wgen, wlogpi=ag.wlogpi / ag.wgen) lossfn = add_ce_loss(lossobj, celossfn, ag) domdisc, dalossobj = None, None return lossfn, domdisc, dalossobj
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define loss function if args.adversarial: thetas = [Theta(dim).to(device) for dim in (classifier.features_dim, num_classes)] else: thetas = None jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy( kernels=( [GaussianKernel(alpha=2 ** k) for k in range(-3, 2)], (GaussianKernel(sigma=0.92, track_running_stats=False),) ), linear=args.linear, thetas=thetas ).to(device) parameters = classifier.get_parameters() if thetas is not None: parameters += [{"params": theta.parameters(), 'lr': 0.1} for theta in thetas] # define optimizer optimizer = SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, jmmd_loss, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()
def main(args: argparse.Namespace): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ ResizeImage(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_transform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, evaluate=True, download=True, transform=val_transform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) num_classes = train_source_dataset.num_classes classifier = ImageClassifier(backbone, num_classes).to(device) # define loss function if args.adversarial: thetas = [ Theta(dim).to(device) for dim in (classifier.features_dim, num_classes) ] else: thetas = None jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy( kernels=([GaussianKernel(alpha=2**k) for k in range(-3, 2)], (GaussianKernel(sigma=0.92, track_running_stats=False), )), linear=args.linear, thetas=thetas).to(device) parameters = classifier.get_parameters() if thetas is not None: parameters += [{ "params": theta.parameters(), 'lr_mult': 0.1 } for theta in thetas] # define optimizer optimizer = SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_sheduler = StepwiseLR(optimizer, init_lr=args.lr, gamma=0.0003, decay_rate=0.75) # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, jmmd_loss, optimizer, lr_sheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint if acc1 > best_acc1: best_model = copy.deepcopy(classifier.state_dict()) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(best_model) acc1 = validate(test_loader, classifier, args) print("test_acc1 = {:3.1f}".format(acc1))