Ejemplo n.º 1
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    discriminator = models['discriminator'].to(args.device)
    print(generator)
    print(discriminator)

    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        discriminator.train()
        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            optim_discriminator.zero_grad()
            d_pos_loss, d_neg_loss, gp = disc_loss_generation(
                data, args.z_dim, discriminator, generator, args.device)
            d_pos_loss.backward()
            d_neg_loss.backward(mone)
            (10 * gp).backward()
            optim_discriminator.step()

        optim_generator.zero_grad()
        t_loss = transfer_loss(args.train_batch_size, args.z_dim,
                               discriminator, generator, args.device)
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            noise = torch.randn(data.shape[0], args.z_dim, device=args.device)
            evaluate(args.visualiser, noise, data, generator, i)
            d_loss = (d_pos_loss - d_neg_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i,
                                 data=d_loss,
                                 title=f'Discriminator loss')
            args.visualiser.plot(step=i,
                                 data=t_loss.detach().cpu().numpy(),
                                 title=f'Generator loss')
            args.visualiser.plot(step=i,
                                 data=gp.detach().cpu().numpy(),
                                 title=f'Gradient Penalty')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 2
0
def evaluate_fid(args):
    parameters = vars(args)
    _, _, test_loader1 = args.loaders1
    _, _, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, True, args.save_path, args.model_path)

    generatorXY = models['generatorXY'].to(args.device)
    generatorYX = models['generatorYX'].to(args.device)

    datas1 = []
    labels1 = []
    gens1 = []
    for i, (data, label) in enumerate(test_loader1):
        data, label = data.to(args.device), label.to(args.device)
        datas1 += [data]
        labels1 += [label]
        z = torch.randn(len(data), args.z_dim, device=args.device)
        gen = generatorXY(data, z)
        gens1 += [gen]
    datas1 = torch.cat(datas1)
    labels1 = torch.cat(labels1)
    gens1 = torch.cat(gens1)

    datas2 = []
    labels2 = []
    gens2 = []
    for i, (data, label) in enumerate(test_loader2):
        data, label = data.to(args.device), label.to(args.device)
        datas2 += [data]
        labels2 += [label]
        z = torch.randn(len(data), args.z_dim, device=args.device)
        gen = generatorYX(data, z)
        gens2 += [gen]
    datas2 = torch.cat(datas2)
    labels2 = torch.cat(labels2)
    gens2 = torch.cat(gens2)

    #fid = calculate_fid(datas1[:1000], datas1[1000:2000], 50, args.device, 2048)
    #print(f'fid datasetX: {fid}')
    #fid = calculate_fid(datas2[:1000], datas2[1000:2000], 50, args.device, 2048)
    #print(f'fid datasetY: {fid}')
    fid = calculate_fid(datas1, gens2, 50, args.device, 2048)
    save_path = args.save_path
    with open(os.path.join(save_path, 'fid_yx'), 'w') as f:
        f.write(f'{fid}\n')
    print(f'fid Y->X: {fid}')
    fid = calculate_fid(datas2, gens1, 50, args.device, 2048)
    with open(os.path.join(save_path, 'fid_xy'), 'w') as f:
        f.write(f'{fid}\n')
    print(f'fid X->Y: {fid}')
Ejemplo n.º 3
0
def train(args):
    parameters = vars(args)
    train_loader, valid_loader, test_loader = args.loader

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    classifier = models['classifier'].to(args.device)
    print(classifier)
    optim_classifier = optim.SGD(classifier.parameters(),
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 nesterov=args.nesterov,
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optim_classifier, args.iterations)

    it = iter(train_loader)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    best_accuracy = 0
    for i in range(iteration, args.iterations):
        classifier.train()
        batch, it = sample(it, train_loader)
        data = batch[0].to(args.device)
        classes = batch[1].to(args.device)

        optim_classifier.zero_grad()
        loss = classification_loss(data, classes, classifier)
        loss.backward()
        optim_classifier.step()
        scheduler.step()

        if i % args.evaluate == 0:
            classifier.eval()
            print('Iter: %s' % i, time.time() - t0)
            valid_accuracy = evaluate(args.visualiser, i, valid_loader,
                                      classifier, 'valid', args.device)
            test_accuracy = evaluate(args.visualiser, i, test_loader,
                                     classifier, 'test', args.device)
            loss = loss.detach().cpu().numpy()
            args.visualiser.plot(loss, title=f'loss', step=i)
            if valid_accuracy > best_accuracy:
                best_accuracy = valid_accuracy
                save_models(models, i, args.model_path, args.checkpoint)
            t0 = time.time()
