def create_dataloaders(batch_size): dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) shuffled_indices = np.random.permutation(len(dataset)) train_idx = shuffled_indices[:int(0.8 * len(dataset))] val_idx = shuffled_indices[int(0.8 * len(dataset)):] train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, sampler=SubsetRandomSampler(train_idx), num_workers=1, pin_memory=True) val_loader = DataLoader(dataset, batch_size=batch_size, drop_last=False, sampler=SubsetRandomSampler(val_idx), num_workers=1, pin_memory=True) return train_loader, val_loader
def main(args): model = Net().to(device) model.load_state_dict(torch.load(args.MODEL_FILE)) feature_extractor = model.feature_extractor clf = model.classifier discriminator = nn.Sequential(GradientReversal(), nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) target_dataset = MNISTM(train=False) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) optim = torch.optim.Adam( list(discriminator.parameters()) + list(model.parameters())) for epoch in range(1, args.epochs + 1): batches = zip(source_loader, target_loader) n_batches = min(len(source_loader), len(target_loader)) total_domain_loss = total_label_accuracy = 0 target_label_accuracy = 0 for (source_x, source_labels), (target_x, target_labels) in tqdm(batches, leave=False, total=n_batches): x = torch.cat([source_x, target_x]) x = x.to(device) domain_y = torch.cat([ torch.ones(source_x.shape[0]), torch.zeros(target_x.shape[0]) ]) domain_y = domain_y.to(device) label_y = source_labels.to(device) features = feature_extractor(x).view(x.shape[0], -1) domain_preds = discriminator(features).squeeze() label_preds = clf(features[:source_x.shape[0]]) domain_loss = F.binary_cross_entropy_with_logits( domain_preds, domain_y) label_loss = F.cross_entropy(label_preds, label_y) loss = domain_loss + label_loss optim.zero_grad() loss.backward() optim.step() total_domain_loss += domain_loss.item() total_label_accuracy += ( label_preds.max(1)[1] == label_y).float().mean().item() target_label_preds = clf(features[source_x.shape[0]:]) target_label_accuracy += (target_label_preds.cpu().max(1)[1] == target_labels).float().mean().item() mean_loss = total_domain_loss / n_batches mean_accuracy = total_label_accuracy / n_batches target_mean_accuracy = target_label_accuracy / n_batches tqdm.write( f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, ' f'source_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) torch.save(model.state_dict(), 'trained_models/revgrad.pt')
def main(args): clf_model = Net().to(device) clf_model.load_state_dict(torch.load(args.MODEL_FILE)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) source_loader = DataLoader(source_dataset, batch_size=half_batch, drop_last=True, shuffle=True, num_workers=0, pin_memory=True) target_dataset = MNISTM(train=False) target_loader = DataLoader(target_dataset, batch_size=half_batch, drop_last=True, shuffle=True, num_workers=0, pin_memory=True) critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4) clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4) clf_criterion = nn.CrossEntropyLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 for _ in trange(args.iterations, leave=False): (source_x, source_y), (target_x, _) = next(batch_iterator) # Train critic set_requires_grad(feature_extractor, requires_grad=False) set_requires_grad(critic, requires_grad=True) source_x, target_x = source_x.to(device), target_x.to(device) source_y = source_y.to(device) with torch.no_grad(): h_s = feature_extractor(source_x).data.view( source_x.shape[0], -1) h_t = feature_extractor(target_x).data.view( target_x.shape[0], -1) for _ in range(args.k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + args.gamma * gp critic_optim.zero_grad() critic_cost.backward() critic_optim.step() total_loss += critic_cost.item() # Train classifier set_requires_grad(feature_extractor, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(args.k_clf): source_features = feature_extractor(source_x).view( source_x.shape[0], -1) target_features = feature_extractor(target_x).view( target_x.shape[0], -1) source_preds = discriminator(source_features) clf_loss = clf_criterion(source_preds, source_y) wasserstein_distance = critic(source_features).mean() - critic( target_features).mean() loss = clf_loss + args.wd_clf * wasserstein_distance clf_optim.zero_grad() loss.backward() clf_optim.step() mean_loss = total_loss / (args.iterations * args.k_critic) tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}') torch.save(clf_model.state_dict(), 'trained_models/wdgrl.pt')
def main(args): source_model = Net().to(device) source_model.load_state_dict(torch.load(args.MODEL_FILE)) source_model.eval() set_requires_grad(source_model, requires_grad=False) clf = source_model source_model = source_model.feature_extractor target_model = Net().to(device) target_model.load_state_dict(torch.load(args.MODEL_FILE)) target_model = target_model.feature_extractor target_clf = clf.classifier discriminator = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) target_dataset = MNISTM(train=False) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) discriminator_optim = torch.optim.Adam(discriminator.parameters()) target_optim = torch.optim.Adam(target_model.parameters()) criterion = nn.BCEWithLogitsLoss() for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 target_label_accuracy = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) for _ in range(args.k_disc): (source_x, _), (target_x, _) = next(batch_iterator) source_x, target_x = source_x.to(device), target_x.to(device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat([source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optim.zero_grad() loss.backward() discriminator_optim.step() total_loss += loss.item() total_accuracy += (( preds > 0).long() == discriminator_y.long()).float().mean().item() # Train classifier set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) for _ in range(args.k_clf): _, (target_x, target_labels) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() loss = criterion(preds, discriminator_y) target_optim.zero_grad() loss.backward() target_optim.step() target_label_preds = target_clf(target_features) target_label_accuracy += (target_label_preds.cpu().max(1)[1] == target_labels).float().mean().item() mean_loss = total_loss / (args.iterations * args.k_disc) mean_accuracy = total_accuracy / (args.iterations * args.k_disc) target_mean_accuracy = target_label_accuracy / (args.iterations * args.k_clf) tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, ' f'discriminator_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model torch.save(clf.state_dict(), 'trained_models/adda.pt')
def main(args): # TODO: add DTN model model = Net().to(device) model.load_state_dict(torch.load(args.MODEL_FILE)) feature_extractor = model.feature_extractor clf = model.classifier discriminator = nn.Sequential(GradientReversal(), nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 if args.adapt_setting == 'mnist2mnistm': source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose( [GrayscaleToRgb(), ToTensor()])) target_dataset = MNISTM(train=False) elif args.adapt_setting == 'svhn2mnist': source_dataset = ImageClassdata(txt_file=args.src_list, root_dir=args.src_root, img_type=args.img_type, transform=transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), ])) target_dataset = ImageClassdata(txt_file=args.tar_list, root_dir=args.tar_root, img_type=args.img_type, transform=transforms.Compose([ transforms.ToTensor(), ])) elif args.adapt_setting == 'mnist2usps': source_dataset = ImageClassdata(txt_file=args.src_list, root_dir=args.src_root, img_type=args.img_type, transform=transforms.Compose([ transforms.ToTensor(), ])) target_dataset = ImageClassdata(txt_file=args.tar_list, root_dir=args.tar_root, img_type=args.img_type, transform=transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), ])) else: raise NotImplementedError source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) optim = torch.optim.Adam( list(discriminator.parameters()) + list(model.parameters())) if not os.path.exists('logs'): os.makedirs('logs') f = open(f'logs/{args.adapt_setting}_{args.name}.txt', 'w+') for epoch in range(1, args.epochs + 1): batches = zip(source_loader, target_loader) n_batches = min(len(source_loader), len(target_loader)) total_domain_loss = total_label_accuracy = 0 target_label_accuracy = 0 for (source_x, source_labels), (target_x, target_labels) in tqdm(batches, leave=False, total=n_batches): x = torch.cat([source_x, target_x]) x = x.to(device) domain_y = torch.cat([ torch.ones(source_x.shape[0]), torch.zeros(target_x.shape[0]) ]) domain_y = domain_y.to(device) label_y = source_labels.to(device) features = feature_extractor(x).view(x.shape[0], -1) domain_preds = discriminator(features).squeeze() label_preds = clf(features[:source_x.shape[0]]) domain_loss = F.binary_cross_entropy_with_logits( domain_preds, domain_y) label_loss = F.cross_entropy(label_preds, label_y) loss = domain_loss + label_loss optim.zero_grad() loss.backward() optim.step() total_domain_loss += domain_loss.item() total_label_accuracy += ( label_preds.max(1)[1] == label_y).float().mean().item() target_label_preds = clf(features[source_x.shape[0]:]) target_label_accuracy += (target_label_preds.cpu().max(1)[1] == target_labels).float().mean().item() mean_loss = total_domain_loss / n_batches mean_accuracy = total_label_accuracy / n_batches target_mean_accuracy = target_label_accuracy / n_batches tqdm.write( f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, ' f'source_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) f.write( f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, ' f'source_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}\n' ) torch.save( model.state_dict(), f'trained_models/{args.adapt_setting}_{args.name}_ep{epoch}.pt') f.close()
def main(args): clf_model = Net().to(device) clf_model.load_state_dict(torch.load(args.MODEL_FILE)) feature_extractor = clf_model.feature_extractor discriminator = clf_model.classifier critic = nn.Sequential(nn.Linear(320, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 if args.adapt_setting == 'mnist2mnistm': source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose( [GrayscaleToRgb(), ToTensor()])) target_dataset = MNISTM(train=False) elif args.adapt_setting == 'svhn2mnist': source_dataset = ImageClassdata(txt_file=args.src_list, root_dir=args.src_root, img_type=args.img_type, transform=transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), ])) target_dataset = ImageClassdata(txt_file=args.tar_list, root_dir=args.tar_root, img_type=args.img_type, transform=transforms.Compose([ transforms.ToTensor(), ])) elif args.adapt_setting == 'mnist2usps': source_dataset = ImageClassdata(txt_file=args.src_list, root_dir=args.src_root, img_type=args.img_type, transform=transforms.Compose([ transforms.ToTensor(), ])) target_dataset = ImageClassdata(txt_file=args.tar_list, root_dir=args.tar_root, img_type=args.img_type, transform=transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), ])) else: raise NotImplementedError source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-4) clf_optim = torch.optim.Adam(clf_model.parameters(), lr=1e-4) clf_criterion = nn.CrossEntropyLoss() if not os.path.exists('logs'): os.makedirs('logs') f = open(f'logs/{args.adapt_setting}_{args.name}.txt', 'w+') for epoch in range(1, args.epochs + 1): batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 total_accuracy = 0 target_label_accuracy = 0 for _ in trange(args.iterations, leave=False): (source_x, source_y), (target_x, target_y) = next(batch_iterator) # Train critic set_requires_grad(feature_extractor, requires_grad=False) set_requires_grad(critic, requires_grad=True) source_x, target_x = source_x.to(device), target_x.to(device) source_y = source_y.to(device) with torch.no_grad(): h_s = feature_extractor(source_x).data.view( source_x.shape[0], -1) h_t = feature_extractor(target_x).data.view( target_x.shape[0], -1) for _ in range(args.k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + args.gamma * gp critic_optim.zero_grad() critic_cost.backward() critic_optim.step() total_loss += critic_cost.item() # Train classifier set_requires_grad(feature_extractor, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(args.k_clf): source_features = feature_extractor(source_x).view( source_x.shape[0], -1) target_features = feature_extractor(target_x).view( target_x.shape[0], -1) source_preds = discriminator(source_features) clf_loss = clf_criterion(source_preds, source_y) wasserstein_distance = critic(source_features).mean() - critic( target_features).mean() loss = clf_loss + args.wd_clf * wasserstein_distance clf_optim.zero_grad() loss.backward() clf_optim.step() target_preds = discriminator(target_features) target_label_accuracy += (target_preds.cpu().max(1)[1] == target_y).float().mean().item() mean_loss = total_loss / (args.iterations * args.k_critic) # mean_accuracy = total_accuracy / (args.iterations * args.k_disc) target_mean_accuracy = target_label_accuracy / (args.iterations * args.k_clf) tqdm.write( f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) f.write( f'EPOCH {epoch:03d}: critic_loss={mean_loss:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) torch.save( clf_model.state_dict(), f'trained_models/{args.adapt_setting}_{args.name}_ep{epoch}.pt') f.close()
def main(args): model = BayesNet().to(device) model.load_state_dict(torch.load(args.MODEL_FILE)) discriminator = nn.Sequential(nn.Linear(10, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 1)).to(device) half_batch = args.batch_size // 2 source_dataset = MNIST(config.DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()])) source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) target_dataset = MNISTM(train=False) target_loader = DataLoader(target_dataset, batch_size=half_batch, shuffle=True, num_workers=1, pin_memory=True) optim_D = torch.optim.Adam(discriminator.parameters()) optim_G = torch.optim.Adam(model.parameters()) for epoch in range(1, args.epochs + 1): batches = zip(source_loader, target_loader) n_batches = min(len(source_loader), len(target_loader)) total_domain_loss = total_label_accuracy = 0 target_label_accuracy = 0 for (source_x, source_labels), (target_x, target_labels) in tqdm(batches, leave=False, total=n_batches): x = torch.cat([source_x, target_x]) x = x.to(device) domain_y = torch.cat([ torch.ones(source_x.shape[0]), torch.zeros(target_x.shape[0]) ]) domain_y = domain_y.to(device) label_y = source_labels.to(device) # forward y_preds, s_preds = model(x) logits = reparameterize(y_preds, s_preds) domain_feat = logits # train discriminator domain_preds = discriminator(domain_feat.detach()).squeeze() domain_loss = F.binary_cross_entropy_with_logits( domain_preds, domain_y) optim_D.zero_grad() domain_loss.backward() optim_D.step() # train generator label_preds = logits[:source_x.shape[0]] label_loss = F.cross_entropy(label_preds, label_y) domain_preds = discriminator( domain_feat[:source_x.shape[0]]).squeeze() domain_loss = F.binary_cross_entropy_with_logits( domain_preds, torch.zeros(source_x.shape[0]).to(device)) loss = label_loss + domain_loss optim_G.zero_grad() loss.backward() optim_G.step() total_domain_loss += domain_loss.item() total_label_accuracy += ( label_preds.max(1)[1] == label_y).float().mean().item() target_label_preds = logits[source_x.shape[0]:] target_label_accuracy += (target_label_preds.cpu().max(1)[1] == target_labels).float().mean().item() mean_loss = total_domain_loss / n_batches mean_accuracy = total_label_accuracy / n_batches target_mean_accuracy = target_label_accuracy / n_batches tqdm.write( f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, ' f'source_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}' ) torch.save(model.state_dict(), 'trained_models/auda.pt')