def evaluate_cluster(visualiser, i, nc, loader, classifier, id, device): labels = [] preds = [] n_preds = 0 for data, label in loader: data, label = data.to(device), label.to(device) pred = F.softmax(classifier(data), 1) labels += [label] preds += [pred] n_preds += len(pred) labels = torch.cat(labels) preds = torch.cat(preds).argmax(1) correct = 0 total = 0 cluster_map = [] for j in range(nc): label = labels[preds == j] if len(label): l = one_hot_embedding(label, nc).sum(0) correct += l.max() cluster_map.append(l.argmax()) total += len(label) accuracy = correct / total accuracy = accuracy.cpu().numpy() visualiser.plot(accuracy, title=f'Transfer clustering accuracy {id}', step=i) return torch.LongTensor(cluster_map).to(device)
def contrastive_loss(x, n_classes, encoder, contrastive, device): enc = encoder(x) z = torch.randint(n_classes, size=(enc.shape[0], )) z = one_hot_embedding(z, n_classes).to(device) cz = contrastive(z).mean() cenc = contrastive(enc).mean() gp = gp_loss(enc, z, contrastive, device) return cz, cenc, gp
def compute_loss(x, xp, encoder, contrastive, device): z = encoder(x) zp = encoder(xp) ztrue = torch.randint(z.shape[1], size=(z.shape[0], )) ztrue = one_hot_embedding(ztrue, z.shape[1]).to(device) p = contrastive(z) closs = p.mean() dloss = F.mse_loss(zp, z).mean() return dloss, closs
def evaluate(visualiser, encoder, nc, data1, target, z_dim, generator, device): z = torch.randn(data1.shape[0], z_dim, device=device) visualiser.image(data1.cpu().numpy(), 'target1', 0) visualiser.image(target.cpu().numpy(), 'target2', 0) enc = encoder(data1).argmax(1) enc = one_hot_embedding(enc, nc).to(device) X = generator(enc, z) visualiser.image(X.cpu().numpy(), f'data{id}', 0) merged = len(X) * 2 * [None] merged[:2 * len(data1):2] = data1 merged[1:2 * len(X):2] = X merged = torch.stack(merged) visualiser.image(merged.cpu().numpy(), f'Comparison{id}', 0) z = torch.stack(nc * [z[:nc - 1]]).transpose(0, 1).reshape(-1, z.shape[1]) data1 = torch.cat((nc - 1) * [data1[:nc]]) e1 = encoder(data1).argmax(1) e1 = one_hot_embedding(e1, nc).to(device) X = generator(e1, z) X = torch.cat((data1[:nc], X)) visualiser.image(X.cpu().numpy(), f'Z effect{id}', 0)
def evaluate_gen_class_accuracy(visualiser, i, loader, nz, nc, encoder, classifier, generator, id, device): correct = 0 total = 0 for data, label in loader: data, label = data.to(device), label.to(device) z = torch.randn(data.shape[0], nz, device=device) l = encoder(data).argmax(1) l = one_hot_embedding(l, nc).to(device) gen = generator(l, z) pred = F.softmax(classifier(gen), 1).argmax(1) correct += (pred == label).sum().cpu().float() total += len(pred) accuracy = correct / total accuracy = accuracy.cpu().numpy() visualiser.plot(accuracy, title=f'Generated accuracy', step=i) return accuracy
def evaluate_accuracy(visualiser, i, loader, classifier, nlabels, id, device): labels = [] preds = [] for data, label in loader: data, label = data.to(device), label.to(device) pred = F.softmax(classifier(data), 1) pred = classifier(data) labels += [label] preds += [pred] labels = torch.cat(labels) preds = torch.cat(preds).argmax(1) correct = 0 total = 0 for j in range(nlabels): label = labels[preds == j] if len(label): correct += one_hot_embedding(label, nlabels).sum(0).max() total += len(label) accuracy = correct / total accuracy = accuracy.cpu().numpy() visualiser.plot(accuracy, title=f'Classifier accuracy {id}', step=i) return accuracy
def train(args): parameters = vars(args) valid_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) critic = models['critic'].to(args.device) eval = args.evaluation.eval().to(args.device) print(generator) print(critic) optim_critic = optim.Adam(critic.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter2 = iter(train_loader2) titer, titer2 = iter(test_loader1), iter(test_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): generator.train() critic.train() for _ in range(args.d_updates): batch, iter2 = sample(iter2, train_loader2) data = batch[0].to(args.device) label = corrupt(batch[1], args.nc, args.corrupt_tgt) label = one_hot_embedding(label, args.nc).to(args.device) optim_critic.zero_grad() pos_loss, neg_loss, gp = critic_loss(data, label, args.z_dim, critic, generator, args.device) pos_loss.backward() neg_loss.backward(mone) (10 * gp).backward() optim_critic.step() optim_generator.zero_grad() t_loss = transfer_loss(data.shape[0], label, args.z_dim, critic, generator, args.device) t_loss.backward() optim_generator.step() if i % args.evaluate == 0: print('Iter: %s' % i, time.time() - t0) generator.eval() batch, titer = sample(titer, test_loader1) data1 = batch[0].to(args.device) label = one_hot_embedding(batch[1], args.nc).to(args.device) batch, titer = sample(titer2, test_loader2) data2 = batch[0].to(args.device) plot_transfer(args.visualiser, label, args.nc, data1, data2, args.nz, generator, args.device, i) save_path = args.save_path eval_accuracy = evaluate(valid_loader1, args.nz, args.nc, args.corrupt_src, generator, eval, args.device) test_accuracy = evaluate(test_loader1, args.nz, args.nc, args.corrupt_src, generator, eval, args.device) with open(os.path.join(save_path, 'critic_loss'), 'a') as f: f.write(f'{i},{(pos_loss-neg_loss).cpu().item()}\n') with open(os.path.join(save_path, 'tloss'), 'a') as f: f.write(f'{i},{t_loss.cpu().item()}\n') with open(os.path.join(save_path, 'eval_accuracy'), 'a') as f: f.write(f'{i},{eval_accuracy}\n') with open(os.path.join(save_path, 'test_accuracy'), 'a') as f: f.write(f'{i},{eval_accuracy}\n') args.visualiser.plot((pos_loss - neg_loss).cpu().detach().numpy(), title='critic_loss', step=i) args.visualiser.plot(t_loss.cpu().detach().numpy(), title='tloss', step=i) args.visualiser.plot(eval_accuracy, title=f'Validation transfer accuracy', step=i) args.visualiser.plot(test_accuracy, title=f'Test transfer accuracy', step=i) t0 = time.time() save_models(models, 0, args.model_path, args.checkpoint)