Example #1
0
    def learn(self, real_A, real_B, z_B, lambda_A, lambda_B, lambda_z_B,
              **kwargs):
        lr = kwargs.pop('lr')
        beta1 = kwargs.pop('beta1')
        max_norm = kwargs.pop('max_norm', 500.)

        dis_losses, dis_preds = self.get_dis_cost(real_A, real_B, z_B)
        gen_losses, _ = self.get_gen_cost(real_A, real_B, z_B, lambda_A,
                                          lambda_B, lambda_z_B)

        dis_params = self.netD_A.trainable + self.netD_B.trainable
        if self.use_latent_gan:
            dis_params += self.netD_z_B.trainable
        gen_params = self.netG_A_B.trainable + self.netG_B_A.trainable + self.netE_B.trainable

        updates_dis, _, dis_grads = nn.adam(dis_losses['loss_D'],
                                            dis_params,
                                            lr,
                                            beta1,
                                            clip_by_norm=max_norm)
        updates_gen, _, gen_grads = nn.adam(gen_losses['loss_G'],
                                            gen_params,
                                            lr,
                                            beta1,
                                            clip_by_norm=max_norm)

        grad_norms = OrderedDict(
            zip([dis_param.name.replace('/', '_') for dis_param in dis_params],
                [nn.utils.p_norm(dis_grad, 2) for dis_grad in dis_grads]))
        grad_norms.update(
            zip([gen_param.name.replace('/', '_') for gen_param in gen_params],
                [nn.utils.p_norm(gen_grad, 2) for gen_grad in gen_grads]))
        return updates_dis, updates_gen, dis_losses, dis_preds, gen_losses, grad_norms
Example #2
0
 def learn(self, image, noise, lr, beta1):
     gen_cost, dis_cost = self.get_cost(image, noise)
     self.opt_gen, updates_gen = nn.adam(gen_cost,
                                         self.gen.trainable,
                                         lr,
                                         beta1,
                                         return_op=True)
     self.opt_dis, updates_dis = nn.adam(dis_cost,
                                         self.dis.trainable,
                                         lr,
                                         beta1,
                                         return_op=True)
     return gen_cost, dis_cost, updates_gen, updates_dis
Example #3
0
 def get_updates(self, x, y, l2_coeff, lr):
     losses = self.get_cost(x, y, l2_coeff)
     updates, _, grads = nn.adam(losses['total'], self.trainable, lr)
     grad_norms = dict([(self.trainable[idx].name.replace('/', '_'),
                         nn.utils.p_norm(grad))
                        for idx, grad in enumerate(grads)])
     return updates, losses, grad_norms
Example #4
0
def train():
    enc = VGG19(input_shape)
    decs = [Decoder(enc, i, name='decoder %d' % i) for i in indices]
    sty_net = StyleTransfer(enc, decs)

    X = T.tensor4('input')
    Y = T.tensor4('style')
    idx = T.scalar('iter', 'int32')
    X_ = nn.placeholder((bs,) + input_shape[1:], name='input_plhd')
    Y_ = nn.placeholder((bs,) + input_shape[1:], name='style_plhd')
    lr_ = nn.placeholder(value=lr, name='lr_plhd')

    nn.set_training_on()
    losses = [dec.cost(X) for dec in decs]
    updates = [nn.adam(loss[0] + weight * loss[1], dec.trainable, lr) for loss, dec in zip(losses, decs)]
    nn.anneal_learning_rate(lr_, idx, 'inverse', decay=decay)
    trains = [nn.function([], [loss[0], loss[1], dec(X, True)], givens={X: X_}, updates=update, name='train decoder')
              for loss, dec, update in zip(losses, decs, updates)]

    nn.set_training_off()
    X_styled = sty_net(X, Y)
    transfer = nn.function([], X_styled, givens={X: X_, Y: Y_}, name='transfer style')

    data_train = DataManager(X_, input_path_train, bs, n_epochs, True, num_val_imgs=num_val_imgs, input_shape=input_shape)
    data_test = DataManagerStyleTransfer((X_, Y_), (input_path_val, style_path_val), bs, 1, input_shape=input_shape)
    mon = nn.Monitor(model_name='WCT', valid_freq=print_freq)

    print('Training...')
    for it in data_train:
        results = [train(it) for train in trains]

        with mon:
            for layer, res in zip(indices, results):
                if np.isnan(res[0] + res[1]) or np.isinf(res[0] + res[1]):
                    raise ValueError('Training failed!')
                mon.plot('pixel loss at layer %d' % layer, res[0])
                mon.plot('feature loss at layer %d' % layer, res[1])

            if it % val_freq == 0:
                mon.imwrite('recon img at layer %d' % layer, res[2])

                for i in data_test:
                    img = transfer()
                    mon.imwrite('stylized image %d' % i, img)
                    mon.imwrite('input %d' % i, X_.get_value())
                    mon.imwrite('style %d' % i, Y_.get_value())

                for idx, dec in zip(indices, decs):
                    mon.dump(nn.utils.shared2numpy(dec.params), 'decoder-%d.npz' % idx, 5)
    mon.flush()
    for idx, dec in zip(indices, decs):
        mon.dump(nn.utils.shared2numpy(dec.params), 'decoder-%d-final.npz' % idx)
    print('Training finished!')
