def main(args): if args.model == 'gta': model_file = './trained_models/gta_source.pt' out_file = './trained_models/gta_wdgrl.pt' out_ftrs = 4375 clf_model = GTANet().to(device) clf_model.load_state_dict(torch.load(model_file)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier elif args.model == 'gta-res': model_file = './trained_models/gta_res_source.pt' out_file = './trained_models/gta_res_wdgrl.pt' clf_model = GTARes18Net(9, pretrained=False).to(device) out_ftrs = clf_model.fc.in_features clf_model.load_state_dict(torch.load(model_file)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.fc elif args.model == 'gta-vgg': model_file = './trained_models/gta_vgg_source.pt' out_file = './trained_models/gta_vgg_wdgrl.pt' clf_model = GTAVGG11Net(9, pretrained=False).to(device) out_ftrs = clf_model.classifier[0].in_features # should be 512 * 7 * 7 clf_model.load_state_dict(torch.load(model_file)) set_requires_grad(clf_model, False) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier else: raise ValueError(f'Unknown model type {args.model}') critic = nn.Sequential( nn.Linear(out_ftrs, 64), nn.ReLU(), nn.Linear(64, 16), nn.ReLU(), nn.Linear(16, 1), ).to(device) half_batch = args.batch_size // 2 target_dataset = ImageFolder('./data', transform=Compose([ Resize((398, 224)), RandomCrop(224), RandomHorizontalFlip(), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) source_dataset = ImageFolder('./t_data', transform=Compose([ RandomCrop(224, pad_if_needed=True, padding_mode='reflect'), RandomHorizontalFlip(), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4) clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4) clf_criterion = nn.CrossEntropyLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 for _ in trange(args.iterations, leave=False): (source_x, source_y), (target_x, _) = next(batch_iterator) # Train critic set_requires_grad(feature_extractor, requires_grad=False) set_requires_grad(critic, requires_grad=True) source_x, target_x = source_x.to(device), target_x.to(device) source_y = source_y.to(device) with torch.no_grad(): h_s = feature_extractor(source_x).data.view( source_x.shape[0], -1) h_t = feature_extractor(target_x).data.view( target_x.shape[0], -1) for _ in range(args.k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + args.gamma * gp critic_optim.zero_grad() critic_cost.backward() critic_optim.step() total_loss += critic_cost.item() # Train classifier set_requires_grad(feature_extractor, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(args.k_clf): source_features = feature_extractor(source_x).view( source_x.shape[0], -1) target_features = feature_extractor(target_x).view( target_x.shape[0], -1) source_preds = discriminator(source_features) clf_loss = clf_criterion(source_preds, source_y) wasserstein_distance = critic(source_features).mean() - critic( target_features).mean() loss = clf_loss + args.wd_clf * wasserstein_distance clf_optim.zero_grad() loss.backward() clf_optim.step() mean_loss = total_loss / (args.iterations * args.k_critic) tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}') torch.save(clf_model.state_dict(), out_file)
def main(args): clf_model = Net().to(device) clf_model.load_state_dict(torch.load(args.MODEL_FILE)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) source_loader = DataLoader(source_dataset, batch_size=half_batch, drop_last=True, shuffle=True, num_workers=0, pin_memory=True) target_dataset = MNISTM(train=False) target_loader = DataLoader(target_dataset, batch_size=half_batch, drop_last=True, shuffle=True, num_workers=0, pin_memory=True) critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4) clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4) clf_criterion = nn.CrossEntropyLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 for _ in trange(args.iterations, leave=False): (source_x, source_y), (target_x, _) = next(batch_iterator) # Train critic set_requires_grad(feature_extractor, requires_grad=False) set_requires_grad(critic, requires_grad=True) source_x, target_x = source_x.to(device), target_x.to(device) source_y = source_y.to(device) with torch.no_grad(): h_s = feature_extractor(source_x).data.view( source_x.shape[0], -1) h_t = feature_extractor(target_x).data.view( target_x.shape[0], -1) for _ in range(args.k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + args.gamma * gp critic_optim.zero_grad() critic_cost.backward() critic_optim.step() total_loss += critic_cost.item() # Train classifier set_requires_grad(feature_extractor, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(args.k_clf): source_features = feature_extractor(source_x).view( source_x.shape[0], -1) target_features = feature_extractor(target_x).view( target_x.shape[0], -1) source_preds = discriminator(source_features) clf_loss = clf_criterion(source_preds, source_y) wasserstein_distance = critic(source_features).mean() - critic( target_features).mean() loss = clf_loss + args.wd_clf * wasserstein_distance clf_optim.zero_grad() loss.backward() clf_optim.step() mean_loss = total_loss / (args.iterations * args.k_critic) tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}') torch.save(clf_model.state_dict(), 'trained_models/wdgrl.pt')
def main(args): final_accs = [] source_models = [Net().to(device) for _ in range(10)] for idx in range(len(source_models)): source_models[idx].load_state_dict(torch.load(args.MODEL_FILE)) source_models[idx].eval() set_requires_grad(source_models[idx], requires_grad=False) clfs = [source_model for source_model in source_models] source_models = [ source_model.feature_extractor for source_model in source_models ] target_models = [Net().to(device) for _ in range(10)] for idx in range(len(target_models)): target_models[idx].load_state_dict(torch.load(args.MODEL_FILE)) target_models[idx] = target_models[idx].feature_extractor discriminators = [ nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64), nn.ReLU(), nn.BatchNorm1d(64), nn.Linear(64, 1), nn.Sigmoid()).to(device) for _ in range(10) ] batch_size = args.batch_size discriminator_optims = [ torch.optim.Adam(discriminators[idx].parameters(), lr=1e-5) for idx in range(10) ] target_optims = [ torch.optim.Adam(target_models[idx].parameters(), lr=1e-5) for idx in range(10) ] criterion = nn.BCEWithLogitsLoss() source_loaders = [] target_loaders = [] for idx in range(10): X_source, y_source = preprocess_train_single(idx) source_dataset = torch.utils.data.TensorDataset(X_source, y_source) source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) source_loaders.append(source_loader) X_target, y_target = preprocess_test(args.person) target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) target_loaders.append(target_loader) best_voting_acc = test_all(clfs) best_tar_accs = [0.0] * 10 for epoch in range(1, args.epochs + 1): source_loaders = [ DataLoader(source_loaders[idx].dataset, batch_size=batch_size, shuffle=True) for idx in range(10) ] target_loaders = [ DataLoader(target_loaders[idx].dataset, batch_size=batch_size, shuffle=True) for idx in range(10) ] for idx in range(10): source_loader = source_loaders[idx] target_loader = target_loaders[idx] batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) target_model = target_models[idx] discriminator = discriminators[idx] source_model = source_models[idx] clf = clfs[idx] total_loss = 0 adv_loss = 0 total_accuracy = 0 second_acc = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) discriminator.train() for _ in range(args.k_disc): (source_x, _), (target_x, _) = next(batch_iterator) source_x, target_x = source_x.to(device), target_x.to( device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat( [source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optims[idx].zero_grad() loss.backward() discriminator_optims[idx].step() total_loss += loss.item() total_accuracy += ((preds >= 0.5).long( ) == discriminator_y.long()).float().mean().item() # Train classifier set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) target_model.train() for _ in range(args.k_clf): _, (target_x, _) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() second_acc += ((preds >= 0.5).long() == discriminator_y. long()).float().mean().item() loss = criterion(preds, discriminator_y) adv_loss += loss.item() target_optims[idx].zero_grad() loss.backward() target_optims[idx].step() mean_loss = total_loss / (args.iterations * args.k_disc) mean_adv_loss = adv_loss / (args.iterations * args.k_clf) dis_accuracy = total_accuracy / (args.iterations * args.k_disc) sec_acc = second_acc / (args.iterations * args.k_clf) clf.feature_extractor = target_model tar_accuarcy = test(args, clf) if tar_accuarcy > best_tar_accs[idx]: best_tar_accs[idx] = tar_accuarcy torch.save(clf.state_dict(), 'trained_models/adda' + str(idx) + '.pt') tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, ' f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_accs[idx]:.4f}, sec_acc = {sec_acc:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model #torch.save(clf.state_dict(), 'trained_models/adda.pt') acc = test_all(clfs) final_accs.append(acc) if acc > best_voting_acc: best_voting_acc = acc print("In epoch %d, voting_acc: %.4f, best_voting_acc: %.4f" % (epoch, acc, best_voting_acc)) jd = {"test_acc": final_accs} with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f: json.dump(jd, f)
def main(args): source_model = Net().to(device) source_model.load_state_dict(torch.load(args.MODEL_FILE)) source_model.eval() set_requires_grad(source_model, requires_grad=False) clf = source_model source_model = source_model.feature_extractor target_model = Net().to(device) target_model.load_state_dict(torch.load(args.MODEL_FILE)) target_model = target_model.feature_extractor target_clf = clf.classifier discriminator = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) target_dataset = MNISTM(train=False) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) discriminator_optim = torch.optim.Adam(discriminator.parameters()) target_optim = torch.optim.Adam(target_model.parameters()) criterion = nn.BCEWithLogitsLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 target_label_accuracy = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) for _ in range(args.k_disc): (source_x, _), (target_x, _) = next(batch_iterator) source_x, target_x = source_x.to(device), target_x.to(device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat([source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optim.zero_grad() loss.backward() discriminator_optim.step() total_loss += loss.item() total_accuracy += (( preds > 0).long() == discriminator_y.long()).float().mean().item() # Train classifier set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) for _ in range(args.k_clf): _, (target_x, target_labels) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() loss = criterion(preds, discriminator_y) target_optim.zero_grad() loss.backward() target_optim.step() target_label_preds = target_clf(target_features) target_label_accuracy += (target_label_preds.cpu().max(1)[1] == target_labels).float().mean().item() mean_loss = total_loss / (args.iterations * args.k_disc) mean_accuracy = total_accuracy / (args.iterations * args.k_disc) target_mean_accuracy = target_label_accuracy / (args.iterations * args.k_clf) tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, ' f'discriminator_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model torch.save(clf.state_dict(), 'trained_models/adda.pt')
def main(args): if args.model == 'gta': model_file = './trained_models/gta_source.pt' out_file = './trained_models/gta_adda.pt' out_ftrs = 4375 model = GTANet().to(device) model.load_state_dict(torch.load(model_file)) model.eval() set_requires_grad(model, False) source_model = model.feature_extractor clf = model model_2 = GTANet().to(device) model_2.load_state_dict(torch.load(model_file)) target_model = model_2.feature_extractor elif args.model == 'gta-res': model_file = './trained_models/gta_res_source.pt' out_file = './trained_models/gta_res_adda.pt' model = GTARes18Net(9, pretrained=False).to(device) out_ftrs = model.fc.in_features model.load_state_dict(torch.load(model_file)) model.eval() set_requires_grad(model, False) source_model = model.feature_extractor clf = model model_2 = GTARes18Net(9, pretrained=False).to(device) model_2.load_state_dict(torch.load(model_file)) target_model = model_2.feature_extractor elif args.model == 'gta-vgg': model_file = './trained_models/gta_vgg_source.pt' out_file = './trained_models/gta_vgg_adda.pt' model = GTAVGG11Net(9, pretrained=False).to(device) out_ftrs = model.classifier[0].in_features # should be 512 * 7 * 7 model.load_state_dict(torch.load(model_file)) model.eval() set_requires_grad(model, False) def source_model(x): x = model.features(x) x = model.avgpool(x) x = torch.flatten(x, 1) return x clf = model model_2 = GTAVGG11Net(9, pretrained=False).to(device) model_2.load_state_dict(torch.load(model_file)) def target_model(x): x = model_2.features(x) x = model_2.avgpool(x) x = torch.flatten(x, 1) return x else: raise ValueError(f'Unknown model type {args.model}') discriminator = nn.Sequential( nn.Linear(out_ftrs, 64), nn.ReLU(), nn.Linear(64, 16), nn.ReLU(), nn.Linear(16, 1), ).to(device) half_batch = args.batch_size // 2 target_dataset = ImageFolder('./data', transform=Compose([ Resize((398, 224)), RandomCrop(224), RandomHorizontalFlip(), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) source_dataset = ImageFolder('./t_data', transform=Compose([ RandomCrop(224, pad_if_needed=True, padding_mode='reflect'), RandomHorizontalFlip(), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) discriminator_optim = torch.optim.Adam(discriminator.parameters()) target_optim = torch.optim.Adam(model_2.parameters()) criterion = nn.BCEWithLogitsLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(model_2, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) for _ in range(args.k_disc): (source_x, _), (target_x, _) = next(batch_iterator) source_x, target_x = source_x.to(device), target_x.to(device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat([source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optim.zero_grad() loss.backward() discriminator_optim.step() total_loss += loss.item() total_accuracy += ((preds > 0).long() == discriminator_y.long() ).float().mean().item() # Train classifier set_requires_grad(model_2, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) for _ in range(args.k_clf): _, (target_x, _) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() loss = criterion(preds, discriminator_y) target_optim.zero_grad() loss.backward() target_optim.step() mean_loss = total_loss / (args.iterations * args.k_disc) mean_accuracy = total_accuracy / (args.iterations * args.k_disc) tqdm.write(f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, ' f'discriminator_accuracy={mean_accuracy:.4f}') # Create the full target model and save it if args.model == 'gta': clf.feature_extractor = target_model elif args.model == 'gta-res': clf.conv1 = model_2.conv1 clf.bn1 = model_2.bn1 clf.relu = model_2.relu clf.maxpool = model_2.maxpool clf.layer1 = model_2.layer1 clf.layer2 = model_2.layer2 clf.layer3 = model_2.layer3 clf.layer4 = model_2.layer4 clf.avgpool = model_2.avgpool torch.save(clf.state_dict(), out_file)
def main(args): source_model = Net().to(device) source_model.load_state_dict(torch.load(args.MODEL_FILE)) source_model.eval() set_requires_grad(source_model, requires_grad=False) clf = source_model source_model = source_model.feature_extractor target_model = Net().to(device) target_model.load_state_dict(torch.load(args.MODEL_FILE)) target_model = target_model.feature_extractor classifier = clf.classifier discriminator = nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64), nn.ReLU(), nn.BatchNorm1d(64), nn.Linear(64, 1), nn.Sigmoid()).to(device) #half_batch = args.batch_size // 2 batch_size = args.batch_size # X_source, y_source = preprocess_train() X_source, y_source = preprocess_train_single(1) source_dataset = torch.utils.data.TensorDataset(X_source, y_source) source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) X_target, y_target = preprocess_test(args.person) target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) discriminator_optim = torch.optim.Adam(discriminator.parameters()) target_optim = torch.optim.Adam(target_model.parameters(), lr=3e-6) criterion = nn.BCEWithLogitsLoss() criterion_class = nn.CrossEntropyLoss() best_tar_acc = test(args, clf) final_accs = [] for epoch in range(1, args.epochs + 1): source_loader = DataLoader(source_loader.dataset, batch_size=batch_size, shuffle=True) target_loader = DataLoader(target_loader.dataset, batch_size=batch_size, shuffle=True) batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 adv_loss = 0 total_accuracy = 0 second_acc = 0 total_class_loss = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) discriminator.train() for _ in range(args.k_disc): (source_x, source_y), (target_x, _) = next(batch_iterator) source_y = source_y.to(device).view(-1) source_x, target_x = source_x.to(device), target_x.to(device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat([source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optim.zero_grad() loss.backward() discriminator_optim.step() total_loss += loss.item() total_accuracy += ((preds >= 0.5).long() == discriminator_y. long()).float().mean().item() # Train feature extractor set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) target_model.train() for _ in range(args.k_clf): _, (target_x, _) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) source_features = target_model(source_x).view( source_x.shape[0], -1) source_pred = classifier(source_features) # (batch_size, 4) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() second_acc += ((preds >= 0.5).long() == discriminator_y.long() ).float().mean().item() loss_adv = criterion(preds, discriminator_y) adv_loss += loss_adv.item() loss_class = criterion_class(source_pred, source_y) total_class_loss += loss_class.item() loss = loss_adv #+ 0.001*loss_class target_optim.zero_grad() loss.backward() target_optim.step() mean_loss = total_loss / (args.iterations * args.k_disc) mean_adv_loss = adv_loss / (args.iterations * args.k_clf) total_class_loss = total_class_loss / (args.iterations * args.k_clf) dis_accuracy = total_accuracy / (args.iterations * args.k_disc) sec_acc = second_acc / (args.iterations * args.k_clf) clf.feature_extractor = target_model tar_accuarcy = test(args, clf) final_accs.append(tar_accuarcy) if tar_accuarcy > best_tar_acc: best_tar_acc = tar_accuarcy torch.save(clf.state_dict(), 'trained_models/adda.pt') tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, ' f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_acc:.4f}, ' f'sec_acc = {sec_acc:.4f}, total_class_loss: {total_class_loss:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model #torch.save(clf.state_dict(), 'trained_models/adda.pt') jd = {"test_acc": final_accs} with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f: json.dump(jd, f)
def main(args): clf_model = Net().to(device) clf_model.load_state_dict(torch.load(args.MODEL_FILE)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 if args.adapt_setting == 'mnist2mnistm': source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose( [GrayscaleToRgb(), ToTensor()])) target_dataset = MNISTM(train=False) elif args.adapt_setting == 'svhn2mnist': source_dataset = ImageClassdata(txt_file=args.src_list, root_dir=args.src_root, img_type=args.img_type, transform=transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), ])) target_dataset = ImageClassdata(txt_file=args.tar_list, root_dir=args.tar_root, img_type=args.img_type, transform=transforms.Compose([ transforms.ToTensor(), ])) elif args.adapt_setting == 'mnist2usps': source_dataset = ImageClassdata(txt_file=args.src_list, root_dir=args.src_root, img_type=args.img_type, transform=transforms.Compose([ transforms.ToTensor(), ])) target_dataset = ImageClassdata(txt_file=args.tar_list, root_dir=args.tar_root, img_type=args.img_type, transform=transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), ])) else: raise NotImplementedError source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4) clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4) clf_criterion = nn.CrossEntropyLoss() if not os.path.exists('logs'): os.makedirs('logs') f = open(f'logs/{args.adapt_setting}_{args.name}.txt', 'w+') for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 target_label_accuracy = 0 for _ in trange(args.iterations, leave=False): (source_x, source_y), (target_x, target_y) = next(batch_iterator) # Train critic set_requires_grad(feature_extractor, requires_grad=False) set_requires_grad(critic, requires_grad=True) source_x, target_x = source_x.to(device), target_x.to(device) source_y = source_y.to(device) with torch.no_grad(): h_s = feature_extractor(source_x).data.view( source_x.shape[0], -1) h_t = feature_extractor(target_x).data.view( target_x.shape[0], -1) for _ in range(args.k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + args.gamma * gp critic_optim.zero_grad() critic_cost.backward() critic_optim.step() total_loss += critic_cost.item() # Train classifier set_requires_grad(feature_extractor, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(args.k_clf): source_features = feature_extractor(source_x).view( source_x.shape[0], -1) target_features = feature_extractor(target_x).view( target_x.shape[0], -1) source_preds = discriminator(source_features) clf_loss = clf_criterion(source_preds, source_y) wasserstein_distance = critic(source_features).mean() - critic( target_features).mean() loss = clf_loss + args.wd_clf * wasserstein_distance clf_optim.zero_grad() loss.backward() clf_optim.step() target_preds = discriminator(target_features) target_label_accuracy += (target_preds.cpu().max(1)[1] == target_y).float().mean().item() mean_loss = total_loss / (args.iterations * args.k_critic) # mean_accuracy = total_accuracy / (args.iterations * args.k_disc) target_mean_accuracy = target_label_accuracy / (args.iterations * args.k_clf) tqdm.write( f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) f.write( f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) torch.save( clf_model.state_dict(), f'trained_models/{args.adapt_setting}_{args.name}_ep{epoch}.pt') f.close()