Пример #1
0
def evaluate_cluster(visualiser, i, nc, loader, classifier, id, device):
    labels = []
    preds = []
    n_preds = 0
    for data, label in loader:
        data, label = data.to(device), label.to(device)
        pred = F.softmax(classifier(data), 1)
        labels += [label]
        preds += [pred]
        n_preds += len(pred)
    labels = torch.cat(labels)
    preds = torch.cat(preds).argmax(1)
    correct = 0
    total = 0
    cluster_map = []
    for j in range(nc):
        label = labels[preds == j]
        if len(label):
            l = one_hot_embedding(label, nc).sum(0)
            correct += l.max()
            cluster_map.append(l.argmax())
        total += len(label)
    accuracy = correct / total
    accuracy = accuracy.cpu().numpy()
    visualiser.plot(accuracy,
                    title=f'Transfer clustering accuracy {id}',
                    step=i)
    return torch.LongTensor(cluster_map).to(device)
Пример #2
0
def contrastive_loss(x, n_classes, encoder, contrastive, device):
    enc = encoder(x)

    z = torch.randint(n_classes, size=(enc.shape[0], ))
    z = one_hot_embedding(z, n_classes).to(device)
    cz = contrastive(z).mean()
    cenc = contrastive(enc).mean()
    gp = gp_loss(enc, z, contrastive, device)
    return cz, cenc, gp
Пример #3
0
def compute_loss(x, xp, encoder, contrastive, device):
    z = encoder(x)
    zp = encoder(xp)

    ztrue = torch.randint(z.shape[1], size=(z.shape[0], ))
    ztrue = one_hot_embedding(ztrue, z.shape[1]).to(device)
    p = contrastive(z)
    closs = p.mean()

    dloss = F.mse_loss(zp, z).mean()
    return dloss, closs
Пример #4
0
def evaluate(visualiser, encoder, nc, data1, target, z_dim, generator, device):
    z = torch.randn(data1.shape[0], z_dim, device=device)
    visualiser.image(data1.cpu().numpy(), 'target1', 0)
    visualiser.image(target.cpu().numpy(), 'target2', 0)
    enc = encoder(data1).argmax(1)
    enc = one_hot_embedding(enc, nc).to(device)
    X = generator(enc, z)
    visualiser.image(X.cpu().numpy(), f'data{id}', 0)

    merged = len(X) * 2 * [None]
    merged[:2 * len(data1):2] = data1
    merged[1:2 * len(X):2] = X
    merged = torch.stack(merged)
    visualiser.image(merged.cpu().numpy(), f'Comparison{id}', 0)

    z = torch.stack(nc * [z[:nc - 1]]).transpose(0, 1).reshape(-1, z.shape[1])
    data1 = torch.cat((nc - 1) * [data1[:nc]])
    e1 = encoder(data1).argmax(1)
    e1 = one_hot_embedding(e1, nc).to(device)
    X = generator(e1, z)
    X = torch.cat((data1[:nc], X))
    visualiser.image(X.cpu().numpy(), f'Z effect{id}', 0)
Пример #5
0
def evaluate_gen_class_accuracy(visualiser, i, loader, nz, nc, encoder,
                                classifier, generator, id, device):
    correct = 0
    total = 0
    for data, label in loader:
        data, label = data.to(device), label.to(device)
        z = torch.randn(data.shape[0], nz, device=device)
        l = encoder(data).argmax(1)
        l = one_hot_embedding(l, nc).to(device)
        gen = generator(l, z)
        pred = F.softmax(classifier(gen), 1).argmax(1)
        correct += (pred == label).sum().cpu().float()
        total += len(pred)
    accuracy = correct / total
    accuracy = accuracy.cpu().numpy()
    visualiser.plot(accuracy, title=f'Generated accuracy', step=i)
    return accuracy
Пример #6
0
def evaluate_accuracy(visualiser, i, loader, classifier, nlabels, id, device):
    labels = []
    preds = []
    for data, label in loader:
        data, label = data.to(device), label.to(device)
        pred = F.softmax(classifier(data), 1)
        pred = classifier(data)
        labels += [label]
        preds += [pred]
    labels = torch.cat(labels)
    preds = torch.cat(preds).argmax(1)
    correct = 0
    total = 0
    for j in range(nlabels):
        label = labels[preds == j]
        if len(label):
            correct += one_hot_embedding(label, nlabels).sum(0).max()
        total += len(label)
    accuracy = correct / total
    accuracy = accuracy.cpu().numpy()
    visualiser.plot(accuracy, title=f'Classifier accuracy {id}', step=i)
    return accuracy
Пример #7
0
def train(args):
    parameters = vars(args)
    valid_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic = models['critic'].to(args.device)
    eval = args.evaluation.eval().to(args.device)
    print(generator)
    print(critic)

    optim_critic = optim.Adam(critic.parameters(),
                              lr=args.lr,
                              betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter2 = iter(train_loader2)
    titer, titer2 = iter(test_loader1), iter(test_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic.train()
        for _ in range(args.d_updates):
            batch, iter2 = sample(iter2, train_loader2)
            data = batch[0].to(args.device)
            label = corrupt(batch[1], args.nc, args.corrupt_tgt)
            label = one_hot_embedding(label, args.nc).to(args.device)
            optim_critic.zero_grad()
            pos_loss, neg_loss, gp = critic_loss(data, label, args.z_dim,
                                                 critic, generator,
                                                 args.device)
            pos_loss.backward()
            neg_loss.backward(mone)
            (10 * gp).backward()
            optim_critic.step()

        optim_generator.zero_grad()
        t_loss = transfer_loss(data.shape[0], label, args.z_dim, critic,
                               generator, args.device)
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            generator.eval()
            batch, titer = sample(titer, test_loader1)
            data1 = batch[0].to(args.device)
            label = one_hot_embedding(batch[1], args.nc).to(args.device)
            batch, titer = sample(titer2, test_loader2)
            data2 = batch[0].to(args.device)
            plot_transfer(args.visualiser, label, args.nc, data1, data2,
                          args.nz, generator, args.device, i)
            save_path = args.save_path
            eval_accuracy = evaluate(valid_loader1, args.nz, args.nc,
                                     args.corrupt_src, generator, eval,
                                     args.device)
            test_accuracy = evaluate(test_loader1, args.nz, args.nc,
                                     args.corrupt_src, generator, eval,
                                     args.device)
            with open(os.path.join(save_path, 'critic_loss'), 'a') as f:
                f.write(f'{i},{(pos_loss-neg_loss).cpu().item()}\n')
            with open(os.path.join(save_path, 'tloss'), 'a') as f:
                f.write(f'{i},{t_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'eval_accuracy'), 'a') as f:
                f.write(f'{i},{eval_accuracy}\n')
            with open(os.path.join(save_path, 'test_accuracy'), 'a') as f:
                f.write(f'{i},{eval_accuracy}\n')
            args.visualiser.plot((pos_loss - neg_loss).cpu().detach().numpy(),
                                 title='critic_loss',
                                 step=i)
            args.visualiser.plot(t_loss.cpu().detach().numpy(),
                                 title='tloss',
                                 step=i)
            args.visualiser.plot(eval_accuracy,
                                 title=f'Validation transfer accuracy',
                                 step=i)
            args.visualiser.plot(test_accuracy,
                                 title=f'Test transfer accuracy',
                                 step=i)

            t0 = time.time()
            save_models(models, 0, args.model_path, args.checkpoint)