Example #5
0
def train_sngan(z_dim=128, image_shape=(3, 32, 32), bs=64, n_iters=int(1e5)):
    gen = DCGANGenerator((None, z_dim))
    dis = SNDCGANDiscriminator(gen.output_shape)

    z = srng.uniform((bs, z_dim), -1, 1, ndim=2, dtype='float32')
    X = T.tensor4('image', 'float32')
    X_ = theano.shared(np.zeros((bs, ) + image_shape, 'float32'),
                       'image_placeholder')

    # training
    nn.set_training_status(True)
    X_fake = gen(z)
    y_fake = dis(X_fake)
    y_real = dis(X)

    dis_loss_real = T.mean(T.nnet.softplus(-y_real))
    dis_loss_fake = T.mean(T.nnet.softplus(y_fake))
    dis_loss = dis_loss_real + dis_loss_fake
    gen_loss = T.mean(T.nnet.softplus(-y_fake))

    updates_gen = nn.adam(gen_loss, gen.trainable, args.adam_alpha,
                          args.adam_beta1, args.adam_beta2)
    updates_dis = nn.adam(dis_loss, dis.trainable, args.adam_alpha,
                          args.adam_beta1, args.adam_beta2)

    train_gen = nn.function([],
                            gen_loss,
                            updates=updates_gen,
                            name='train generator')
    train_dis = nn.function([],
                            dis_loss,
                            updates=updates_dis,
                            givens={X: X_},
                            name='train discriminator')

    # testing
    nn.set_training_status(False)
    fixed_noise = T.constant(np.random.uniform(-1, 1, (bs, z_dim)),
                             'fixed noise', 2, 'float32')
    gen_imgs = gen(fixed_noise)
    generate = nn.function([], gen_imgs, name='generate images')

    dm = DataManager(X_, n_iters, bs, True)
    mon = nn.monitor.Monitor(model_name='LSGAN', use_visdom=args.use_visdom)
    epoch = 0
    print('Training...')
    batches = dm.get_batches(epoch, dm.n_epochs, infinite=True)
    start = time.time()
    for iteration in range(n_iters):
        #update generator
        training_gen_cost = train_gen()
        if np.isnan(training_gen_cost) or np.isinf(training_gen_cost):
            raise ValueError('Training failed due to NaN cost')
        mon.plot('training gen cost', training_gen_cost)

        #update discriminator
        training_disc_cost = []
        for i in range(args.n_dis):
            batches.__next__()
            training_disc_cost.append(train_dis())
            if np.isnan(training_disc_cost[-1]) or np.isinf(
                    training_disc_cost[-1]):
                raise ValueError('Training failed due to NaN cost')
        mon.plot('training disc cost', np.mean(training_disc_cost))

        if iteration % args.valid_freq == 0:
            gen_images = generate()
            mon.imwrite('generated image', dm.unnormalize(gen_images))
            mon.plot('time elapsed', (time.time() - start) / 60.)
            mon.flush()
        mon.tick()
    mon.flush()
    print('Training finished!')