def gen_pred(args, model): X_target, y_target = preprocess_test(args.person) target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=args.batch_size, num_workers=1, pin_memory=True) model.eval() preds = [] with torch.no_grad(): for x, y_true in tqdm(target_loader, leave=False): x, y_true = x.to(device), y_true.to(device) y_pred = model(x).tolist() preds.append(y_pred) return preds
def main(args): X_target, y_target = preprocess_test() target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=args.batch_size, shuffle=False, num_workers=1, pin_memory=True) model = Net().to(device) model.load_state_dict(torch.load(args.MODEL_FILE)) model.eval() total_accuracy = 0 with torch.no_grad(): for x, y_true in tqdm(target_loader, leave=False): x, y_true = x.to(device), y_true.to(device) y_pred = model(x) total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item() mean_accuracy = total_accuracy / len(target_loader) print(f'Accuracy on target data: {mean_accuracy:.4f}')
def test(args, model): X_target, y_target = preprocess_test(args.person) target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=args.batch_size, num_workers=1, pin_memory=True) model.eval() total_accuracy = 0 with torch.no_grad(): for x, y_true in tqdm(target_loader, leave=False): x, y_true = x.to(device), y_true.to(device) y_pred = model(x) total_accuracy += ( y_pred.max(1)[1] == y_true).float().mean().item() mean_accuracy = total_accuracy / len(target_loader) #print(f'Accuracy on target data: {mean_accuracy:.4f}') return mean_accuracy
def main(args): final_accs = [] source_models = [Net().to(device) for _ in range(10)] for idx in range(len(source_models)): source_models[idx].load_state_dict(torch.load(args.MODEL_FILE)) source_models[idx].eval() set_requires_grad(source_models[idx], requires_grad=False) clfs = [source_model for source_model in source_models] source_models = [ source_model.feature_extractor for source_model in source_models ] target_models = [Net().to(device) for _ in range(10)] for idx in range(len(target_models)): target_models[idx].load_state_dict(torch.load(args.MODEL_FILE)) target_models[idx] = target_models[idx].feature_extractor discriminators = [ nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64), nn.ReLU(), nn.BatchNorm1d(64), nn.Linear(64, 1), nn.Sigmoid()).to(device) for _ in range(10) ] batch_size = args.batch_size discriminator_optims = [ torch.optim.Adam(discriminators[idx].parameters(), lr=1e-5) for idx in range(10) ] target_optims = [ torch.optim.Adam(target_models[idx].parameters(), lr=1e-5) for idx in range(10) ] criterion = nn.BCEWithLogitsLoss() source_loaders = [] target_loaders = [] for idx in range(10): X_source, y_source = preprocess_train_single(idx) source_dataset = torch.utils.data.TensorDataset(X_source, y_source) source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) source_loaders.append(source_loader) X_target, y_target = preprocess_test(args.person) target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) target_loaders.append(target_loader) best_voting_acc = test_all(clfs) best_tar_accs = [0.0] * 10 for epoch in range(1, args.epochs + 1): source_loaders = [ DataLoader(source_loaders[idx].dataset, batch_size=batch_size, shuffle=True) for idx in range(10) ] target_loaders = [ DataLoader(target_loaders[idx].dataset, batch_size=batch_size, shuffle=True) for idx in range(10) ] for idx in range(10): source_loader = source_loaders[idx] target_loader = target_loaders[idx] batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) target_model = target_models[idx] discriminator = discriminators[idx] source_model = source_models[idx] clf = clfs[idx] total_loss = 0 adv_loss = 0 total_accuracy = 0 second_acc = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) discriminator.train() for _ in range(args.k_disc): (source_x, _), (target_x, _) = next(batch_iterator) source_x, target_x = source_x.to(device), target_x.to( device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat( [source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optims[idx].zero_grad() loss.backward() discriminator_optims[idx].step() total_loss += loss.item() total_accuracy += ((preds >= 0.5).long( ) == discriminator_y.long()).float().mean().item() # Train classifier set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) target_model.train() for _ in range(args.k_clf): _, (target_x, _) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() second_acc += ((preds >= 0.5).long() == discriminator_y. long()).float().mean().item() loss = criterion(preds, discriminator_y) adv_loss += loss.item() target_optims[idx].zero_grad() loss.backward() target_optims[idx].step() mean_loss = total_loss / (args.iterations * args.k_disc) mean_adv_loss = adv_loss / (args.iterations * args.k_clf) dis_accuracy = total_accuracy / (args.iterations * args.k_disc) sec_acc = second_acc / (args.iterations * args.k_clf) clf.feature_extractor = target_model tar_accuarcy = test(args, clf) if tar_accuarcy > best_tar_accs[idx]: best_tar_accs[idx] = tar_accuarcy torch.save(clf.state_dict(), 'trained_models/adda' + str(idx) + '.pt') tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, ' f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_accs[idx]:.4f}, sec_acc = {sec_acc:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model #torch.save(clf.state_dict(), 'trained_models/adda.pt') acc = test_all(clfs) final_accs.append(acc) if acc > best_voting_acc: best_voting_acc = acc print("In epoch %d, voting_acc: %.4f, best_voting_acc: %.4f" % (epoch, acc, best_voting_acc)) jd = {"test_acc": final_accs} with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f: json.dump(jd, f)
def main(args): source_model = Net().to(device) source_model.load_state_dict(torch.load(args.MODEL_FILE)) source_model.eval() set_requires_grad(source_model, requires_grad=False) clf = source_model source_model = source_model.feature_extractor target_model = Net().to(device) target_model.load_state_dict(torch.load(args.MODEL_FILE)) target_model = target_model.feature_extractor classifier = clf.classifier discriminator = nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64), nn.ReLU(), nn.BatchNorm1d(64), nn.Linear(64, 1), nn.Sigmoid()).to(device) #half_batch = args.batch_size // 2 batch_size = args.batch_size # X_source, y_source = preprocess_train() X_source, y_source = preprocess_train_single(1) source_dataset = torch.utils.data.TensorDataset(X_source, y_source) source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) X_target, y_target = preprocess_test(args.person) target_dataset = torch.utils.data.TensorDataset(X_target, y_target) target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) discriminator_optim = torch.optim.Adam(discriminator.parameters()) target_optim = torch.optim.Adam(target_model.parameters(), lr=3e-6) criterion = nn.BCEWithLogitsLoss() criterion_class = nn.CrossEntropyLoss() best_tar_acc = test(args, clf) final_accs = [] for epoch in range(1, args.epochs + 1): source_loader = DataLoader(source_loader.dataset, batch_size=batch_size, shuffle=True) target_loader = DataLoader(target_loader.dataset, batch_size=batch_size, shuffle=True) batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) total_loss = 0 adv_loss = 0 total_accuracy = 0 second_acc = 0 total_class_loss = 0 for _ in trange(args.iterations, leave=False): # Train discriminator set_requires_grad(target_model, requires_grad=False) set_requires_grad(discriminator, requires_grad=True) discriminator.train() for _ in range(args.k_disc): (source_x, source_y), (target_x, _) = next(batch_iterator) source_y = source_y.to(device).view(-1) source_x, target_x = source_x.to(device), target_x.to(device) source_features = source_model(source_x).view( source_x.shape[0], -1) target_features = target_model(target_x).view( target_x.shape[0], -1) discriminator_x = torch.cat([source_features, target_features]) discriminator_y = torch.cat([ torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device) ]) preds = discriminator(discriminator_x).squeeze() loss = criterion(preds, discriminator_y) discriminator_optim.zero_grad() loss.backward() discriminator_optim.step() total_loss += loss.item() total_accuracy += ((preds >= 0.5).long() == discriminator_y. long()).float().mean().item() # Train feature extractor set_requires_grad(target_model, requires_grad=True) set_requires_grad(discriminator, requires_grad=False) target_model.train() for _ in range(args.k_clf): _, (target_x, _) = next(batch_iterator) target_x = target_x.to(device) target_features = target_model(target_x).view( target_x.shape[0], -1) source_features = target_model(source_x).view( source_x.shape[0], -1) source_pred = classifier(source_features) # (batch_size, 4) # flipped labels discriminator_y = torch.ones(target_x.shape[0], device=device) preds = discriminator(target_features).squeeze() second_acc += ((preds >= 0.5).long() == discriminator_y.long() ).float().mean().item() loss_adv = criterion(preds, discriminator_y) adv_loss += loss_adv.item() loss_class = criterion_class(source_pred, source_y) total_class_loss += loss_class.item() loss = loss_adv #+ 0.001*loss_class target_optim.zero_grad() loss.backward() target_optim.step() mean_loss = total_loss / (args.iterations * args.k_disc) mean_adv_loss = adv_loss / (args.iterations * args.k_clf) total_class_loss = total_class_loss / (args.iterations * args.k_clf) dis_accuracy = total_accuracy / (args.iterations * args.k_disc) sec_acc = second_acc / (args.iterations * args.k_clf) clf.feature_extractor = target_model tar_accuarcy = test(args, clf) final_accs.append(tar_accuarcy) if tar_accuarcy > best_tar_acc: best_tar_acc = tar_accuarcy torch.save(clf.state_dict(), 'trained_models/adda.pt') tqdm.write( f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, ' f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_acc:.4f}, ' f'sec_acc = {sec_acc:.4f}, total_class_loss: {total_class_loss:.4f}' ) # Create the full target model and save it clf.feature_extractor = target_model #torch.save(clf.state_dict(), 'trained_models/adda.pt') jd = {"test_acc": final_accs} with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f: json.dump(jd, f)