Ejemplo n.º 4
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic1 = models['critic1'].to(args.device)
    critic2 = models['critic2'].to(args.device)
    print(generator)
    print(critic1)

    optim_critic1 = optim.Adam(critic1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_critic2 = optim.Adam(critic2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    iter1, iter2 = iter(train_loader1), iter(train_loader2)
    iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2 = iter(test_loader1), iter(test_loader2)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic1.train()
        critic2.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data1 = batchx.to(args.device)
            batchy, iter2 = sample(iter2, train_loader2)
            data2 = batchy.to(args.device)

            optim_critic1.zero_grad()
            r_loss1 = critic_loss(data1, data1, args.alpha, critic1, generator, args.device)
            r_loss1.backward(mone)
            optim_critic1.step()

            optim_critic2.zero_grad()
            r_loss2 = critic_loss(data1, data2, args.alpha, critic2, generator, args.device)
            r_loss2.backward(mone)
            optim_critic2.step()

        optim_generator.zero_grad()
        for _ in range(10):
            t_ = torch.distributions.beta.Beta(args.alpha, args.alpha).sample_n(1).to(args.device)
            t = torch.stack([t_]*data1.shape[0])
            t_loss1 = transfer_loss(data1, data1, t, critic1, generator, args.device)**2
            t_loss2 = transfer_loss(data1, data2, t, critic2, generator, args.device)**2
            ((1-t_)*t_loss1 + t_*t_loss2).backward()

        #t_ = torch.FloatTensor([0]).to(args.device)
        #t = torch.stack([t_]*data1.shape[0])
        #gen = generator(data1, t)
        #reg = F.mse_loss(gen, data1)
        #(10*reg).backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            data1 = batchx.to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            data2 = batchy.to(args.device)
            evaluate(args.visualiser, data1, data2, args.z_dim, generator, i, args.device)
            args.visualiser.plot(step=i, data=r_loss1.detach().cpu().numpy(), title=f'Critic loss 1')
            args.visualiser.plot(step=i, data=r_loss2.detach().cpu().numpy(), title=f'Critic loss 2')
            args.visualiser.plot(step=i, data=t_loss1.detach().cpu().numpy(), title=f'Generator loss 1')
            args.visualiser.plot(step=i, data=t_loss2.detach().cpu().numpy(), title=f'Generator loss 2')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 5
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    g12 = models['g12'].to(args.device)
    g21 = models['g21'].to(args.device)
    d1 = models['d1'].to(args.device)
    d2 = models['d2'].to(args.device)
    eval_model = args.evaluation.eval().to(args.device)
    print(g12)
    print(g21)
    print(d1)
    print(d2)

    optim_g12 = optim.Adam(g12.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2))
    optim_g21 = optim.Adam(g21.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2))
    optim_d1 = optim.Adam(d1.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2))
    optim_d2 = optim.Adam(d2.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    titer1 = iter(test_loader1)
    titer2 = iter(test_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        g12.train()
        g21.train()
        d1.train()
        d2.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data1 = batchx[0].to(args.device)
        if data1.shape[0] != args.train_batch_size:
            batchx, iter1 = sample(iter1, train_loader1)
            data1 = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        data2 = batchy[0].to(args.device)
        if data2.shape[0] != args.train_batch_size:
            batchy, iter2 = sample(iter2, train_loader2)
            data2 = batchy[0].to(args.device)

        dloss1 = disc_loss(data2, data1, g21, d1)
        optim_d1.zero_grad()
        dloss1.backward()
        optim_d1.step()

        dloss2 = disc_loss(data1, data2, g12, d2)
        optim_d2.zero_grad()
        dloss2.backward()
        optim_d2.step()

        gloss1 = generator_loss(data1, g12, g21, d2)
        optim_g12.zero_grad()
        optim_g21.zero_grad()
        gloss1.backward()
        optim_g12.step()
        optim_g21.step()
        gloss2 = generator_loss(data2, g21, g12, d1)
        optim_g12.zero_grad()
        optim_g21.zero_grad()
        gloss2.backward()
        optim_g12.step()
        optim_g21.step()

        if i % args.evaluate == 0:
            g12.eval()
            g21.eval()
            batchx, titer1 = sample(titer1, test_loader1)
            data1 = batchx[0].to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            data2 = batchy[0].to(args.device)
            print('Iter: %s' % i, time.time() - t0)
            evaluate(args.visualiser, data1, g12, 'x')
            evaluate(args.visualiser, data2, g21, 'y')
            evaluate_gen_class_accuracy(args.visualiser, i, test_loader1,
                                        eval_model, g12, args.device)
            evaluate_class_accuracy(args.visualiser, i, test_loader2,
                                    eval_model, args.device)
            args.visualiser.plot(dloss1.cpu().detach().numpy(),
                                 title=f'Discriminator loss1',
                                 step=i)
            args.visualiser.plot(dloss2.cpu().detach().numpy(),
                                 title=f'Discriminator loss2',
                                 step=i)
            args.visualiser.plot(gloss1.cpu().detach().numpy(),
                                 title=f'Generator loss 1-2',
                                 step=i)
            args.visualiser.plot(gloss2.cpu().detach().numpy(),
                                 title=f'Generator loss 2-1',
                                 step=i)
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 6
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    encoder = models['classifier'].to(args.device)
    contrastive = models['contrastive'].to(args.device)
    print(encoder)
    print(contrastive)

    optim_encoder = optim.Adam(encoder.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_contrastive = optim.Adam(contrastive.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        encoder.train()
        contrastive.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            datax = batchx[0].float().to(args.device)

            optim_contrastive.zero_grad()
            ploss, nloss, gp = contrastive_loss(datax, args.nc, encoder,
                                                contrastive, args.device)
            (ploss - nloss + gp).backward()
            optim_contrastive.step()

        optim_encoder.zero_grad()

        batchx, iter1 = sample(iter1, train_loader1)
        datax = batchx[0].float().to(args.device)
        dataxp = batchx[1].float().to(args.device)

        dloss, closs = compute_loss(datax, dataxp, encoder, contrastive,
                                    args.device)
        (args.ld * dloss + closs).backward()
        optim_encoder.step()

        if i % args.evaluate == 0:
            encoder.eval()
            contrastive.eval()
            print('Iter: {}'.format(i), end=': ')
            evaluate(args.visualiser, datax, dataxp, 'x')
            _acc = evaluate_accuracy(args.visualiser, i, test_loader1, encoder,
                                     args.nc, 'x', args.device)
            print('disc loss: {}'.format(
                (ploss - nloss).detach().cpu().numpy()),
                  end='\t')
            print('gp: {}'.format(gp.detach().cpu().numpy()), end='\t')
            print('positive dist loss: {}'.format(
                dloss.detach().cpu().numpy()),
                  end='\t')
            print('contrast. loss: {}'.format(closs.detach().cpu().numpy()),
                  end='\t')
            print('Accuracy: {}'.format(_acc))

            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 7
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    ssx = args.ssx.to(args.device)
    ssx.eval()

    zxs, labelsx = get_initial_zx(train_loader1, ssx, args.device)
    zys, labelsy = get_initial_zx(train_loader2, ssx, args.device)

    sc = SpectralClustering(args.nc, affinity='sigmoid', gamma=1.7)
    clusters = sc.fit_predict(zxs.cpu().numpy())
    clusters = torch.from_numpy(clusters).to(args.device)

    classifier = models['classifier'].to(args.device)
    discriminator = models['discriminator'].to(args.device)
    classifier.apply(he_init)
    discriminator.apply(he_init)
    print(classifier)
    print(discriminator)

    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))
    optim_classifier = optim.Adam(classifier.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
    optims = {
        'optim_discriminator': optim_discriminator,
        'optim_classifier': optim_classifier
    }

    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        classifier.train()
        discriminator.train()

        perm = torch.randperm(len(zxs))
        ix = perm[:args.train_batch_size]
        zx = zxs[ix]
        perm = torch.randperm(len(zys))
        iy = perm[:args.train_batch_size]
        zy = zys[iy]

        optim_discriminator.zero_grad()
        d_loss = disc_loss(zx, zy, discriminator, classifier.x, classifier.mlp,
                           args.device)
        d_loss.backward()
        optim_discriminator.step()

        perm = torch.randperm(len(zxs))
        ix = perm[:args.train_batch_size]
        zx = zxs[ix]
        label = clusters[ix].long()
        perm = torch.randperm(len(zys))
        iy = perm[:args.train_batch_size]
        zy = zys[iy]

        optim_classifier.zero_grad()
        c_loss = classification_loss(zx, label, classifier)
        tcw_loss = classification_target_loss(zy, classifier)
        dw_loss = embed_div_loss(zx, zy, discriminator, classifier.x,
                                 classifier.mlp, args.device)
        m_loss1 = mixup_loss(zx, classifier, args.device)
        m_loss2 = mixup_loss(zy, classifier, args.device)
        (args.cw * c_loss).backward()
        (args.tcw * tcw_loss).backward()
        (args.dw * dw_loss).backward()
        (args.smw * m_loss1).backward()
        (args.tmw * m_loss2).backward()
        optim_classifier.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            classifier.eval()

            class_map = evaluate_cluster(args.visualiser, i, args.nc, zxs,
                                         labelsx, classifier, f'x',
                                         args.device)
            evaluate_cluster_accuracy(args.visualiser, i, zxs, labelsx,
                                      class_map, classifier, f'x', args.device)
            evaluate_cluster_accuracy(args.visualiser, i, zys, labelsy,
                                      class_map, classifier, f'y', args.device)

            save_path = args.save_path
            with open(os.path.join(save_path, 'c_loss'), 'a') as f:
                f.write(f'{i},{c_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'tcw_loss'), 'a') as f:
                f.write(f'{i},{tcw_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'dw_loss'), 'a') as f:
                f.write(f'{i},{dw_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'm_loss1'), 'a') as f:
                f.write(f'{i},{m_loss1.cpu().item()}\n')
            with open(os.path.join(save_path, 'm_loss2'), 'a') as f:
                f.write(f'{i},{m_loss2.cpu().item()}\n')
            with open(os.path.join(save_path, 'd_loss2'), 'a') as f:
                f.write(f'{i},{d_loss.cpu().item()}\n')
            args.visualiser.plot(c_loss.cpu().detach().numpy(),
                                 title='Source classifier loss',
                                 step=i)
            args.visualiser.plot(tcw_loss.cpu().detach().numpy(),
                                 title='Target classifier cross entropy',
                                 step=i)
            args.visualiser.plot(dw_loss.cpu().detach().numpy(),
                                 title='Classifier marginal divergence',
                                 step=i)
            args.visualiser.plot(m_loss1.cpu().detach().numpy(),
                                 title='Source mix up loss',
                                 step=i)
            args.visualiser.plot(m_loss2.cpu().detach().numpy(),
                                 title='Target mix up loss',
                                 step=i)
            args.visualiser.plot(d_loss.cpu().detach().numpy(),
                                 title='Discriminator loss',
                                 step=i)
            t0 = time.time()
            save_models(models, i, args.model_path, args.evaluate)
            save_models(optims, i, args.model_path, args.evaluate)
Ejemplo n.º 8
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2
    train_loader3, test_loader3 = args.loaders3
    train_loader4, test_loader4 = args.loaders4

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    criticx1 = models['criticx1'].to(args.device)
    criticx2 = models['criticx2'].to(args.device)
    criticy1 = models['criticy1'].to(args.device)
    criticy2 = models['criticy2'].to(args.device)
    criticz1 = models['criticz1'].to(args.device)
    criticz2 = models['criticz2'].to(args.device)
    print(generator)
    print(criticx1)
    print(criticy1)

    optim_criticx1 = optim.Adam(criticx1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticx2 = optim.Adam(criticx2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticy1 = optim.Adam(criticy1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticy2 = optim.Adam(criticy2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticz1 = optim.Adam(criticz1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticz2 = optim.Adam(criticz2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    iter1, iter2, iter3, iter4 = iter(train_loader1), iter(train_loader2), iter(train_loader3), iter(train_loader4)
    iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2, titer3, titer4 = iter(test_loader1), iter(test_loader2), iter(test_loader3), iter(test_loader4)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()

    generator.train()
    criticx1.train()
    criticx2.train()
    criticy1.train()
    criticy2.train()
    criticz1.train()
    criticz2.train()
    for i in range(0):
        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)
        datay[:,1] = datay[:,1]*0
        datay[:,2] = datay[:,2]*0

        batchz, iter3 = sample(iter3, train_loader3)
        dataz = batchz[0].to(args.device)
        dataz[:,0] = dataz[:,0]*0
        dataz[:,2] = dataz[:,2]*0

        batchw, iter4 = sample(iter4, train_loader4)
        dataw = batchw[0].to(args.device)
        dataw[:,0] = dataw[:,0]*0
        dataw[:,1] = dataw[:,1]*0

        optim_criticx1.zero_grad()
        optim_criticx2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps, args.lp, criticx1, criticx2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticx1.step()
        optim_criticx2.step()

        optim_criticy1.zero_grad()
        optim_criticy2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, dataz, args.eps, args.lp, criticy1, criticy2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticy1.step()
        optim_criticy2.step()

        optim_criticz1.zero_grad()
        optim_criticz2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, dataw, args.eps, args.lp, criticz1, criticz2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticz1.step()
        optim_criticz2.step()

        if i % 100 == 0:
            print(f'Critics-{i}')
            print('Iter: %s' % i, time.time() - t0)
            args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty')
            d_loss = (r_loss+g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y')
            t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        criticx1.train()
        criticx2.train()
        criticy1.train()
        criticy2.train()
        criticz1.train()
        criticz2.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)

            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy[0].to(args.device)
            datay[:,1] = datay[:,1]*0
            datay[:,2] = datay[:,2]*0

            batchz, iter3 = sample(iter3, train_loader3)
            dataz = batchz[0].to(args.device)
            dataz[:,0] = dataz[:,0]*0
            dataz[:,2] = dataz[:,2]*0

            batchw, iter4 = sample(iter4, train_loader4)
            dataw = batchw[0].to(args.device)
            dataw[:,0] = dataw[:,0]*0
            dataw[:,1] = dataw[:,1]*0

            optim_criticx1.zero_grad()
            optim_criticx2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps, args.lp, criticx1, criticx2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticx1.step()
            optim_criticx2.step()

            optim_criticy1.zero_grad()
            optim_criticy2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, dataz, args.eps, args.lp, criticy1, criticy2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticy1.step()
            optim_criticy2.step()

            optim_criticz1.zero_grad()
            optim_criticz2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, dataw, args.eps, args.lp, criticz1, criticz2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticz1.step()
            optim_criticz2.step()

        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)
        datay[:,1] = datay[:,1]*0
        datay[:,2] = datay[:,2]*0

        batchz, iter3 = sample(iter3, train_loader3)
        dataz = batchz[0].to(args.device)
        dataz[:,0] = dataz[:,0]*0
        dataz[:,2] = dataz[:,2]*0

        batchw, iter4 = sample(iter4, train_loader4)
        dataw = batchw[0].to(args.device)
        dataw[:,0] = dataw[:,0]*0
        dataw[:,1] = dataw[:,1]*0

        optim_generator.zero_grad()
        t_ = Dirichlet(torch.FloatTensor([1.,1.,1.])).sample().to(args.device)
        t = torch.stack([t_]*data.shape[0])
        t_lossx = transfer_loss(data, datay, t, args.eps, args.lp, criticx1, criticx2, generator)
        t_lossy = transfer_loss(data, dataz, t, args.eps, args.lp, criticy1, criticy2, generator)
        t_lossz = transfer_loss(data, dataw, t, args.eps, args.lp, criticz1, criticz2, generator)
        t_loss = (t_[0]*t_lossx + t_[1]*t_lossy + t_[2]*t_lossz).sum()
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            #batchx, titer1 = sample(titer1, test_loader1)
            #datax = batchx[0].to(args.device)
            #batchy, titer2 = sample(titer2, test_loader2)
            #datay = batchy[0].to(args.device)
            evaluate(args.visualiser, data, datay, dataz, dataw, generator, 'x', args.device)
            args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss X')
            args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss Y')
            args.visualiser.plot(step=i, data=t_lossz.detach().cpu().numpy(), title=f'Generator loss Z')
            args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty')
            d_loss = (r_loss+g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 9
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic = models['critic'].to(args.device)
    encoder = models['encoder'].to(args.device)
    print(generator)
    print(critic)
    print(encoder)

    optim_critic = optim.Adam(critic.parameters(),
                              lr=args.lr,
                              betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))
    optim_encoder = optim.Adam(encoder.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1 = iter(test_loader1)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic.train()
        encoder.train()

        for _ in range(10):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            optim_encoder.zero_grad()
            optim_generator.zero_grad()
            e_loss = encoder_loss(data.shape[0], args.lp, args.z_dim, encoder,
                                  generator, critic, args.device)
            e_loss.backward()
            optim_encoder.step()
            optim_generator.step()

        for _ in range(1):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            optim_critic.zero_grad()
            r_loss = critic_loss(data, args.lp, args.z_dim, encoder, critic,
                                 generator, args.device)
            r_loss.backward(mone)
            optim_critic.step()

        for _ in range(1):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            optim_generator.zero_grad()
            t_loss = transfer_loss(data.shape[0], args.lp, args.z_dim, encoder,
                                   critic, generator, args.device)
            t_loss.backward()
            optim_generator.step()

        if i % args.evaluate == 0:
            encoder.eval()
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            data = batchx[0].to(args.device)
            evaluate(args.visualiser, args.z_dim, data, encoder, generator,
                     critic, args.z_dim, i, args.device)
            d_loss = (r_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss')
            args.visualiser.plot(step=i,
                                 data=e_loss.detach().cpu().numpy(),
                                 title=f'Encoder loss')
            args.visualiser.plot(step=i,
                                 data=t_loss.detach().cpu().numpy(),
                                 title=f'Generator loss')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 10
0
def train(args):
    parameters = vars(args)
    train_loader1, valid_loader1, test_loader1 = args.loaders1
    train_loader2, valid_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    critic = models['critic'].to(args.device)
    generator = models['generator'].to(args.device)
    evalY = args.evalY.to(args.device).eval()
    semantic = args.semantic.to(args.device).eval()
    print(generator)
    print(critic)
    print(semantic)

    optim_critic = optim.Adam(critic.parameters(),
                              lr=args.lr,
                              betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        critic.train()
        generator.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            datax = batchx[0].to(args.device)

            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy[0].to(args.device)

            critic_lossy = compute_critic_loss(datax, args.z_dim, datay,
                                               critic, generator, args.device)
            optim_critic.zero_grad()
            critic_lossy.backward()
            optim_critic.step()

        batchx, iter1 = sample(iter1, train_loader1)
        datax = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)

        glossxy = generator_loss(datax, args.z_dim, critic, generator,
                                 args.device)
        slossxy = semantic_loss(datax, args.z_dim, generator, semantic,
                                args.device)
        optim_generator.zero_grad()
        glossxy.backward()
        (args.gsxy * slossxy).backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            generator.eval()
            save_path = args.save_path
            plot_transfer(args.visualiser, datax, datay, args.z_dim, generator,
                          'x-y', i, args.device)
            test_accuracy_xy = evaluate(test_loader1, args.z_dim, generator,
                                        evalY, args.device)
            with open(os.path.join(save_path, 'glossxy'), 'a') as f:
                f.write(f'{i},{glossxy.cpu().item()}\n')
            with open(os.path.join(save_path, 'slossxy'), 'a') as f:
                f.write(f'{i},{slossxy.cpu().item()}\n')
            with open(os.path.join(save_path, 'test_accuracy_xy'), 'a') as f:
                f.write(f'{i},{test_accuracy_xy}\n')
            args.visualiser.plot(critic_lossy.cpu().detach().numpy(),
                                 title='critic_lossy',
                                 step=i)
            args.visualiser.plot(glossxy.cpu().detach().numpy(),
                                 title='glossxy',
                                 step=i)
            args.visualiser.plot(slossxy.cpu().detach().numpy(),
                                 title='slossxy',
                                 step=i)
            args.visualiser.plot(test_accuracy_xy,
                                 title=f'Test transfer accuracy X-Y',
                                 step=i)
            t0 = time.time()
            save_models(models, 0, args.model_path, args.checkpoint)
Ejemplo n.º 11
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    classifier = models['classifier'].to(args.device)
    discriminator = models['discriminator'].to(args.device)
    cluster = args.cluster.eval().to(args.device)
    print(classifier)
    print(discriminator)

    optim_classifier = optim.Adam(classifier.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    titer1 = iter(test_loader1)
    titer2 = iter(test_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        classifier.train()
        discriminator.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data1 = batchx[0].to(args.device)
        if data1.shape[0] != args.train_batch_size:
            batchx, iter1 = sample(iter1, train_loader1)
            data1 = batchx[0].to(args.device)
        label = cluster(data1).argmax(1).detach()

        batchy, iter2 = sample(iter2, train_loader2)
        data2 = batchy[0].to(args.device)
        if data2.shape[0] != args.train_batch_size:
            batchy, iter2 = sample(iter2, train_loader2)
            data2 = batchy[0].to(args.device)

        optim_discriminator.zero_grad()
        d_loss = disc_loss(data1, data2, discriminator, classifier.x,
                           classifier.mlp, args.device)
        d_loss.backward()
        optim_discriminator.step()

        optim_classifier.zero_grad()
        c_loss = classification_loss(data1, label, classifier)
        tcw_loss = classification_target_loss(data2, classifier)
        dw_loss = embed_div_loss(data1, data2, discriminator, classifier.x,
                                 classifier.mlp, args.device)
        v_loss1 = vat_loss(data1, classifier, args.radius, args.device)
        v_loss2 = vat_loss(data2, classifier, args.radius, args.device)
        m_loss1 = mixup_loss(data1, classifier, args.device)
        m_loss2 = mixup_loss(data2, classifier, args.device)
        (args.cw * c_loss).backward()
        (args.tcw * tcw_loss).backward()
        (args.dw * dw_loss).backward()
        (args.svw * v_loss1).backward()
        (args.tvw * v_loss2).backward()
        (args.smw * m_loss1).backward()
        (args.tmw * m_loss2).backward()
        optim_classifier.step()

        if i % args.evaluate == 0:
            classifier.eval()
            batchx, titer1 = sample(titer1, test_loader1)
            data1 = batchx[0].to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            data2 = batchy[0].to(args.device)
            print('Iter: %s' % i, time.time() - t0)
            class_map = evaluate_cluster(args.visualiser, i, args.nc,
                                         test_loader1, cluster, f'x',
                                         args.device)
            evaluate_cluster_accuracy(args.visualiser, i, test_loader1,
                                      class_map, classifier, f'x', args.device)
            evaluate_cluster_accuracy(args.visualiser, i, test_loader2,
                                      class_map, classifier, f'y', args.device)
            args.visualiser.plot(c_loss.cpu().detach().numpy(),
                                 title=f'Classifier loss',
                                 step=i)
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 12
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    print(generator)

    optim_generator = optim.SGD(generator.parameters(),
                                lr=args.lr)  #, betas=(args.beta1, args.beta2))

    iter1, iter2 = iter(train_loader1), iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2 = iter(test_loader1), iter(test_loader2)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx.to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy.to(args.device)

        optim_generator.zero_grad()
        x = generator(data)
        t_lossx = sinkhorn_loss(x,
                                data,
                                args.eps,
                                data.shape[0],
                                100,
                                args.device,
                                p=args.lp)**args.p_exp
        t_lossy = sinkhorn_loss(x,
                                datay,
                                args.eps,
                                data.shape[0],
                                100,
                                args.device,
                                p=args.lp)**args.p_exp
        ((1 - args.alphat) * t_lossx + args.alphat * t_lossy).backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            datax = batchx.to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            datay = batchy.to(args.device)
            evaluate(args.visualiser, datax, datax, datay, generator, 'x')
            args.visualiser.plot(step=i,
                                 data=t_lossx.detach().cpu().numpy(),
                                 title=f'Generator loss y')
            args.visualiser.plot(step=i,
                                 data=t_lossy.detach().cpu().numpy(),
                                 title=f'Generator loss x')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 13
0
def train(args):
    parameters = vars(args)
    train_loader1, valid_loader1, test_loader1 = args.loaders1
    train_loader2, valid_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    criticX = models['criticX'].to(args.device)
    criticY = models['criticY'].to(args.device)
    generatorXY = models['generatorXY'].to(args.device)
    generatorYX = models['generatorYX'].to(args.device)
    evalX = args.evalX.to(args.device).eval()
    evalY = args.evalY.to(args.device).eval()
    classifier = args.classifier.to(args.device).eval()
    print(generatorXY)
    print(criticX)
    print(classifier)

    optim_criticX = optim.Adam(criticX.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_criticY = optim.Adam(criticY.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_generatorXY = optim.Adam(generatorXY.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta1, args.beta2))
    optim_generatorYX = optim.Adam(generatorYX.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        criticX.train()
        criticY.train()
        generatorXY.train()
        generatorYX.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            datax = batchx[0].to(args.device)
            if datax.shape[0] != args.train_batch_size:
                batchx, iter1 = sample(iter1, train_loader1)
                datax = batchx[0].to(args.device)

            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy[0].to(args.device)
            if datay.shape[0] != args.train_batch_size:
                batchy, iter2 = sample(iter2, train_loader2)
                datay = batchy[0].to(args.device)

            #critic_lossx = compute_critic_loss(datay, args.z_dim, datax, criticX, generatorYX, args.device)
            #optim_criticX.zero_grad()
            #critic_lossx.backward()
            #optim_criticX.step()

            critic_lossy = compute_critic_loss(datax, args.z_dim, datay,
                                               criticY, generatorXY,
                                               args.device)
            optim_criticY.zero_grad()
            critic_lossy.backward()
            optim_criticY.step()

        batchx, iter1 = sample(iter1, train_loader1)
        datax = batchx[0].to(args.device)
        if datax.shape[0] != args.train_batch_size:
            batchx, iter1 = sample(iter1, train_loader1)
            datax = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)
        if datay.shape[0] != args.train_batch_size:
            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy[0].to(args.device)

        glossxy = generator_loss(datax, args.z_dim, criticY, generatorXY,
                                 args.device)
        slossxy = semantic_loss(datax, args.z_dim, generatorXY, classifier,
                                args.device)
        #cyclelossxy = cycle_loss(datax, args.z_dim, generatorXY, generatorYX, args.device)
        #idlossxy = identity_loss(datax, args.z_dim, generatorXY, args.device)
        optim_generatorXY.zero_grad()
        glossxy.backward()
        (args.gsxy * slossxy).backward()
        #(args.gcxy*cyclelossxy).backward()
        #(args.gixy*idlossxy).backward()
        optim_generatorXY.step()

        #glossyx = generator_loss(datay, args.z_dim, criticX, generatorYX, args.device)
        #slossyx = semantic_loss(datay, args.z_dim, generatorYX, classifier, args.device)
        #cyclelossyx = cycle_loss(datay, args.z_dim, generatorYX, generatorXY, args.device)
        #idlossyx = identity_loss(datay, args.z_dim, generatorYX, args.device)
        #optim_generatorYX.zero_grad()
        #glossyx.backward()
        #(args.gsyx*slossyx).backward()
        #(args.gcyx*cyclelossyx).backward()
        #(args.giyx*idlossyx).backward()
        #optim_generatorYX.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            generatorXY.eval()
            generatorYX.eval()
            save_path = args.save_path
            plot_transfer(args.visualiser, datax, datay, args.z_dim,
                          generatorXY, 'x-y', i, args.device)
            #plot_transfer(args.visualiser, datay, datax, args.z_dim, generatorYX, 'y-x', i, args.device)
            #eval_accuracy_xy = evaluate(valid_loader1, args.z_dim, generatorXY, evalY, args.device)
            #eval_accuracy_yx = evaluate(valid_loader2, args.z_dim, generatorYX, evalX, args.device)
            test_accuracy_xy = evaluate(test_loader1, args.z_dim, generatorXY,
                                        evalY, args.device)
            #test_accuracy_yx = evaluate(test_loader2, args.z_dim, generatorYX, evalX, args.device)

            #with open(os.path.join(save_path, 'critic_lossx'), 'a') as f: f.write(f'{i},{critic_lossx.cpu().item()}\n')
            #with open(os.path.join(save_path, 'critic_lossy'), 'a') as f: f.write(f'{i},{critic_lossy.cpu().item()}\n')
            with open(os.path.join(save_path, 'glossxy'), 'a') as f:
                f.write(f'{i},{glossxy.cpu().item()}\n')
            #with open(os.path.join(save_path, 'glossyx'), 'a') as f: f.write(f'{i},{glossyx.cpu().item()}\n')
            with open(os.path.join(save_path, 'slossxy'), 'a') as f:
                f.write(f'{i},{slossxy.cpu().item()}\n')
            #with open(os.path.join(save_path, 'slossyx'), 'a') as f: f.write(f'{i},{slossyx.cpu().item()}\n')
            #with open(os.path.join(save_path, 'idlossxy'), 'a') as f: f.write(f'{i},{idlossxy.cpu().item()}')
            #with open(os.path.join(save_path, 'idlossyx'), 'a') as f: f.write(f'{i},{idlossyx.cpu().item()}')
            #with open(os.path.join(save_path, 'cyclelossxy'), 'a') as f: f.write(f'{i},{cyclelossxy.cpu().item()}\n')
            #with open(os.path.join(save_path, 'cyclelossyx'), 'a') as f: f.write(f'{i},{cyclelossyx.cpu().item()}\n')
            #with open(os.path.join(save_path, 'eval_accuracy_xy'), 'a') as f: f.write(f'{i},{eval_accuracy_xy}\n')
            #with open(os.path.join(save_path, 'eval_accuracy_yx'), 'a') as f: f.write(f'{i},{eval_accuracy_yx}\n')
            with open(os.path.join(save_path, 'test_accuracy_xy'), 'a') as f:
                f.write(f'{i},{test_accuracy_xy}\n')
            #with open(os.path.join(save_path, 'test_accuracy_yx'), 'a') as f: f.write(f'{i},{test_accuracy_yx}\n')
            #args.visualiser.plot(critic_lossx.cpu().detach().numpy(), title='critic_lossx', step=i)
            args.visualiser.plot(critic_lossy.cpu().detach().numpy(),
                                 title='critic_lossy',
                                 step=i)
            args.visualiser.plot(glossxy.cpu().detach().numpy(),
                                 title='glossxy',
                                 step=i)
            #args.visualiser.plot(glossyx.cpu().detach().numpy(), title='glossyx', step=i)
            args.visualiser.plot(slossxy.cpu().detach().numpy(),
                                 title='slossxy',
                                 step=i)
            #args.visualiser.plot(slossyx.cpu().detach().numpy(), title='slossyx', step=i)
            #args.visualiser.plot(idlossxy.cpu().detach().numpy(), title='idlossxy', step=i)
            #args.visualiser.plot(idlossyx.cpu().detach().numpy(), title='idlossyx', step=i)
            #args.visualiser.plot(cyclelossxy.cpu().detach().numpy(), title='cyclelossxy', step=i)
            #args.visualiser.plot(cyclelossyx.cpu().detach().numpy(), title='cyclelossyx', step=i)
            #args.visualiser.plot(eval_accuracy_xy, title=f'Validation transfer accuracy X-Y', step=i)
            #args.visualiser.plot(eval_accuracy_yx, title=f'Validation transfer accuracy Y-X', step=i)
            args.visualiser.plot(test_accuracy_xy,
                                 title=f'Test transfer accuracy X-Y',
                                 step=i)
            #args.visualiser.plot(test_accuracy_yx, title=f'Test transfer accuracy Y-X', step=i)
            t0 = time.time()
            save_models(models, 0, args.model_path, args.checkpoint)
Ejemplo n.º 14
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic1 = models['critic1'].to(args.device)
    critic2 = models['critic2'].to(args.device)
    print(generator)
    print(critic1)
    print(critic2)

    optim_critic1 = optim.Adam(critic1.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_critic2 = optim.Adam(critic2.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic1.train()
        critic2.train()
        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            if data.shape[0] != args.train_batch_size:
                batchx, iter1 = sample(iter1, train_loader1)
                data = batchx[0].to(args.device)

            optim_critic1.zero_grad()
            optim_critic2.zero_grad()
            z = torch.randn(data.shape[0], args.z_dim, device=args.device)
            gen1 = generator(z).detach()
            r_loss, g_loss, p = disc_loss_generation(data, gen1, args.eps,
                                                     args.lp, critic1, critic2)
            (r_loss + g_loss + p).backward(mone)
            optim_critic1.step()
            optim_critic2.step()

        optim_generator.zero_grad()
        t_loss = transfer_loss(data, args.eps, args.lp, args.z_dim, critic1,
                               critic2, generator, args.device)
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            noise = torch.randn(args.test_batch_size,
                                args.z_dim,
                                device=args.device)
            evaluate(args.visualiser, noise, data[:args.test_batch_size],
                     generator, i)
            d_loss = (r_loss + g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss')
            args.visualiser.plot(step=i,
                                 data=t_loss.detach().cpu().numpy(),
                                 title=f'Generator loss')
            args.visualiser.plot(step=i,
                                 data=p.detach().cpu().numpy(),
                                 title=f'Penalty')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 15
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    classifier = models['classifier'].to(args.device)
    discriminator = models['discriminator'].to(args.device)
    print(classifier)
    print(discriminator)

    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))
    optim_classifier = optim.Adam(classifier.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
    optims = {
        'optim_discriminator': optim_discriminator,
        'optim_classifier': optim_classifier
    }
    initialize(optims, args.reload, args.save_path, args.model_path)

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        classifier.train()
        discriminator.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data1 = batchx[0].to(args.device)
        if data1.shape[0] != args.train_batch_size:
            batchx, iter1 = sample(iter1, train_loader1)
            data1 = batchx[0].to(args.device)
        label = batchx[1].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        data2 = batchy[0].to(args.device)
        if data2.shape[0] != args.train_batch_size:
            batchy, iter2 = sample(iter2, train_loader2)
            data2 = batchy[0].to(args.device)

        optim_discriminator.zero_grad()
        d_loss = disc_loss(data1, data2, discriminator, classifier.x,
                           classifier.mlp, args.device)
        d_loss.backward()
        optim_discriminator.step()

        optim_classifier.zero_grad()
        c_loss = classification_loss(data1, label, classifier)
        tcw_loss = classification_target_loss(data2, classifier)
        dw_loss = embed_div_loss(data1, data2, discriminator, classifier.x,
                                 classifier.mlp, args.device)
        v_loss1 = vat_loss(data1, classifier, args.radius, args.device)
        v_loss2 = vat_loss(data2, classifier, args.radius, args.device)
        m_loss1 = mixup_loss(data1, classifier, args.device)
        m_loss2 = mixup_loss(data2, classifier, args.device)
        (args.cw * c_loss).backward()
        (args.tcw * tcw_loss).backward()
        (args.dw * dw_loss).backward()
        (args.svw * v_loss1).backward()
        (args.tvw * v_loss2).backward()
        (args.smw * m_loss1).backward()
        (args.tmw * m_loss2).backward()
        optim_classifier.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            classifier.eval()

            test_accuracy_x = evaluate(test_loader1, classifier, args.device)
            test_accuracy_y = evaluate(test_loader2, classifier, args.device)

            save_path = args.save_path
            with open(os.path.join(save_path, 'c_loss'), 'a') as f:
                f.write(f'{i},{c_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'tcw_loss'), 'a') as f:
                f.write(f'{i},{tcw_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'dw_loss'), 'a') as f:
                f.write(f'{i},{dw_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'v_loss1'), 'a') as f:
                f.write(f'{i},{v_loss1.cpu().item()}\n')
            with open(os.path.join(save_path, 'v_loss2'), 'a') as f:
                f.write(f'{i},{v_loss2.cpu().item()}\n')
            with open(os.path.join(save_path, 'm_loss1'), 'a') as f:
                f.write(f'{i},{m_loss1.cpu().item()}\n')
            with open(os.path.join(save_path, 'm_loss2'), 'a') as f:
                f.write(f'{i},{m_loss2.cpu().item()}\n')
            with open(os.path.join(save_path, 'd_loss2'), 'a') as f:
                f.write(f'{i},{d_loss.cpu().item()}\n')
            #with open(os.path.join(save_path, 'eval_accuracy_x'), 'a') as f: f.write(f'{i},{eval_accuracy_x}\n')
            #with open(os.path.join(save_path, 'eval_accuracy_y'), 'a') as f: f.write(f'{i},{eval_accuracy_y}\n')
            with open(os.path.join(save_path, 'test_accuracy_x'), 'a') as f:
                f.write(f'{i},{test_accuracy_x}\n')
            with open(os.path.join(save_path, 'test_accuracy_y'), 'a') as f:
                f.write(f'{i},{test_accuracy_y}\n')
            args.visualiser.plot(c_loss.cpu().detach().numpy(),
                                 title='Source classifier loss',
                                 step=i)
            args.visualiser.plot(tcw_loss.cpu().detach().numpy(),
                                 title='Target classifier cross entropy',
                                 step=i)
            args.visualiser.plot(dw_loss.cpu().detach().numpy(),
                                 title='Classifier marginal divergence',
                                 step=i)
            args.visualiser.plot(v_loss1.cpu().detach().numpy(),
                                 title='Source virtual adversarial loss',
                                 step=i)
            args.visualiser.plot(v_loss2.cpu().detach().numpy(),
                                 title='Target virtual adversarial loss',
                                 step=i)
            args.visualiser.plot(m_loss1.cpu().detach().numpy(),
                                 title='Source mix up loss',
                                 step=i)
            args.visualiser.plot(m_loss2.cpu().detach().numpy(),
                                 title='Target mix up loss',
                                 step=i)
            args.visualiser.plot(d_loss.cpu().detach().numpy(),
                                 title='Discriminator loss',
                                 step=i)
            #args.visualiser.plot(eval_accuracy_x, title='Eval acc X', step=i)
            #args.visualiser.plot(eval_accuracy_y, title='Eval acc Y', step=i)
            args.visualiser.plot(test_accuracy_x, title='Test acc X', step=i)
            args.visualiser.plot(test_accuracy_y, title='Test acc Y', step=i)
            t0 = time.time()
            save_models(models, i, args.model_path, args.evaluate)
            save_models(optims, i, args.model_path, args.evaluate)
Ejemplo n.º 16
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2
    train_loader3, test_loader3 = args.loaders3
    #train_loader4, test_loader4 = args.loaders4

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    criticx1 = models['criticx1'].to(args.device)
    criticx2 = models['criticx2'].to(args.device)
    criticy1 = models['criticy1'].to(args.device)
    criticy2 = models['criticy2'].to(args.device)
    criticz1 = models['criticz1'].to(args.device)
    criticz2 = models['criticz2'].to(args.device)
    #criticw1 = models['criticw1'].to(args.device)
    #criticw2 = models['criticw2'].to(args.device)
    print(generator)
    print(criticx1)
    print(criticy1)
    print(criticz1)
    #print(criticw1)

    optim_criticx1 = optim.Adam(criticx1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticx2 = optim.Adam(criticx2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticy1 = optim.Adam(criticy1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticy2 = optim.Adam(criticy2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticz1 = optim.Adam(criticz1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticz2 = optim.Adam(criticz2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    #optim_criticw1 = optim.Adam(criticw1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    #optim_criticw2 = optim.Adam(criticw2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1, iter2, iter3 = iter(train_loader1), iter(train_loader2), iter(
        train_loader3)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2, titer3, = iter(test_loader1), iter(test_loader2), iter(
        test_loader3)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        criticx1.train()
        criticx2.train()
        criticy1.train()
        criticy2.train()
        criticz1.train()
        criticz2.train()
        #criticw1.train()
        #criticw2.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx.to(args.device)
            batchx, iter1 = sample(iter1, train_loader1)
            input_data = batchx.to(args.device)
            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy.to(args.device)
            batchz, iter3 = sample(iter3, train_loader3)
            dataz = batchz.to(args.device)
            #batchw, iter4 = sample(iter4, train_loader4)
            #dataw = batchw[0].to(args.device)

            optim_criticx1.zero_grad()
            optim_criticx2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(input_data, data,
                                                     args.eps, args.lp,
                                                     criticx1, criticx2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticx1.step()
            optim_criticx2.step()

            optim_criticy1.zero_grad()
            optim_criticy2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(input_data, datay,
                                                     args.eps, args.lp,
                                                     criticy1, criticy2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticy1.step()
            optim_criticy2.step()

            optim_criticz1.zero_grad()
            optim_criticz2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(input_data, dataz,
                                                     args.eps, args.lp,
                                                     criticz1, criticz2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticz1.step()
            optim_criticz2.step()

            #optim_criticw1.zero_grad()
            #optim_criticw2.zero_grad()
            #r_loss, g_loss, p = disc_loss_generation(input_data, dataw, args.eps, args.lp, criticw1, criticw2)
            #(r_loss + g_loss + p).backward(mone)
            #optim_criticw1.step()
            #optim_criticw2.step()

        optim_generator.zero_grad()
        t_ = Dirichlet(torch.FloatTensor([1., 1.,
                                          1.])).sample().to(args.device)
        t = torch.stack([t_] * input_data.shape[0])
        tinputdata = torch.cat([input_data] * args.nt)
        tdata = torch.cat([data] * args.nt)
        tdatay = torch.cat([datay] * args.nt)
        tdataz = torch.cat([dataz] * args.nt)
        #tdataw = torch.cat([dataw]*args.nt)
        t_lossx = transfer_loss(tinputdata, tdata, args.nt, t, args.eps,
                                args.lp, criticx1, criticx2, generator)
        t_lossy = transfer_loss(tinputdata, tdatay, args.nt, t, args.eps,
                                args.lp, criticy1, criticy2, generator)
        t_lossz = transfer_loss(tinputdata, tdataz, args.nt, t, args.eps,
                                args.lp, criticz1, criticz2, generator)
        #t_lossw = transfer_loss(tinputdata, tdataw, args.nt, t, args.eps, args.lp, criticw1, criticw2, generator)
        t_loss = (t_[0] * t_lossx + t_[1] * t_lossy + t_[2] * t_lossz).sum()
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            datax = batchx.to(args.device)
            #labelx = batchx[1].to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            datay = batchy.to(args.device)
            #labely = batchy[1].to(args.device)
            batchz, titer3 = sample(titer3, test_loader3)
            dataz = batchz.to(args.device)
            #labelz = batchz[1].to(args.device)
            evaluate(args.visualiser, datax, datay, dataz, generator, 'x',
                     args.device)
            #batchw, titerw = sample(titer4, test_loader4)
            #dataw = batchw[0].to(args.device)
            #labelw = batchw[1].to(args.device)
            #evaluate_3d(args.visualiser, datax, datay, dataz, dataw, labelx, labely, labelz, labelw, generator, 'x', args.device)
            d_loss = (r_loss + g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y')
            args.visualiser.plot(step=i,
                                 data=t_lossx.detach().cpu().numpy(),
                                 title=f'Generator loss X')
            args.visualiser.plot(step=i,
                                 data=t_lossy.detach().cpu().numpy(),
                                 title=f'Generator loss Y')
            #args.visualiser.plot(step=i, data=t_lossw.detach().cpu().numpy(), title=f'Generator loss W')
            args.visualiser.plot(step=i,
                                 data=p.detach().cpu().numpy(),
                                 title=f'Penalty')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 17
0
def train(args):
    parameters = vars(args)
    valid_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic = models['critic'].to(args.device)
    eval = args.evaluation.eval().to(args.device)
    print(generator)
    print(critic)

    optim_critic = optim.Adam(critic.parameters(),
                              lr=args.lr,
                              betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter2 = iter(train_loader2)
    titer, titer2 = iter(test_loader1), iter(test_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic.train()
        for _ in range(args.d_updates):
            batch, iter2 = sample(iter2, train_loader2)
            data = batch[0].to(args.device)
            label = corrupt(batch[1], args.nc, args.corrupt_tgt)
            label = one_hot_embedding(label, args.nc).to(args.device)
            optim_critic.zero_grad()
            pos_loss, neg_loss, gp = critic_loss(data, label, args.z_dim,
                                                 critic, generator,
                                                 args.device)
            pos_loss.backward()
            neg_loss.backward(mone)
            (10 * gp).backward()
            optim_critic.step()

        optim_generator.zero_grad()
        t_loss = transfer_loss(data.shape[0], label, args.z_dim, critic,
                               generator, args.device)
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            generator.eval()
            batch, titer = sample(titer, test_loader1)
            data1 = batch[0].to(args.device)
            label = one_hot_embedding(batch[1], args.nc).to(args.device)
            batch, titer = sample(titer2, test_loader2)
            data2 = batch[0].to(args.device)
            plot_transfer(args.visualiser, label, args.nc, data1, data2,
                          args.nz, generator, args.device, i)
            save_path = args.save_path
            eval_accuracy = evaluate(valid_loader1, args.nz, args.nc,
                                     args.corrupt_src, generator, eval,
                                     args.device)
            test_accuracy = evaluate(test_loader1, args.nz, args.nc,
                                     args.corrupt_src, generator, eval,
                                     args.device)
            with open(os.path.join(save_path, 'critic_loss'), 'a') as f:
                f.write(f'{i},{(pos_loss-neg_loss).cpu().item()}\n')
            with open(os.path.join(save_path, 'tloss'), 'a') as f:
                f.write(f'{i},{t_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'eval_accuracy'), 'a') as f:
                f.write(f'{i},{eval_accuracy}\n')
            with open(os.path.join(save_path, 'test_accuracy'), 'a') as f:
                f.write(f'{i},{eval_accuracy}\n')
            args.visualiser.plot((pos_loss - neg_loss).cpu().detach().numpy(),
                                 title='critic_loss',
                                 step=i)
            args.visualiser.plot(t_loss.cpu().detach().numpy(),
                                 title='tloss',
                                 step=i)
            args.visualiser.plot(eval_accuracy,
                                 title=f'Validation transfer accuracy',
                                 step=i)
            args.visualiser.plot(test_accuracy,
                                 title=f'Test transfer accuracy',
                                 step=i)

            t0 = time.time()
            save_models(models, 0, args.model_path, args.checkpoint)
Ejemplo n.º 18
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    criticx1 = models['criticx1'].to(args.device)
    criticx2 = models['criticx2'].to(args.device)
    criticy1 = models['criticy1'].to(args.device)
    criticy2 = models['criticy2'].to(args.device)
    print(generator)
    print(criticx1)
    print(criticx2)
    print(criticy1)
    print(criticy2)

    optim_criticx1 = optim.Adam(criticx1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticx2 = optim.Adam(criticx2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticy1 = optim.Adam(criticy1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticy2 = optim.Adam(criticy2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    iter1, iter2 = iter(train_loader1), iter(train_loader2)
    iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2 = iter(test_loader1), iter(test_loader2)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        criticx1.train()
        criticx2.train()
        criticy1.train()
        criticy2.train()
        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx.to(args.device)
            optim_criticx1.zero_grad()
            optim_criticx2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, data, args.z_dim, args.eps, args.lp, criticx1, criticx2, generator, args.device)
            (r_loss + g_loss + p).backward(mone)
            optim_criticx1.step()
            optim_criticx2.step()

            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy.to(args.device)
            optim_criticy1.zero_grad()
            optim_criticy2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, datay, args.z_dim, args.eps, args.lp, criticy1, criticy2, generator, args.device)
            (r_loss + g_loss + p).backward(mone)
            optim_criticy1.step()
            optim_criticy2.step()

        optim_generator.zero_grad()
        t_lossx = transfer_loss(data, data, args.z_dim, args.eps, args.lp, criticx1, criticx2, generator, args.device)
        t_lossy = transfer_loss(data, datay, args.z_dim, args.eps, args.lp, criticy1, criticy2, generator, args.device)
        (0.5*t_lossx + 0.5*t_lossy).backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            datax = batchx.to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            datay = batchy.to(args.device)
            data = torch.randn(args.test_batch_size, args.z_dim, device=args.device)
            evaluate(args.visualiser, data, datax, datay, generator, 'x')
            d_loss = (r_loss+g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss')
            args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss y')
            args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss x')
            args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
Ejemplo n.º 19
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    criticx1 = models['criticx1'].to(args.device)
    criticx2 = models['criticx2'].to(args.device)
    criticy1 = models['criticy1'].to(args.device)
    criticy2 = models['criticy2'].to(args.device)
    print(generator)
    print(criticx1)
    print(criticy1)

    optim_criticx1 = optim.Adam(criticx1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticx2 = optim.Adam(criticx2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticy1 = optim.Adam(criticy1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticy2 = optim.Adam(criticy2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1, iter2 = iter(train_loader1), iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2 = iter(test_loader1), iter(test_loader2)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()

    generator.train()
    criticx1.train()
    criticx2.train()
    criticy1.train()
    criticy2.train()
    for i in range(4000):
        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)

        optim_criticx1.zero_grad()
        optim_criticx2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, data, args.eps, args.lp,
                                                 criticx1, criticx2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticx1.step()
        optim_criticx2.step()

        optim_criticy1.zero_grad()
        optim_criticy2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps,
                                                 args.lp, criticy1, criticy2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticy1.step()
        optim_criticy2.step()
        if i % 100 == 0:
            print(f'Critics-{i}')
            print('Iter: %s' % i, time.time() - t0)
            args.visualiser.plot(step=i,
                                 data=p.detach().cpu().numpy(),
                                 title=f'Penalty')
            d_loss = (r_loss + g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y')
            t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        criticx1.train()
        criticx2.train()
        criticy1.train()
        criticy2.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx[0].to(args.device)
        #if data.shape[0] != args.train_batch_size:
        #batchx, iter1 = sample(iter1, train_loader1)
        #data = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)
        #if datay.shape[0] != args.train_batch_size:
        #batchy, iter2 = sample(iter2, train_loader2)
        #datay = batchy[0].to(args.device)

        optim_generator.zero_grad()
        t_ = torch.rand(args.nt, device=args.device)
        t = torch.stack([t_] * data.shape[0])
        #t = torch.stack([t_] * input_data.shape[0]).transpose(0, 1).reshape(-1, 1)
        tdata = torch.cat([data] * args.nt)
        tdatay = torch.cat([datay] * args.nt)
        t_lossx = transfer_loss(tdata, tdata, args.nt, t, args.eps, args.lp,
                                criticx1, criticx2, generator)
        t_lossy = transfer_loss(tdata, tdatay, args.nt, t, args.eps, args.lp,
                                criticy1, criticy2, generator)
        t_loss = ((1 - t_) * t_lossx + t_ * t_lossy).sum()
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            datax = batchx[0].to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            datay = batchy[0].to(args.device)
            evaluate(args.visualiser, datax, datay, generator, 'x',
                     args.device)
            args.visualiser.plot(step=i,
                                 data=t_lossx.detach().cpu().numpy(),
                                 title=f'Generator loss X')
            args.visualiser.plot(step=i,
                                 data=t_lossy.detach().cpu().numpy(),
                                 title=f'Generator loss Y')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)