示例#1
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        resume_iter = infer_iteration('nets', args.reload, args.model_path, args.save_path)
        print(resume_iter)

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        self.optims = Munch()
        for net in self.nets.keys():
            self.optims[net] = torch.optim.Adam(
                params=self.nets[net].parameters(),
                lr=args.f_lr if net == 'mapping_network' else args.lr,
                betas=[args.beta1, args.beta2],
                weight_decay=args.weight_decay)

        self.ckptios = [
            CheckpointIO(ospj(args.model_path, 'nets:{:06d}.ckpt'), **self.nets),
            CheckpointIO(ospj(args.model_path, 'nets_ema:{:06d}.ckpt'), **self.nets_ema),
            CheckpointIO(ospj(args.model_path, 'optims:{:06d}.ckpt'), **self.optims)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the EMA parameters
            if ('ema' not in name):
                print('Initializing %s...' % name)
                network.apply(he_init)
示例#2
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    discriminator = models['discriminator'].to(args.device)
    print(generator)
    print(discriminator)

    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        discriminator.train()
        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            optim_discriminator.zero_grad()
            d_pos_loss, d_neg_loss, gp = disc_loss_generation(
                data, args.z_dim, discriminator, generator, args.device)
            d_pos_loss.backward()
            d_neg_loss.backward(mone)
            (10 * gp).backward()
            optim_discriminator.step()

        optim_generator.zero_grad()
        t_loss = transfer_loss(args.train_batch_size, args.z_dim,
                               discriminator, generator, args.device)
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            noise = torch.randn(data.shape[0], args.z_dim, device=args.device)
            evaluate(args.visualiser, noise, data, generator, i)
            d_loss = (d_pos_loss - d_neg_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i,
                                 data=d_loss,
                                 title=f'Discriminator loss')
            args.visualiser.plot(step=i,
                                 data=t_loss.detach().cpu().numpy(),
                                 title=f'Generator loss')
            args.visualiser.plot(step=i,
                                 data=gp.detach().cpu().numpy(),
                                 title=f'Gradient Penalty')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#3
0
def train(args):
    parameters = vars(args)
    train_loader, valid_loader, test_loader = args.loader

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    classifier = models['classifier'].to(args.device)
    print(classifier)
    optim_classifier = optim.SGD(classifier.parameters(),
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 nesterov=args.nesterov,
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optim_classifier, args.iterations)

    it = iter(train_loader)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    best_accuracy = 0
    for i in range(iteration, args.iterations):
        classifier.train()
        batch, it = sample(it, train_loader)
        data = batch[0].to(args.device)
        classes = batch[1].to(args.device)

        optim_classifier.zero_grad()
        loss = classification_loss(data, classes, classifier)
        loss.backward()
        optim_classifier.step()
        scheduler.step()

        if i % args.evaluate == 0:
            classifier.eval()
            print('Iter: %s' % i, time.time() - t0)
            valid_accuracy = evaluate(args.visualiser, i, valid_loader,
                                      classifier, 'valid', args.device)
            test_accuracy = evaluate(args.visualiser, i, test_loader,
                                     classifier, 'test', args.device)
            loss = loss.detach().cpu().numpy()
            args.visualiser.plot(loss, title=f'loss', step=i)
            if valid_accuracy > best_accuracy:
                best_accuracy = valid_accuracy
                save_models(models, i, args.model_path, args.checkpoint)
            t0 = time.time()
示例#4
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic1 = models['critic1'].to(args.device)
    critic2 = models['critic2'].to(args.device)
    print(generator)
    print(critic1)

    optim_critic1 = optim.Adam(critic1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_critic2 = optim.Adam(critic2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    iter1, iter2 = iter(train_loader1), iter(train_loader2)
    iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2 = iter(test_loader1), iter(test_loader2)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic1.train()
        critic2.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data1 = batchx.to(args.device)
            batchy, iter2 = sample(iter2, train_loader2)
            data2 = batchy.to(args.device)

            optim_critic1.zero_grad()
            r_loss1 = critic_loss(data1, data1, args.alpha, critic1, generator, args.device)
            r_loss1.backward(mone)
            optim_critic1.step()

            optim_critic2.zero_grad()
            r_loss2 = critic_loss(data1, data2, args.alpha, critic2, generator, args.device)
            r_loss2.backward(mone)
            optim_critic2.step()

        optim_generator.zero_grad()
        for _ in range(10):
            t_ = torch.distributions.beta.Beta(args.alpha, args.alpha).sample_n(1).to(args.device)
            t = torch.stack([t_]*data1.shape[0])
            t_loss1 = transfer_loss(data1, data1, t, critic1, generator, args.device)**2
            t_loss2 = transfer_loss(data1, data2, t, critic2, generator, args.device)**2
            ((1-t_)*t_loss1 + t_*t_loss2).backward()

        #t_ = torch.FloatTensor([0]).to(args.device)
        #t = torch.stack([t_]*data1.shape[0])
        #gen = generator(data1, t)
        #reg = F.mse_loss(gen, data1)
        #(10*reg).backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            data1 = batchx.to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            data2 = batchy.to(args.device)
            evaluate(args.visualiser, data1, data2, args.z_dim, generator, i, args.device)
            args.visualiser.plot(step=i, data=r_loss1.detach().cpu().numpy(), title=f'Critic loss 1')
            args.visualiser.plot(step=i, data=r_loss2.detach().cpu().numpy(), title=f'Critic loss 2')
            args.visualiser.plot(step=i, data=t_loss1.detach().cpu().numpy(), title=f'Generator loss 1')
            args.visualiser.plot(step=i, data=t_loss2.detach().cpu().numpy(), title=f'Generator loss 2')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#5
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    g12 = models['g12'].to(args.device)
    g21 = models['g21'].to(args.device)
    d1 = models['d1'].to(args.device)
    d2 = models['d2'].to(args.device)
    eval_model = args.evaluation.eval().to(args.device)
    print(g12)
    print(g21)
    print(d1)
    print(d2)

    optim_g12 = optim.Adam(g12.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2))
    optim_g21 = optim.Adam(g21.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2))
    optim_d1 = optim.Adam(d1.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2))
    optim_d2 = optim.Adam(d2.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    titer1 = iter(test_loader1)
    titer2 = iter(test_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        g12.train()
        g21.train()
        d1.train()
        d2.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data1 = batchx[0].to(args.device)
        if data1.shape[0] != args.train_batch_size:
            batchx, iter1 = sample(iter1, train_loader1)
            data1 = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        data2 = batchy[0].to(args.device)
        if data2.shape[0] != args.train_batch_size:
            batchy, iter2 = sample(iter2, train_loader2)
            data2 = batchy[0].to(args.device)

        dloss1 = disc_loss(data2, data1, g21, d1)
        optim_d1.zero_grad()
        dloss1.backward()
        optim_d1.step()

        dloss2 = disc_loss(data1, data2, g12, d2)
        optim_d2.zero_grad()
        dloss2.backward()
        optim_d2.step()

        gloss1 = generator_loss(data1, g12, g21, d2)
        optim_g12.zero_grad()
        optim_g21.zero_grad()
        gloss1.backward()
        optim_g12.step()
        optim_g21.step()
        gloss2 = generator_loss(data2, g21, g12, d1)
        optim_g12.zero_grad()
        optim_g21.zero_grad()
        gloss2.backward()
        optim_g12.step()
        optim_g21.step()

        if i % args.evaluate == 0:
            g12.eval()
            g21.eval()
            batchx, titer1 = sample(titer1, test_loader1)
            data1 = batchx[0].to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            data2 = batchy[0].to(args.device)
            print('Iter: %s' % i, time.time() - t0)
            evaluate(args.visualiser, data1, g12, 'x')
            evaluate(args.visualiser, data2, g21, 'y')
            evaluate_gen_class_accuracy(args.visualiser, i, test_loader1,
                                        eval_model, g12, args.device)
            evaluate_class_accuracy(args.visualiser, i, test_loader2,
                                    eval_model, args.device)
            args.visualiser.plot(dloss1.cpu().detach().numpy(),
                                 title=f'Discriminator loss1',
                                 step=i)
            args.visualiser.plot(dloss2.cpu().detach().numpy(),
                                 title=f'Discriminator loss2',
                                 step=i)
            args.visualiser.plot(gloss1.cpu().detach().numpy(),
                                 title=f'Generator loss 1-2',
                                 step=i)
            args.visualiser.plot(gloss2.cpu().detach().numpy(),
                                 title=f'Generator loss 2-1',
                                 step=i)
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#6
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    encoder = models['classifier'].to(args.device)
    contrastive = models['contrastive'].to(args.device)
    print(encoder)
    print(contrastive)

    optim_encoder = optim.Adam(encoder.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_contrastive = optim.Adam(contrastive.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        encoder.train()
        contrastive.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            datax = batchx[0].float().to(args.device)

            optim_contrastive.zero_grad()
            ploss, nloss, gp = contrastive_loss(datax, args.nc, encoder,
                                                contrastive, args.device)
            (ploss - nloss + gp).backward()
            optim_contrastive.step()

        optim_encoder.zero_grad()

        batchx, iter1 = sample(iter1, train_loader1)
        datax = batchx[0].float().to(args.device)
        dataxp = batchx[1].float().to(args.device)

        dloss, closs = compute_loss(datax, dataxp, encoder, contrastive,
                                    args.device)
        (args.ld * dloss + closs).backward()
        optim_encoder.step()

        if i % args.evaluate == 0:
            encoder.eval()
            contrastive.eval()
            print('Iter: {}'.format(i), end=': ')
            evaluate(args.visualiser, datax, dataxp, 'x')
            _acc = evaluate_accuracy(args.visualiser, i, test_loader1, encoder,
                                     args.nc, 'x', args.device)
            print('disc loss: {}'.format(
                (ploss - nloss).detach().cpu().numpy()),
                  end='\t')
            print('gp: {}'.format(gp.detach().cpu().numpy()), end='\t')
            print('positive dist loss: {}'.format(
                dloss.detach().cpu().numpy()),
                  end='\t')
            print('contrast. loss: {}'.format(closs.detach().cpu().numpy()),
                  end='\t')
            print('Accuracy: {}'.format(_acc))

            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#7
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    ssx = args.ssx.to(args.device)
    ssx.eval()

    zxs, labelsx = get_initial_zx(train_loader1, ssx, args.device)
    zys, labelsy = get_initial_zx(train_loader2, ssx, args.device)

    sc = SpectralClustering(args.nc, affinity='sigmoid', gamma=1.7)
    clusters = sc.fit_predict(zxs.cpu().numpy())
    clusters = torch.from_numpy(clusters).to(args.device)

    classifier = models['classifier'].to(args.device)
    discriminator = models['discriminator'].to(args.device)
    classifier.apply(he_init)
    discriminator.apply(he_init)
    print(classifier)
    print(discriminator)

    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))
    optim_classifier = optim.Adam(classifier.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
    optims = {
        'optim_discriminator': optim_discriminator,
        'optim_classifier': optim_classifier
    }

    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        classifier.train()
        discriminator.train()

        perm = torch.randperm(len(zxs))
        ix = perm[:args.train_batch_size]
        zx = zxs[ix]
        perm = torch.randperm(len(zys))
        iy = perm[:args.train_batch_size]
        zy = zys[iy]

        optim_discriminator.zero_grad()
        d_loss = disc_loss(zx, zy, discriminator, classifier.x, classifier.mlp,
                           args.device)
        d_loss.backward()
        optim_discriminator.step()

        perm = torch.randperm(len(zxs))
        ix = perm[:args.train_batch_size]
        zx = zxs[ix]
        label = clusters[ix].long()
        perm = torch.randperm(len(zys))
        iy = perm[:args.train_batch_size]
        zy = zys[iy]

        optim_classifier.zero_grad()
        c_loss = classification_loss(zx, label, classifier)
        tcw_loss = classification_target_loss(zy, classifier)
        dw_loss = embed_div_loss(zx, zy, discriminator, classifier.x,
                                 classifier.mlp, args.device)
        m_loss1 = mixup_loss(zx, classifier, args.device)
        m_loss2 = mixup_loss(zy, classifier, args.device)
        (args.cw * c_loss).backward()
        (args.tcw * tcw_loss).backward()
        (args.dw * dw_loss).backward()
        (args.smw * m_loss1).backward()
        (args.tmw * m_loss2).backward()
        optim_classifier.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            classifier.eval()

            class_map = evaluate_cluster(args.visualiser, i, args.nc, zxs,
                                         labelsx, classifier, f'x',
                                         args.device)
            evaluate_cluster_accuracy(args.visualiser, i, zxs, labelsx,
                                      class_map, classifier, f'x', args.device)
            evaluate_cluster_accuracy(args.visualiser, i, zys, labelsy,
                                      class_map, classifier, f'y', args.device)

            save_path = args.save_path
            with open(os.path.join(save_path, 'c_loss'), 'a') as f:
                f.write(f'{i},{c_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'tcw_loss'), 'a') as f:
                f.write(f'{i},{tcw_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'dw_loss'), 'a') as f:
                f.write(f'{i},{dw_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'm_loss1'), 'a') as f:
                f.write(f'{i},{m_loss1.cpu().item()}\n')
            with open(os.path.join(save_path, 'm_loss2'), 'a') as f:
                f.write(f'{i},{m_loss2.cpu().item()}\n')
            with open(os.path.join(save_path, 'd_loss2'), 'a') as f:
                f.write(f'{i},{d_loss.cpu().item()}\n')
            args.visualiser.plot(c_loss.cpu().detach().numpy(),
                                 title='Source classifier loss',
                                 step=i)
            args.visualiser.plot(tcw_loss.cpu().detach().numpy(),
                                 title='Target classifier cross entropy',
                                 step=i)
            args.visualiser.plot(dw_loss.cpu().detach().numpy(),
                                 title='Classifier marginal divergence',
                                 step=i)
            args.visualiser.plot(m_loss1.cpu().detach().numpy(),
                                 title='Source mix up loss',
                                 step=i)
            args.visualiser.plot(m_loss2.cpu().detach().numpy(),
                                 title='Target mix up loss',
                                 step=i)
            args.visualiser.plot(d_loss.cpu().detach().numpy(),
                                 title='Discriminator loss',
                                 step=i)
            t0 = time.time()
            save_models(models, i, args.model_path, args.evaluate)
            save_models(optims, i, args.model_path, args.evaluate)
示例#8
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2
    train_loader3, test_loader3 = args.loaders3
    train_loader4, test_loader4 = args.loaders4

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    criticx1 = models['criticx1'].to(args.device)
    criticx2 = models['criticx2'].to(args.device)
    criticy1 = models['criticy1'].to(args.device)
    criticy2 = models['criticy2'].to(args.device)
    criticz1 = models['criticz1'].to(args.device)
    criticz2 = models['criticz2'].to(args.device)
    print(generator)
    print(criticx1)
    print(criticy1)

    optim_criticx1 = optim.Adam(criticx1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticx2 = optim.Adam(criticx2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticy1 = optim.Adam(criticy1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticy2 = optim.Adam(criticy2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticz1 = optim.Adam(criticz1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticz2 = optim.Adam(criticz2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    iter1, iter2, iter3, iter4 = iter(train_loader1), iter(train_loader2), iter(train_loader3), iter(train_loader4)
    iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2, titer3, titer4 = iter(test_loader1), iter(test_loader2), iter(test_loader3), iter(test_loader4)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()

    generator.train()
    criticx1.train()
    criticx2.train()
    criticy1.train()
    criticy2.train()
    criticz1.train()
    criticz2.train()
    for i in range(0):
        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)
        datay[:,1] = datay[:,1]*0
        datay[:,2] = datay[:,2]*0

        batchz, iter3 = sample(iter3, train_loader3)
        dataz = batchz[0].to(args.device)
        dataz[:,0] = dataz[:,0]*0
        dataz[:,2] = dataz[:,2]*0

        batchw, iter4 = sample(iter4, train_loader4)
        dataw = batchw[0].to(args.device)
        dataw[:,0] = dataw[:,0]*0
        dataw[:,1] = dataw[:,1]*0

        optim_criticx1.zero_grad()
        optim_criticx2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps, args.lp, criticx1, criticx2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticx1.step()
        optim_criticx2.step()

        optim_criticy1.zero_grad()
        optim_criticy2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, dataz, args.eps, args.lp, criticy1, criticy2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticy1.step()
        optim_criticy2.step()

        optim_criticz1.zero_grad()
        optim_criticz2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, dataw, args.eps, args.lp, criticz1, criticz2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticz1.step()
        optim_criticz2.step()

        if i % 100 == 0:
            print(f'Critics-{i}')
            print('Iter: %s' % i, time.time() - t0)
            args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty')
            d_loss = (r_loss+g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y')
            t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        criticx1.train()
        criticx2.train()
        criticy1.train()
        criticy2.train()
        criticz1.train()
        criticz2.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)

            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy[0].to(args.device)
            datay[:,1] = datay[:,1]*0
            datay[:,2] = datay[:,2]*0

            batchz, iter3 = sample(iter3, train_loader3)
            dataz = batchz[0].to(args.device)
            dataz[:,0] = dataz[:,0]*0
            dataz[:,2] = dataz[:,2]*0

            batchw, iter4 = sample(iter4, train_loader4)
            dataw = batchw[0].to(args.device)
            dataw[:,0] = dataw[:,0]*0
            dataw[:,1] = dataw[:,1]*0

            optim_criticx1.zero_grad()
            optim_criticx2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps, args.lp, criticx1, criticx2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticx1.step()
            optim_criticx2.step()

            optim_criticy1.zero_grad()
            optim_criticy2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, dataz, args.eps, args.lp, criticy1, criticy2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticy1.step()
            optim_criticy2.step()

            optim_criticz1.zero_grad()
            optim_criticz2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, dataw, args.eps, args.lp, criticz1, criticz2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticz1.step()
            optim_criticz2.step()

        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)
        datay[:,1] = datay[:,1]*0
        datay[:,2] = datay[:,2]*0

        batchz, iter3 = sample(iter3, train_loader3)
        dataz = batchz[0].to(args.device)
        dataz[:,0] = dataz[:,0]*0
        dataz[:,2] = dataz[:,2]*0

        batchw, iter4 = sample(iter4, train_loader4)
        dataw = batchw[0].to(args.device)
        dataw[:,0] = dataw[:,0]*0
        dataw[:,1] = dataw[:,1]*0

        optim_generator.zero_grad()
        t_ = Dirichlet(torch.FloatTensor([1.,1.,1.])).sample().to(args.device)
        t = torch.stack([t_]*data.shape[0])
        t_lossx = transfer_loss(data, datay, t, args.eps, args.lp, criticx1, criticx2, generator)
        t_lossy = transfer_loss(data, dataz, t, args.eps, args.lp, criticy1, criticy2, generator)
        t_lossz = transfer_loss(data, dataw, t, args.eps, args.lp, criticz1, criticz2, generator)
        t_loss = (t_[0]*t_lossx + t_[1]*t_lossy + t_[2]*t_lossz).sum()
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            #batchx, titer1 = sample(titer1, test_loader1)
            #datax = batchx[0].to(args.device)
            #batchy, titer2 = sample(titer2, test_loader2)
            #datay = batchy[0].to(args.device)
            evaluate(args.visualiser, data, datay, dataz, dataw, generator, 'x', args.device)
            args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss X')
            args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss Y')
            args.visualiser.plot(step=i, data=t_lossz.detach().cpu().numpy(), title=f'Generator loss Z')
            args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty')
            d_loss = (r_loss+g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#9
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic = models['critic'].to(args.device)
    encoder = models['encoder'].to(args.device)
    print(generator)
    print(critic)
    print(encoder)

    optim_critic = optim.Adam(critic.parameters(),
                              lr=args.lr,
                              betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))
    optim_encoder = optim.Adam(encoder.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1 = iter(test_loader1)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic.train()
        encoder.train()

        for _ in range(10):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            optim_encoder.zero_grad()
            optim_generator.zero_grad()
            e_loss = encoder_loss(data.shape[0], args.lp, args.z_dim, encoder,
                                  generator, critic, args.device)
            e_loss.backward()
            optim_encoder.step()
            optim_generator.step()

        for _ in range(1):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            optim_critic.zero_grad()
            r_loss = critic_loss(data, args.lp, args.z_dim, encoder, critic,
                                 generator, args.device)
            r_loss.backward(mone)
            optim_critic.step()

        for _ in range(1):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            optim_generator.zero_grad()
            t_loss = transfer_loss(data.shape[0], args.lp, args.z_dim, encoder,
                                   critic, generator, args.device)
            t_loss.backward()
            optim_generator.step()

        if i % args.evaluate == 0:
            encoder.eval()
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            data = batchx[0].to(args.device)
            evaluate(args.visualiser, args.z_dim, data, encoder, generator,
                     critic, args.z_dim, i, args.device)
            d_loss = (r_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss')
            args.visualiser.plot(step=i,
                                 data=e_loss.detach().cpu().numpy(),
                                 title=f'Encoder loss')
            args.visualiser.plot(step=i,
                                 data=t_loss.detach().cpu().numpy(),
                                 title=f'Generator loss')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#10
0
文件: train.py 项目: alexmlamb/SPUDT
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    classifier = models['classifier'].to(args.device)
    discriminator = models['discriminator'].to(args.device)
    print(classifier)
    print(discriminator)

    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))
    optim_classifier = optim.Adam(classifier.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
    optims = {
        'optim_discriminator': optim_discriminator,
        'optim_classifier': optim_classifier
    }
    initialize(optims, args.reload, args.save_path, args.model_path)

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        classifier.train()
        discriminator.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data1 = batchx[0].to(args.device)
        if data1.shape[0] != args.train_batch_size:
            batchx, iter1 = sample(iter1, train_loader1)
            data1 = batchx[0].to(args.device)
        label = batchx[1].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        data2 = batchy[0].to(args.device)
        if data2.shape[0] != args.train_batch_size:
            batchy, iter2 = sample(iter2, train_loader2)
            data2 = batchy[0].to(args.device)

        optim_discriminator.zero_grad()
        d_loss = disc_loss(data1, data2, discriminator, classifier.x,
                           classifier.mlp, args.device)
        d_loss.backward()
        optim_discriminator.step()

        optim_classifier.zero_grad()
        c_loss = classification_loss(data1, label, classifier)
        tcw_loss = classification_target_loss(data2, classifier)
        dw_loss = embed_div_loss(data1, data2, discriminator, classifier.x,
                                 classifier.mlp, args.device)
        v_loss1 = vat_loss(data1, classifier, args.radius, args.device)
        v_loss2 = vat_loss(data2, classifier, args.radius, args.device)
        m_loss1 = mixup_loss(data1, classifier, args.device)
        m_loss2 = mixup_loss(data2, classifier, args.device)
        (args.cw * c_loss).backward()
        (args.tcw * tcw_loss).backward()
        (args.dw * dw_loss).backward()
        (args.svw * v_loss1).backward()
        (args.tvw * v_loss2).backward()
        (args.smw * m_loss1).backward()
        (args.tmw * m_loss2).backward()
        optim_classifier.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            classifier.eval()

            test_accuracy_x = evaluate(test_loader1, classifier, args.device)
            test_accuracy_y = evaluate(test_loader2, classifier, args.device)

            save_path = args.save_path
            with open(os.path.join(save_path, 'c_loss'), 'a') as f:
                f.write(f'{i},{c_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'tcw_loss'), 'a') as f:
                f.write(f'{i},{tcw_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'dw_loss'), 'a') as f:
                f.write(f'{i},{dw_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'v_loss1'), 'a') as f:
                f.write(f'{i},{v_loss1.cpu().item()}\n')
            with open(os.path.join(save_path, 'v_loss2'), 'a') as f:
                f.write(f'{i},{v_loss2.cpu().item()}\n')
            with open(os.path.join(save_path, 'm_loss1'), 'a') as f:
                f.write(f'{i},{m_loss1.cpu().item()}\n')
            with open(os.path.join(save_path, 'm_loss2'), 'a') as f:
                f.write(f'{i},{m_loss2.cpu().item()}\n')
            with open(os.path.join(save_path, 'd_loss2'), 'a') as f:
                f.write(f'{i},{d_loss.cpu().item()}\n')
            #with open(os.path.join(save_path, 'eval_accuracy_x'), 'a') as f: f.write(f'{i},{eval_accuracy_x}\n')
            #with open(os.path.join(save_path, 'eval_accuracy_y'), 'a') as f: f.write(f'{i},{eval_accuracy_y}\n')
            with open(os.path.join(save_path, 'test_accuracy_x'), 'a') as f:
                f.write(f'{i},{test_accuracy_x}\n')
            with open(os.path.join(save_path, 'test_accuracy_y'), 'a') as f:
                f.write(f'{i},{test_accuracy_y}\n')
            args.visualiser.plot(c_loss.cpu().detach().numpy(),
                                 title='Source classifier loss',
                                 step=i)
            args.visualiser.plot(tcw_loss.cpu().detach().numpy(),
                                 title='Target classifier cross entropy',
                                 step=i)
            args.visualiser.plot(dw_loss.cpu().detach().numpy(),
                                 title='Classifier marginal divergence',
                                 step=i)
            args.visualiser.plot(v_loss1.cpu().detach().numpy(),
                                 title='Source virtual adversarial loss',
                                 step=i)
            args.visualiser.plot(v_loss2.cpu().detach().numpy(),
                                 title='Target virtual adversarial loss',
                                 step=i)
            args.visualiser.plot(m_loss1.cpu().detach().numpy(),
                                 title='Source mix up loss',
                                 step=i)
            args.visualiser.plot(m_loss2.cpu().detach().numpy(),
                                 title='Target mix up loss',
                                 step=i)
            args.visualiser.plot(d_loss.cpu().detach().numpy(),
                                 title='Discriminator loss',
                                 step=i)
            #args.visualiser.plot(eval_accuracy_x, title='Eval acc X', step=i)
            #args.visualiser.plot(eval_accuracy_y, title='Eval acc Y', step=i)
            args.visualiser.plot(test_accuracy_x, title='Test acc X', step=i)
            args.visualiser.plot(test_accuracy_y, title='Test acc Y', step=i)
            t0 = time.time()
            save_models(models, i, args.model_path, args.evaluate)
            save_models(optims, i, args.model_path, args.evaluate)
示例#11
0
def train(args):
    parameters = vars(args)
    train_loader1, valid_loader1, test_loader1 = args.loaders1
    train_loader2, valid_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    critic = models['critic'].to(args.device)
    generator = models['generator'].to(args.device)
    evalY = args.evalY.to(args.device).eval()
    semantic = args.semantic.to(args.device).eval()
    print(generator)
    print(critic)
    print(semantic)

    optim_critic = optim.Adam(critic.parameters(),
                              lr=args.lr,
                              betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        critic.train()
        generator.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            datax = batchx[0].to(args.device)

            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy[0].to(args.device)

            critic_lossy = compute_critic_loss(datax, args.z_dim, datay,
                                               critic, generator, args.device)
            optim_critic.zero_grad()
            critic_lossy.backward()
            optim_critic.step()

        batchx, iter1 = sample(iter1, train_loader1)
        datax = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)

        glossxy = generator_loss(datax, args.z_dim, critic, generator,
                                 args.device)
        slossxy = semantic_loss(datax, args.z_dim, generator, semantic,
                                args.device)
        optim_generator.zero_grad()
        glossxy.backward()
        (args.gsxy * slossxy).backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            generator.eval()
            save_path = args.save_path
            plot_transfer(args.visualiser, datax, datay, args.z_dim, generator,
                          'x-y', i, args.device)
            test_accuracy_xy = evaluate(test_loader1, args.z_dim, generator,
                                        evalY, args.device)
            with open(os.path.join(save_path, 'glossxy'), 'a') as f:
                f.write(f'{i},{glossxy.cpu().item()}\n')
            with open(os.path.join(save_path, 'slossxy'), 'a') as f:
                f.write(f'{i},{slossxy.cpu().item()}\n')
            with open(os.path.join(save_path, 'test_accuracy_xy'), 'a') as f:
                f.write(f'{i},{test_accuracy_xy}\n')
            args.visualiser.plot(critic_lossy.cpu().detach().numpy(),
                                 title='critic_lossy',
                                 step=i)
            args.visualiser.plot(glossxy.cpu().detach().numpy(),
                                 title='glossxy',
                                 step=i)
            args.visualiser.plot(slossxy.cpu().detach().numpy(),
                                 title='slossxy',
                                 step=i)
            args.visualiser.plot(test_accuracy_xy,
                                 title=f'Test transfer accuracy X-Y',
                                 step=i)
            t0 = time.time()
            save_models(models, 0, args.model_path, args.checkpoint)
示例#12
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    classifier = models['classifier'].to(args.device)
    discriminator = models['discriminator'].to(args.device)
    cluster = args.cluster.eval().to(args.device)
    print(classifier)
    print(discriminator)

    optim_classifier = optim.Adam(classifier.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
    optim_discriminator = optim.Adam(discriminator.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    titer1 = iter(test_loader1)
    titer2 = iter(test_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        classifier.train()
        discriminator.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data1 = batchx[0].to(args.device)
        if data1.shape[0] != args.train_batch_size:
            batchx, iter1 = sample(iter1, train_loader1)
            data1 = batchx[0].to(args.device)
        label = cluster(data1).argmax(1).detach()

        batchy, iter2 = sample(iter2, train_loader2)
        data2 = batchy[0].to(args.device)
        if data2.shape[0] != args.train_batch_size:
            batchy, iter2 = sample(iter2, train_loader2)
            data2 = batchy[0].to(args.device)

        optim_discriminator.zero_grad()
        d_loss = disc_loss(data1, data2, discriminator, classifier.x,
                           classifier.mlp, args.device)
        d_loss.backward()
        optim_discriminator.step()

        optim_classifier.zero_grad()
        c_loss = classification_loss(data1, label, classifier)
        tcw_loss = classification_target_loss(data2, classifier)
        dw_loss = embed_div_loss(data1, data2, discriminator, classifier.x,
                                 classifier.mlp, args.device)
        v_loss1 = vat_loss(data1, classifier, args.radius, args.device)
        v_loss2 = vat_loss(data2, classifier, args.radius, args.device)
        m_loss1 = mixup_loss(data1, classifier, args.device)
        m_loss2 = mixup_loss(data2, classifier, args.device)
        (args.cw * c_loss).backward()
        (args.tcw * tcw_loss).backward()
        (args.dw * dw_loss).backward()
        (args.svw * v_loss1).backward()
        (args.tvw * v_loss2).backward()
        (args.smw * m_loss1).backward()
        (args.tmw * m_loss2).backward()
        optim_classifier.step()

        if i % args.evaluate == 0:
            classifier.eval()
            batchx, titer1 = sample(titer1, test_loader1)
            data1 = batchx[0].to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            data2 = batchy[0].to(args.device)
            print('Iter: %s' % i, time.time() - t0)
            class_map = evaluate_cluster(args.visualiser, i, args.nc,
                                         test_loader1, cluster, f'x',
                                         args.device)
            evaluate_cluster_accuracy(args.visualiser, i, test_loader1,
                                      class_map, classifier, f'x', args.device)
            evaluate_cluster_accuracy(args.visualiser, i, test_loader2,
                                      class_map, classifier, f'y', args.device)
            args.visualiser.plot(c_loss.cpu().detach().numpy(),
                                 title=f'Classifier loss',
                                 step=i)
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#13
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    print(generator)

    optim_generator = optim.SGD(generator.parameters(),
                                lr=args.lr)  #, betas=(args.beta1, args.beta2))

    iter1, iter2 = iter(train_loader1), iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2 = iter(test_loader1), iter(test_loader2)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx.to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy.to(args.device)

        optim_generator.zero_grad()
        x = generator(data)
        t_lossx = sinkhorn_loss(x,
                                data,
                                args.eps,
                                data.shape[0],
                                100,
                                args.device,
                                p=args.lp)**args.p_exp
        t_lossy = sinkhorn_loss(x,
                                datay,
                                args.eps,
                                data.shape[0],
                                100,
                                args.device,
                                p=args.lp)**args.p_exp
        ((1 - args.alphat) * t_lossx + args.alphat * t_lossy).backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            datax = batchx.to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            datay = batchy.to(args.device)
            evaluate(args.visualiser, datax, datax, datay, generator, 'x')
            args.visualiser.plot(step=i,
                                 data=t_lossx.detach().cpu().numpy(),
                                 title=f'Generator loss y')
            args.visualiser.plot(step=i,
                                 data=t_lossy.detach().cpu().numpy(),
                                 title=f'Generator loss x')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#14
0
def train(args):
    parameters = vars(args)
    train_loader1, valid_loader1, test_loader1 = args.loaders1
    train_loader2, valid_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    criticX = models['criticX'].to(args.device)
    criticY = models['criticY'].to(args.device)
    generatorXY = models['generatorXY'].to(args.device)
    generatorYX = models['generatorYX'].to(args.device)
    evalX = args.evalX.to(args.device).eval()
    evalY = args.evalY.to(args.device).eval()
    classifier = args.classifier.to(args.device).eval()
    print(generatorXY)
    print(criticX)
    print(classifier)

    optim_criticX = optim.Adam(criticX.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_criticY = optim.Adam(criticY.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_generatorXY = optim.Adam(generatorXY.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta1, args.beta2))
    optim_generatorYX = optim.Adam(generatorYX.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iter2 = iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        criticX.train()
        criticY.train()
        generatorXY.train()
        generatorYX.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            datax = batchx[0].to(args.device)
            if datax.shape[0] != args.train_batch_size:
                batchx, iter1 = sample(iter1, train_loader1)
                datax = batchx[0].to(args.device)

            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy[0].to(args.device)
            if datay.shape[0] != args.train_batch_size:
                batchy, iter2 = sample(iter2, train_loader2)
                datay = batchy[0].to(args.device)

            #critic_lossx = compute_critic_loss(datay, args.z_dim, datax, criticX, generatorYX, args.device)
            #optim_criticX.zero_grad()
            #critic_lossx.backward()
            #optim_criticX.step()

            critic_lossy = compute_critic_loss(datax, args.z_dim, datay,
                                               criticY, generatorXY,
                                               args.device)
            optim_criticY.zero_grad()
            critic_lossy.backward()
            optim_criticY.step()

        batchx, iter1 = sample(iter1, train_loader1)
        datax = batchx[0].to(args.device)
        if datax.shape[0] != args.train_batch_size:
            batchx, iter1 = sample(iter1, train_loader1)
            datax = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)
        if datay.shape[0] != args.train_batch_size:
            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy[0].to(args.device)

        glossxy = generator_loss(datax, args.z_dim, criticY, generatorXY,
                                 args.device)
        slossxy = semantic_loss(datax, args.z_dim, generatorXY, classifier,
                                args.device)
        #cyclelossxy = cycle_loss(datax, args.z_dim, generatorXY, generatorYX, args.device)
        #idlossxy = identity_loss(datax, args.z_dim, generatorXY, args.device)
        optim_generatorXY.zero_grad()
        glossxy.backward()
        (args.gsxy * slossxy).backward()
        #(args.gcxy*cyclelossxy).backward()
        #(args.gixy*idlossxy).backward()
        optim_generatorXY.step()

        #glossyx = generator_loss(datay, args.z_dim, criticX, generatorYX, args.device)
        #slossyx = semantic_loss(datay, args.z_dim, generatorYX, classifier, args.device)
        #cyclelossyx = cycle_loss(datay, args.z_dim, generatorYX, generatorXY, args.device)
        #idlossyx = identity_loss(datay, args.z_dim, generatorYX, args.device)
        #optim_generatorYX.zero_grad()
        #glossyx.backward()
        #(args.gsyx*slossyx).backward()
        #(args.gcyx*cyclelossyx).backward()
        #(args.giyx*idlossyx).backward()
        #optim_generatorYX.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            generatorXY.eval()
            generatorYX.eval()
            save_path = args.save_path
            plot_transfer(args.visualiser, datax, datay, args.z_dim,
                          generatorXY, 'x-y', i, args.device)
            #plot_transfer(args.visualiser, datay, datax, args.z_dim, generatorYX, 'y-x', i, args.device)
            #eval_accuracy_xy = evaluate(valid_loader1, args.z_dim, generatorXY, evalY, args.device)
            #eval_accuracy_yx = evaluate(valid_loader2, args.z_dim, generatorYX, evalX, args.device)
            test_accuracy_xy = evaluate(test_loader1, args.z_dim, generatorXY,
                                        evalY, args.device)
            #test_accuracy_yx = evaluate(test_loader2, args.z_dim, generatorYX, evalX, args.device)

            #with open(os.path.join(save_path, 'critic_lossx'), 'a') as f: f.write(f'{i},{critic_lossx.cpu().item()}\n')
            #with open(os.path.join(save_path, 'critic_lossy'), 'a') as f: f.write(f'{i},{critic_lossy.cpu().item()}\n')
            with open(os.path.join(save_path, 'glossxy'), 'a') as f:
                f.write(f'{i},{glossxy.cpu().item()}\n')
            #with open(os.path.join(save_path, 'glossyx'), 'a') as f: f.write(f'{i},{glossyx.cpu().item()}\n')
            with open(os.path.join(save_path, 'slossxy'), 'a') as f:
                f.write(f'{i},{slossxy.cpu().item()}\n')
            #with open(os.path.join(save_path, 'slossyx'), 'a') as f: f.write(f'{i},{slossyx.cpu().item()}\n')
            #with open(os.path.join(save_path, 'idlossxy'), 'a') as f: f.write(f'{i},{idlossxy.cpu().item()}')
            #with open(os.path.join(save_path, 'idlossyx'), 'a') as f: f.write(f'{i},{idlossyx.cpu().item()}')
            #with open(os.path.join(save_path, 'cyclelossxy'), 'a') as f: f.write(f'{i},{cyclelossxy.cpu().item()}\n')
            #with open(os.path.join(save_path, 'cyclelossyx'), 'a') as f: f.write(f'{i},{cyclelossyx.cpu().item()}\n')
            #with open(os.path.join(save_path, 'eval_accuracy_xy'), 'a') as f: f.write(f'{i},{eval_accuracy_xy}\n')
            #with open(os.path.join(save_path, 'eval_accuracy_yx'), 'a') as f: f.write(f'{i},{eval_accuracy_yx}\n')
            with open(os.path.join(save_path, 'test_accuracy_xy'), 'a') as f:
                f.write(f'{i},{test_accuracy_xy}\n')
            #with open(os.path.join(save_path, 'test_accuracy_yx'), 'a') as f: f.write(f'{i},{test_accuracy_yx}\n')
            #args.visualiser.plot(critic_lossx.cpu().detach().numpy(), title='critic_lossx', step=i)
            args.visualiser.plot(critic_lossy.cpu().detach().numpy(),
                                 title='critic_lossy',
                                 step=i)
            args.visualiser.plot(glossxy.cpu().detach().numpy(),
                                 title='glossxy',
                                 step=i)
            #args.visualiser.plot(glossyx.cpu().detach().numpy(), title='glossyx', step=i)
            args.visualiser.plot(slossxy.cpu().detach().numpy(),
                                 title='slossxy',
                                 step=i)
            #args.visualiser.plot(slossyx.cpu().detach().numpy(), title='slossyx', step=i)
            #args.visualiser.plot(idlossxy.cpu().detach().numpy(), title='idlossxy', step=i)
            #args.visualiser.plot(idlossyx.cpu().detach().numpy(), title='idlossyx', step=i)
            #args.visualiser.plot(cyclelossxy.cpu().detach().numpy(), title='cyclelossxy', step=i)
            #args.visualiser.plot(cyclelossyx.cpu().detach().numpy(), title='cyclelossyx', step=i)
            #args.visualiser.plot(eval_accuracy_xy, title=f'Validation transfer accuracy X-Y', step=i)
            #args.visualiser.plot(eval_accuracy_yx, title=f'Validation transfer accuracy Y-X', step=i)
            args.visualiser.plot(test_accuracy_xy,
                                 title=f'Test transfer accuracy X-Y',
                                 step=i)
            #args.visualiser.plot(test_accuracy_yx, title=f'Test transfer accuracy Y-X', step=i)
            t0 = time.time()
            save_models(models, 0, args.model_path, args.checkpoint)
示例#15
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic1 = models['critic1'].to(args.device)
    critic2 = models['critic2'].to(args.device)
    print(generator)
    print(critic1)
    print(critic2)

    optim_critic1 = optim.Adam(critic1.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_critic2 = optim.Adam(critic2.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1 = iter(train_loader1)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic1.train()
        critic2.train()
        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx[0].to(args.device)
            if data.shape[0] != args.train_batch_size:
                batchx, iter1 = sample(iter1, train_loader1)
                data = batchx[0].to(args.device)

            optim_critic1.zero_grad()
            optim_critic2.zero_grad()
            z = torch.randn(data.shape[0], args.z_dim, device=args.device)
            gen1 = generator(z).detach()
            r_loss, g_loss, p = disc_loss_generation(data, gen1, args.eps,
                                                     args.lp, critic1, critic2)
            (r_loss + g_loss + p).backward(mone)
            optim_critic1.step()
            optim_critic2.step()

        optim_generator.zero_grad()
        t_loss = transfer_loss(data, args.eps, args.lp, args.z_dim, critic1,
                               critic2, generator, args.device)
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            noise = torch.randn(args.test_batch_size,
                                args.z_dim,
                                device=args.device)
            evaluate(args.visualiser, noise, data[:args.test_batch_size],
                     generator, i)
            d_loss = (r_loss + g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss')
            args.visualiser.plot(step=i,
                                 data=t_loss.detach().cpu().numpy(),
                                 title=f'Generator loss')
            args.visualiser.plot(step=i,
                                 data=p.detach().cpu().numpy(),
                                 title=f'Penalty')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#16
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, args.latent_dim, args.device)
        fetcher_val = InputFetcher(loaders.val, args.latent_dim, args.device)
        inputs_val = next(fetcher_val)

        # resume training if necessary
        resume_iter = infer_iteration('nets', args.reload, args.model_path, args.save_path)
        print(resume_iter)
        if args.resume_iter > 0:
            self._load_checkpoint(resume_iter)

        # remember the initial value of ds weight
        print('Start training...')
        start_time = time.time()
        for i in range(resume_iter, args.total_iters):
            lambda_ds = args.lambda_ds * (1 - i / args.total_iters)
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_real, d_org = inputs.x_src, inputs.y_src, inputs.d_src
            x_trg, x_ds, d_trg = inputs.x_src2, inputs.x_ds, inputs.d_src2
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(
                nets, args, x_real, y_real, d_org, d_trg, z_trg=z_trg)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = compute_d_loss(
                nets, args, x_real, y_real, d_org, d_trg, x_trg=x_trg)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(
                nets, args, x_real, y_real, d_org, d_trg, lambda_ds, z_trgs=[z_trg, z_trg2])
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref = compute_g_loss(
                nets, args, x_real, y_real, d_org, d_trg, lambda_ds, x_refs=[x_trg, x_ds])
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)

            # print out log info
            if (i+1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
                                        ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
                print(log)

            # save model checkpoints
            if (i+1) % args.save_every == 0:
                self._save_checkpoint(step=i+1, checkpoint=args.checkpoint)
示例#17
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2
    train_loader3, test_loader3 = args.loaders3
    #train_loader4, test_loader4 = args.loaders4

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    criticx1 = models['criticx1'].to(args.device)
    criticx2 = models['criticx2'].to(args.device)
    criticy1 = models['criticy1'].to(args.device)
    criticy2 = models['criticy2'].to(args.device)
    criticz1 = models['criticz1'].to(args.device)
    criticz2 = models['criticz2'].to(args.device)
    #criticw1 = models['criticw1'].to(args.device)
    #criticw2 = models['criticw2'].to(args.device)
    print(generator)
    print(criticx1)
    print(criticy1)
    print(criticz1)
    #print(criticw1)

    optim_criticx1 = optim.Adam(criticx1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticx2 = optim.Adam(criticx2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticy1 = optim.Adam(criticy1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticy2 = optim.Adam(criticy2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticz1 = optim.Adam(criticz1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticz2 = optim.Adam(criticz2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    #optim_criticw1 = optim.Adam(criticw1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    #optim_criticw2 = optim.Adam(criticw2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1, iter2, iter3 = iter(train_loader1), iter(train_loader2), iter(
        train_loader3)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2, titer3, = iter(test_loader1), iter(test_loader2), iter(
        test_loader3)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        criticx1.train()
        criticx2.train()
        criticy1.train()
        criticy2.train()
        criticz1.train()
        criticz2.train()
        #criticw1.train()
        #criticw2.train()

        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx.to(args.device)
            batchx, iter1 = sample(iter1, train_loader1)
            input_data = batchx.to(args.device)
            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy.to(args.device)
            batchz, iter3 = sample(iter3, train_loader3)
            dataz = batchz.to(args.device)
            #batchw, iter4 = sample(iter4, train_loader4)
            #dataw = batchw[0].to(args.device)

            optim_criticx1.zero_grad()
            optim_criticx2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(input_data, data,
                                                     args.eps, args.lp,
                                                     criticx1, criticx2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticx1.step()
            optim_criticx2.step()

            optim_criticy1.zero_grad()
            optim_criticy2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(input_data, datay,
                                                     args.eps, args.lp,
                                                     criticy1, criticy2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticy1.step()
            optim_criticy2.step()

            optim_criticz1.zero_grad()
            optim_criticz2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(input_data, dataz,
                                                     args.eps, args.lp,
                                                     criticz1, criticz2)
            (r_loss + g_loss + p).backward(mone)
            optim_criticz1.step()
            optim_criticz2.step()

            #optim_criticw1.zero_grad()
            #optim_criticw2.zero_grad()
            #r_loss, g_loss, p = disc_loss_generation(input_data, dataw, args.eps, args.lp, criticw1, criticw2)
            #(r_loss + g_loss + p).backward(mone)
            #optim_criticw1.step()
            #optim_criticw2.step()

        optim_generator.zero_grad()
        t_ = Dirichlet(torch.FloatTensor([1., 1.,
                                          1.])).sample().to(args.device)
        t = torch.stack([t_] * input_data.shape[0])
        tinputdata = torch.cat([input_data] * args.nt)
        tdata = torch.cat([data] * args.nt)
        tdatay = torch.cat([datay] * args.nt)
        tdataz = torch.cat([dataz] * args.nt)
        #tdataw = torch.cat([dataw]*args.nt)
        t_lossx = transfer_loss(tinputdata, tdata, args.nt, t, args.eps,
                                args.lp, criticx1, criticx2, generator)
        t_lossy = transfer_loss(tinputdata, tdatay, args.nt, t, args.eps,
                                args.lp, criticy1, criticy2, generator)
        t_lossz = transfer_loss(tinputdata, tdataz, args.nt, t, args.eps,
                                args.lp, criticz1, criticz2, generator)
        #t_lossw = transfer_loss(tinputdata, tdataw, args.nt, t, args.eps, args.lp, criticw1, criticw2, generator)
        t_loss = (t_[0] * t_lossx + t_[1] * t_lossy + t_[2] * t_lossz).sum()
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            datax = batchx.to(args.device)
            #labelx = batchx[1].to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            datay = batchy.to(args.device)
            #labely = batchy[1].to(args.device)
            batchz, titer3 = sample(titer3, test_loader3)
            dataz = batchz.to(args.device)
            #labelz = batchz[1].to(args.device)
            evaluate(args.visualiser, datax, datay, dataz, generator, 'x',
                     args.device)
            #batchw, titerw = sample(titer4, test_loader4)
            #dataw = batchw[0].to(args.device)
            #labelw = batchw[1].to(args.device)
            #evaluate_3d(args.visualiser, datax, datay, dataz, dataw, labelx, labely, labelz, labelw, generator, 'x', args.device)
            d_loss = (r_loss + g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y')
            args.visualiser.plot(step=i,
                                 data=t_lossx.detach().cpu().numpy(),
                                 title=f'Generator loss X')
            args.visualiser.plot(step=i,
                                 data=t_lossy.detach().cpu().numpy(),
                                 title=f'Generator loss Y')
            #args.visualiser.plot(step=i, data=t_lossw.detach().cpu().numpy(), title=f'Generator loss W')
            args.visualiser.plot(step=i,
                                 data=p.detach().cpu().numpy(),
                                 title=f'Penalty')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#18
0
def train(args):
    parameters = vars(args)
    valid_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    critic = models['critic'].to(args.device)
    eval = args.evaluation.eval().to(args.device)
    print(generator)
    print(critic)

    optim_critic = optim.Adam(critic.parameters(),
                              lr=args.lr,
                              betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter2 = iter(train_loader2)
    titer, titer2 = iter(test_loader1), iter(test_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        critic.train()
        for _ in range(args.d_updates):
            batch, iter2 = sample(iter2, train_loader2)
            data = batch[0].to(args.device)
            label = corrupt(batch[1], args.nc, args.corrupt_tgt)
            label = one_hot_embedding(label, args.nc).to(args.device)
            optim_critic.zero_grad()
            pos_loss, neg_loss, gp = critic_loss(data, label, args.z_dim,
                                                 critic, generator,
                                                 args.device)
            pos_loss.backward()
            neg_loss.backward(mone)
            (10 * gp).backward()
            optim_critic.step()

        optim_generator.zero_grad()
        t_loss = transfer_loss(data.shape[0], label, args.z_dim, critic,
                               generator, args.device)
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            print('Iter: %s' % i, time.time() - t0)
            generator.eval()
            batch, titer = sample(titer, test_loader1)
            data1 = batch[0].to(args.device)
            label = one_hot_embedding(batch[1], args.nc).to(args.device)
            batch, titer = sample(titer2, test_loader2)
            data2 = batch[0].to(args.device)
            plot_transfer(args.visualiser, label, args.nc, data1, data2,
                          args.nz, generator, args.device, i)
            save_path = args.save_path
            eval_accuracy = evaluate(valid_loader1, args.nz, args.nc,
                                     args.corrupt_src, generator, eval,
                                     args.device)
            test_accuracy = evaluate(test_loader1, args.nz, args.nc,
                                     args.corrupt_src, generator, eval,
                                     args.device)
            with open(os.path.join(save_path, 'critic_loss'), 'a') as f:
                f.write(f'{i},{(pos_loss-neg_loss).cpu().item()}\n')
            with open(os.path.join(save_path, 'tloss'), 'a') as f:
                f.write(f'{i},{t_loss.cpu().item()}\n')
            with open(os.path.join(save_path, 'eval_accuracy'), 'a') as f:
                f.write(f'{i},{eval_accuracy}\n')
            with open(os.path.join(save_path, 'test_accuracy'), 'a') as f:
                f.write(f'{i},{eval_accuracy}\n')
            args.visualiser.plot((pos_loss - neg_loss).cpu().detach().numpy(),
                                 title='critic_loss',
                                 step=i)
            args.visualiser.plot(t_loss.cpu().detach().numpy(),
                                 title='tloss',
                                 step=i)
            args.visualiser.plot(eval_accuracy,
                                 title=f'Validation transfer accuracy',
                                 step=i)
            args.visualiser.plot(test_accuracy,
                                 title=f'Test transfer accuracy',
                                 step=i)

            t0 = time.time()
            save_models(models, 0, args.model_path, args.checkpoint)
示例#19
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    criticx1 = models['criticx1'].to(args.device)
    criticx2 = models['criticx2'].to(args.device)
    criticy1 = models['criticy1'].to(args.device)
    criticy2 = models['criticy2'].to(args.device)
    print(generator)
    print(criticx1)
    print(criticx2)
    print(criticy1)
    print(criticy2)

    optim_criticx1 = optim.Adam(criticx1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticx2 = optim.Adam(criticx2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticy1 = optim.Adam(criticy1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_criticy2 = optim.Adam(criticy2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    iter1, iter2 = iter(train_loader1), iter(train_loader2)
    iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2 = iter(test_loader1), iter(test_loader2)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        criticx1.train()
        criticx2.train()
        criticy1.train()
        criticy2.train()
        for _ in range(args.d_updates):
            batchx, iter1 = sample(iter1, train_loader1)
            data = batchx.to(args.device)
            optim_criticx1.zero_grad()
            optim_criticx2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, data, args.z_dim, args.eps, args.lp, criticx1, criticx2, generator, args.device)
            (r_loss + g_loss + p).backward(mone)
            optim_criticx1.step()
            optim_criticx2.step()

            batchy, iter2 = sample(iter2, train_loader2)
            datay = batchy.to(args.device)
            optim_criticy1.zero_grad()
            optim_criticy2.zero_grad()
            r_loss, g_loss, p = disc_loss_generation(data, datay, args.z_dim, args.eps, args.lp, criticy1, criticy2, generator, args.device)
            (r_loss + g_loss + p).backward(mone)
            optim_criticy1.step()
            optim_criticy2.step()

        optim_generator.zero_grad()
        t_lossx = transfer_loss(data, data, args.z_dim, args.eps, args.lp, criticx1, criticx2, generator, args.device)
        t_lossy = transfer_loss(data, datay, args.z_dim, args.eps, args.lp, criticy1, criticy2, generator, args.device)
        (0.5*t_lossx + 0.5*t_lossy).backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            datax = batchx.to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            datay = batchy.to(args.device)
            data = torch.randn(args.test_batch_size, args.z_dim, device=args.device)
            evaluate(args.visualiser, data, datax, datay, generator, 'x')
            d_loss = (r_loss+g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss')
            args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss y')
            args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss x')
            args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)
示例#20
0
def train(args):
    parameters = vars(args)
    train_loader1, test_loader1 = args.loaders1
    train_loader2, test_loader2 = args.loaders2

    models = define_models(**parameters)
    initialize(models, args.reload, args.save_path, args.model_path)

    generator = models['generator'].to(args.device)
    criticx1 = models['criticx1'].to(args.device)
    criticx2 = models['criticx2'].to(args.device)
    criticy1 = models['criticy1'].to(args.device)
    criticy2 = models['criticy2'].to(args.device)
    print(generator)
    print(criticx1)
    print(criticy1)

    optim_criticx1 = optim.Adam(criticx1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticx2 = optim.Adam(criticx2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticy1 = optim.Adam(criticy1.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_criticy2 = optim.Adam(criticy2.parameters(),
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
    optim_generator = optim.Adam(generator.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, args.beta2))

    iter1, iter2 = iter(train_loader1), iter(train_loader2)
    iteration = infer_iteration(
        list(models.keys())[0], args.reload, args.model_path, args.save_path)
    titer1, titer2 = iter(test_loader1), iter(test_loader2)
    mone = torch.FloatTensor([-1]).to(args.device)
    t0 = time.time()

    generator.train()
    criticx1.train()
    criticx2.train()
    criticy1.train()
    criticy2.train()
    for i in range(4000):
        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)

        optim_criticx1.zero_grad()
        optim_criticx2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, data, args.eps, args.lp,
                                                 criticx1, criticx2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticx1.step()
        optim_criticx2.step()

        optim_criticy1.zero_grad()
        optim_criticy2.zero_grad()
        r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps,
                                                 args.lp, criticy1, criticy2)
        (r_loss + g_loss + p).backward(mone)
        optim_criticy1.step()
        optim_criticy2.step()
        if i % 100 == 0:
            print(f'Critics-{i}')
            print('Iter: %s' % i, time.time() - t0)
            args.visualiser.plot(step=i,
                                 data=p.detach().cpu().numpy(),
                                 title=f'Penalty')
            d_loss = (r_loss + g_loss).detach().cpu().numpy()
            args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y')
            t0 = time.time()
    for i in range(iteration, args.iterations):
        generator.train()
        criticx1.train()
        criticx2.train()
        criticy1.train()
        criticy2.train()
        batchx, iter1 = sample(iter1, train_loader1)
        data = batchx[0].to(args.device)
        #if data.shape[0] != args.train_batch_size:
        #batchx, iter1 = sample(iter1, train_loader1)
        #data = batchx[0].to(args.device)

        batchy, iter2 = sample(iter2, train_loader2)
        datay = batchy[0].to(args.device)
        #if datay.shape[0] != args.train_batch_size:
        #batchy, iter2 = sample(iter2, train_loader2)
        #datay = batchy[0].to(args.device)

        optim_generator.zero_grad()
        t_ = torch.rand(args.nt, device=args.device)
        t = torch.stack([t_] * data.shape[0])
        #t = torch.stack([t_] * input_data.shape[0]).transpose(0, 1).reshape(-1, 1)
        tdata = torch.cat([data] * args.nt)
        tdatay = torch.cat([datay] * args.nt)
        t_lossx = transfer_loss(tdata, tdata, args.nt, t, args.eps, args.lp,
                                criticx1, criticx2, generator)
        t_lossy = transfer_loss(tdata, tdatay, args.nt, t, args.eps, args.lp,
                                criticy1, criticy2, generator)
        t_loss = ((1 - t_) * t_lossx + t_ * t_lossy).sum()
        t_loss.backward()
        optim_generator.step()

        if i % args.evaluate == 0:
            generator.eval()
            print('Iter: %s' % i, time.time() - t0)
            batchx, titer1 = sample(titer1, test_loader1)
            datax = batchx[0].to(args.device)
            batchy, titer2 = sample(titer2, test_loader2)
            datay = batchy[0].to(args.device)
            evaluate(args.visualiser, datax, datay, generator, 'x',
                     args.device)
            args.visualiser.plot(step=i,
                                 data=t_lossx.detach().cpu().numpy(),
                                 title=f'Generator loss X')
            args.visualiser.plot(step=i,
                                 data=t_lossy.detach().cpu().numpy(),
                                 title=f'Generator loss Y')
            t0 = time.time()
            save_models(models, i, args.model_path, args.checkpoint)