ntrain = tr_data.num_examples
    print '# examples', tr_data.num_examples
    print '# training examples', ntrain_s
    print '# validation examples', nval_s

    tr_handle = tr_data.open()
    vaX, labels = tr_data.get_data(tr_handle, slice(0, 10000))
    vaX = transform(vaX)
    means = labels.mean(axis=0)
    print('labels ', labels.shape, means, means[0] / means[1])

    vaY, labels = tr_data.get_data(tr_handle, slice(10000, min(ntrain, 20000)))
    vaY = transform(vaY)

    va_nnd_1k = nnd_score(vaY.reshape((len(vaY), -1)),
                          vaX.reshape((len(vaX), -1)),
                          metric='euclidean')
    print 'va_nnd_1k = %.2f' % (va_nnd_1k)
    means = labels.mean(axis=0)
    print('labels ', labels.shape, means, means[0] / means[1])

#####################################
# shared variables
gifn = inits.Normal(scale=0.02)
difn = inits.Normal(scale=0.02)
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)

gw = gifn((nz, ngf * 8 * 4 * 4), 'gw')
gg = gain_ifn((ngf * 8 * 4 * 4), 'gg')
gb = bias_ifn((ngf * 8 * 4 * 4), 'gb')
Beispiel #2
0
for epoch in range(niter):
    for imb, in tqdm(tr_stream.get_epoch_iterator(), total=ntrain/nbatch):
        imb = transform(imb) # Shared image batch used to update both of g and d.
        zmb = floatX(np_rng.uniform(-1., 1., size=(len(imb), nz))) # Shared z batch used to update both of g and d.
								   # Randomly initialized regardless of image batch.
        if n_updates % (k+1) == 0:
            cost = _train_g(imb, zmb) # Updates g one time.
        else:
            cost = _train_d(imb, zmb) # Updates d two times from both of real/fake inputs.
        n_updates += 1
        n_examples += len(imb)
    g_cost = float(cost[0])
    d_cost = float(cost[1])
    gX = gen_samples(100000)
    gX = gX.reshape(len(gX), -1)
    va_nnd_1k = nnd_score(gX[:1000], vaX, metric='euclidean')
    va_nnd_10k = nnd_score(gX[:10000], vaX, metric='euclidean')
    va_nnd_100k = nnd_score(gX[:100000], vaX, metric='euclidean')
    log = [n_epochs, n_updates, n_examples, time()-t, va_nnd_1k, va_nnd_10k, va_nnd_100k, g_cost, d_cost]
    print '%.0f %.2f %.2f %.2f %.4f %.4f'%(epoch, va_nnd_1k, va_nnd_10k, va_nnd_100k, g_cost, d_cost)
    f_log.write(json.dumps(dict(zip(log_fields, log)))+'\n')
    f_log.flush()
    samples = np.asarray(_gen(sample_zmb))
    color_grid_vis(inverse_transform(samples), (14, 14), 'samples/%s/%d.png'%(desc, n_epochs))
    n_epochs += 1
    if n_epochs > niter:
        lrt.set_value(floatX(lrt.get_value() - lr/niter_decay))
    if n_epochs in [1, 2, 3, 4, 5, 10, 15, 20, 25]:
        joblib.dump([p.get_value() for p in gen_params], 'models/%s/%d_gen_params.jl'%(desc, n_epochs))
        joblib.dump([p.get_value() for p in discrim_params], 'models/%s/%d_discrim_params.jl'%(desc, n_epochs))
Beispiel #3
0
def eval_and_disp(epoch, costs, ng=(10 * megabatch_size)):
    start_time = time()
    kwargs = dict(metric='euclidean')
    cost_string = '  '.join('%s: %.4f' % o
                            for o in zip(disp_costs.keys(), costs))
    print '%*d) %s' % (iter_pad, epoch, cost_string)
    outs = OrderedDict()
    _feats = {}

    def _get_feats(f, x):
        key = f, id(x)
        if key not in _feats:
            _feats[key] = batch_feats(f, x)
        return _feats[key]

    def _nnc(inputs, labels, f=None):
        assert len(inputs) == len(labels) == 2
        if f is not None:
            inputs = (_get_feats(f, x) for x in inputs)
        (vaX, trX), (vaY, trY) = inputs, labels
        return nnc_score(flat(trX), trY, flat(vaX), vaY, **kwargs)

    gX = flat(batch_feats(_gen, eval_gen_inputs, wraparound=True))
    nnd_sizes = [100, 10, 1]
    nndVaXImages = flat(transform(vaXImages))
    for subsample in nnd_sizes:
        size = ng // subsample
        gXsubset = gX[:size]
        suffix = '' if (subsample == 1) else '/%d' % subsample
        outs['NND' + suffix] = nnd_score(gXsubset, nndVaXImages, **kwargs)
    labels = vaY, trY
    images = vaXImages, trXImages
    big_images = vaXBigImages, trXBigImages
    if args.encode:
        outs['NNC_e'] = _nnc(big_images, labels, f=_enc_l2distable)
        outs['NNC_e-'] = _nnc(big_images, labels, f=_enc_feats)
    if f_discrim is not None:
        outs['NNC_d'] = _nnc(images, labels, f=_discrim_feats)
    if args.classifier:

        def accuracy(func, feat, Y):
            return 100 * (batch_feats(func, feat).argmax(axis=1) == Y).mean()

        if args.encode:
            f = _get_feats(_enc_feats, big_images[0])
            outs['CLS_e-'] = accuracy(_enc_preds, f, vaY)
        if f_discrim is not None:
            f = _get_feats(_discrim_feats, images[0])
            outs['CLS_d'] = accuracy(_discrim_preds, f, vaY)
    if args.encode:

        def image_recon_error(enc_inputs, recon_sized_inputs=None):
            def sqerr(a, b, axis):
                return ((a - b)**2).sum(axis=axis)**0.5

            def _f_error(enc_inputs, recon_sized_inputs):
                gen_input = _enc_recon(enc_inputs)
                recon = _gen(*gen_input)
                if isinstance(recon_sized_inputs, list):
                    recon_sized_inputs = recon_sized_inputs[0]
                inputs = transform(recon_sized_inputs, crop=args.crop_resize)
                axis = tuple(range(1, inputs.ndim))
                error = sqerr(inputs, recon, axis=axis).reshape(-1, 1)
                assert len(inputs) > 1
                shifted_inputs = np.concatenate([inputs[1:], inputs[:1]],
                                                axis=0)
                base_error = sqerr(shifted_inputs, recon,
                                   axis=axis).reshape(-1, 1)
                return np.concatenate([error, base_error], axis=1)

            if recon_sized_inputs is None:
                recon_sized_inputs = enc_inputs
            errors = batch_feats(_f_error, [enc_inputs, recon_sized_inputs],
                                 wraparound=True)
            return errors.mean(axis=0)

        outs['EGr'], outs['EGr_b'] = image_recon_error(big_images[0],
                                                       images[0])
        if args.crop_size == args.crop_resize:
            outs['EGg'], outs['EGg_b'] = image_recon_error(
                gen_output_to_enc_input(gX))

    def format_str(key):
        def is_prop(key, prop_metrics=['NNC', 'CLS']):
            return any(key.startswith(m) for m in prop_metrics)

        if key in ('El', 'El_r'):
            return '%s: %.2e'
        return '%s: %.2f' + ('%%' if is_prop(key) else '')

    print '  '.join(format_str(k) % (k, v) for k, v in outs.iteritems())
    samples = batch_feats(_gen, sample_inputs, wraparound=True)
    sample_shape = num_sample_rows, num_sample_cols

    def imname(tag=None):
        tag = '' if (tag is None) else (tag + '.')
        return '%s/%d.%spng' % (samples_dir, epoch, tag)

    dataset.grid_vis(inverse_transform(samples), sample_shape, imname())
    if args.encode:
        if args.crop_size == args.crop_resize:
            # pass the generator's samples back through encoder;
            # then pass codes back through generator
            enc_gen_inputs = gen_output_to_enc_input(samples)
            samples_enc = batch_feats(_enc_recon,
                                      enc_gen_inputs,
                                      wraparound=True)
            samples_regen = batch_feats(_gen, samples_enc, wraparound=True)
            dataset.grid_vis(inverse_transform(samples_regen), sample_shape,
                             imname('regen'))
        assert trXVisRaw.dtype == np.uint8
        enc_real_input = trXBigVisRaw
        for func, name in [(_enc_recon, 'real_regen'),
                           (_enc_sample, 'real_regen_s')]:
            real_enc = batch_feats(func, enc_real_input, wraparound=True)
            real_regen = batch_feats(_gen, real_enc, wraparound=True)
            dataset.grid_vis(inverse_transform(real_regen), grid_shape,
                             imname(name))
    eval_time = time() - start_time
    sys.stdout.write('Eval done. (%f seconds)\n' % eval_time)
Beispiel #4
0
def train(conf):
    # set up object
    gan = model.GAN(conf.z_dim, conf.img_h, conf.img_w, conf.c_dim,
                    conf.g_learning_rate, conf.d_learning_rate,
                    conf.g_beta1, conf.d_beta2,
                    conf.gf_dim, conf.df_dim)
    sample_x = data_utils.DataSet(conf.X)

    # log ground truth
    vis_nsample = min(6, conf.nbatch)
    vis_X = conf.X[:vis_nsample]
    vis_X = vis_X.reshape([vis_X.shape[0], -1])
    vis.plot_series(
        vis_X, os.path.join(conf.dir_samples, "000_real.png"))
    # save variables to log
    save_variables(conf, os.path.join(conf.dir_logs,'variables_{}'.format(conf.model_name)))
    f_log_train = open(os.path.join(conf.dir_logs,'log_train_{}.ndjson'.format(conf.model_name)), 'w')
    log_fields = [
        'n_epoches',
        'n_updates',
        'n_examples',
        'n_seconds',
        '1k_va_nnd',
        '10k_va_nnd',
        '100k_va_nnd',
        'g_loss',
        'd_loss_real',
        'd_loss_fake'
    ]

    # set up tf session and train model
    with tf.Session(config=tf_conf) as sess:
        # initialize
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        # train
        n_updates = 0
        n_epoches = 0
        n_examples = 0
        g_losses, d_losses, d_losses_fake, d_losses_real = [], [], [], []
        nnd_1k, nnd_10k, nnd_100k = [], [], []
        t = time()
        for epoch in xrange(conf.nepoch):
            g_loss, d_loss, d_loss_fake, d_loss_real = np.zeros(4)
            for i in xrange(sample_x.num_examples // conf.nbatch):
                x = sample_x.next_batch(conf.nbatch)
                z = sample_z([conf.nbatch, conf.z_dim])

                _ = sess.run(gan.d_opt, feed_dict={gan.x: x, gan.z: z})
                _ = sess.run(gan.g_opt, feed_dict={gan.z: z})
                _ = sess.run(gan.g_opt, feed_dict={gan.z: z})

                d_loss, d_loss_real, d_loss_fake, g_loss = sess.run(
                    [gan.d_loss, gan.d_loss_real, gan.d_loss_fake, gan.g_loss],
                    feed_dict={gan.x: x, gan.z: z})
                n_updates += 1
                n_examples += len(x)
            n_epoches += 1
            g_losses.append(g_loss)
            d_losses.append(d_loss)
            d_losses_fake.append(d_loss_fake)
            d_losses_real.append(d_loss_real)

            # log
            if epoch % conf.freq_print == 0:
                print("Epoch: [{}/{}], g_loss = {:.4f}, d_loss = {:.4f}, "
                      "d_loss_fake = {:.4f}, d_loss_reak = {:.4f}".format(
                    epoch, conf.nepoch,
                    g_loss, d_loss, d_loss_fake, d_loss_real))
            if epoch % conf.freq_log == 0:
                # eval
                gX = gan_sample(sess, gan, conf, conf.nsample)
                gX = gX.reshape(len(gX), -1)
                teX = conf.X.reshape(len(conf.X), -1)
                # teX = conf.teX.reshape(len(conf.teX), -1)
                va_nnd_1k = metrics.nnd_score(gX[:1000], teX, metric='euclidean')
                va_nnd_10k = metrics.nnd_score(gX[:10000], teX, metric='euclidean')
                va_nnd_100k = metrics.nnd_score(gX[:100000], teX, metric='euclidean')
                nnd_1k.append(va_nnd_1k)
                nnd_10k.append(va_nnd_10k)
                nnd_100k.append(va_nnd_100k)

                log_valus = [n_epoches, n_updates, n_examples, time()-t,
                             va_nnd_1k, va_nnd_10k, va_nnd_100k,
                             float(g_loss), float(d_loss_real), float(d_loss_fake)]
                f_log_train.write(
                    json.dumps(dict(zip(log_fields, log_valus))) + '\n')
                f_log_train.flush()
                # save checkpoint
                gan.save(sess, conf.dir_checkpoint, n_updates, conf.model_name)

            if epoch % conf.freq_plot == 0:
                samples = gan_sample(sess, gan, conf, vis_nsample)
                samples = samples.reshape([samples.shape[0],-1])
                img_path = os.path.join(
                    conf.dir_samples,"train_{}.png".format(str(epoch+1).zfill(4)))
                vis.plot_series(samples, img_path)

        # plot loss
        losses = {'g_loss': np.array(g_losses),
                  'd_loss': np.array(d_losses),
                  'd_loss_fake': np.array(d_losses_fake),
                  'd_loss_real': np.array(d_losses_real)}
        vis.plot_dic(losses, title='{}_loss'.format(conf.data_name),
                     save_path=os.path.join(conf.dir_logs, 'loss_{}.png'.format(conf.model_name)))
        nnd = {'nnd_1k': np.array(nnd_1k),
               'nnd_10k': np.array(nnd_10k),
               'nnd_100k': np.array(nnd_100k)}
        vis.plot_dic(nnd, title='{}_nnd'.format(conf.data_name),
                     save_path=os.path.join(conf.dir_logs, 'nnd_{}.png'.format(conf.model_name)))
Beispiel #5
0
    if ntrain is None:
        ntrain = tr_data.num_examples
    print '# examples', tr_data.num_examples
    print '# training examples', ntrain_s
    print '# validation examples', nval_s

    tr_handle = tr_data.open()
    vaX,labels = tr_data.get_data(tr_handle, slice(0, 10000))
    vaX = transform(vaX)
    means = labels.mean(axis=0)
    print('labels ',labels.shape,means,means[0]/means[1])

    vaY,labels = tr_data.get_data(tr_handle, slice(10000, min(ntrain, 20000)))
    vaY = transform(vaY)

    va_nnd_1k = nnd_score(vaY.reshape((len(vaY),-1)), vaX.reshape((len(vaX),-1)), metric='euclidean')
    print 'va_nnd_1k = %.2f'%(va_nnd_1k)
    means = labels.mean(axis=0)
    print('labels ',labels.shape,means,means[0]/means[1])

#####################################
# shared variables
gifn = inits.Normal(scale=0.02)
difn = inits.Normal(scale=0.02)
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)

gw  = gifn((nz, ngf*8*4*4), 'gw')
gg = gain_ifn((ngf*8*4*4), 'gg')
gb = bias_ifn((ngf*8*4*4), 'gb')
gw2 = gifn((ngf*8, ngf*4, 5, 5), 'gw2')
Beispiel #6
0
t = time()
for epoch in range(niter):
    for imb, in tqdm(tr_stream.get_epoch_iterator(), total=ntrain / nbatch):
        imb = transform(imb)
        zmb = floatX(np_rng.uniform(-1., 1., size=(len(imb), nz)))
        if n_updates % (k + 1) == 0:
            cost = _train_g(imb, zmb)
        else:
            cost = _train_d(imb, zmb)
        n_updates += 1
        n_examples += len(imb)
    g_cost = float(cost[0])
    d_cost = float(cost[1])
    gX = gen_samples(100000)
    gX = gX.reshape(len(gX), -1)
    va_nnd_1k = nnd_score(gX[:1000], vaX, metric='euclidean')
    va_nnd_10k = nnd_score(gX[:10000], vaX, metric='euclidean')
    va_nnd_100k = nnd_score(gX[:100000], vaX, metric='euclidean')
    log = [
        n_epochs, n_updates, n_examples,
        time() - t, va_nnd_1k, va_nnd_10k, va_nnd_100k, g_cost, d_cost
    ]
    print '%.0f %.2f %.2f %.2f %.4f %.4f' % (epoch, va_nnd_1k, va_nnd_10k,
                                             va_nnd_100k, g_cost, d_cost)
    f_log.write(json.dumps(dict(zip(log_fields, log))) + '\n')
    f_log.flush()

    samples = np.asarray(_gen(sample_zmb))
    color_grid_vis(inverse_transform(samples), (14, 14),
                   'samples/%s/%d.png' % (desc, n_epochs))
    n_epochs += 1
Beispiel #7
0
def train(conf):
    # set up object
    gan = model.GAN(conf.z_dim, conf.img_h, conf.img_w, conf.c_dim,
                    conf.g_learning_rate, conf.d_learning_rate, conf.g_beta1,
                    conf.d_beta2, conf.gf_dim, conf.df_dim)
    sample_x = data_utils.DataSet(conf.X)

    # log ground truth
    vis_nsample = min(6, conf.nbatch)
    vis_X = conf.X[:vis_nsample]
    vis_X = vis_X.reshape([vis_X.shape[0], -1])
    vis.plot_series(vis_X, os.path.join(conf.dir_samples, "000_real.png"))
    # save variables to log
    save_variables(
        conf,
        os.path.join(conf.dir_logs, 'variables_{}'.format(conf.model_name)))
    f_log_train = open(
        os.path.join(conf.dir_logs,
                     'log_train_{}.ndjson'.format(conf.model_name)), 'w')
    log_fields = [
        'n_epoches', 'n_updates', 'n_examples', 'n_seconds', '1k_va_nnd',
        '10k_va_nnd', '100k_va_nnd', 'g_loss', 'd_loss_real', 'd_loss_fake'
    ]

    # set up tf session and train model
    with tf.Session(config=tf_conf) as sess:
        # initialize
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        # train
        n_updates = 0
        n_epoches = 0
        n_examples = 0
        g_losses, d_losses, d_losses_fake, d_losses_real = [], [], [], []
        nnd_1k, nnd_10k, nnd_100k = [], [], []
        t = time()
        for epoch in xrange(conf.nepoch):
            g_loss, d_loss, d_loss_fake, d_loss_real = np.zeros(4)
            for i in xrange(sample_x.num_examples // conf.nbatch):
                x = sample_x.next_batch(conf.nbatch)
                z = sample_z([conf.nbatch, conf.z_dim])

                _ = sess.run(gan.d_opt, feed_dict={gan.x: x, gan.z: z})
                _ = sess.run(gan.g_opt, feed_dict={gan.z: z})
                _ = sess.run(gan.g_opt, feed_dict={gan.z: z})

                d_loss, d_loss_real, d_loss_fake, g_loss = sess.run(
                    [gan.d_loss, gan.d_loss_real, gan.d_loss_fake, gan.g_loss],
                    feed_dict={
                        gan.x: x,
                        gan.z: z
                    })
                n_updates += 1
                n_examples += len(x)
            n_epoches += 1
            g_losses.append(g_loss)
            d_losses.append(d_loss)
            d_losses_fake.append(d_loss_fake)
            d_losses_real.append(d_loss_real)

            # log
            if epoch % conf.freq_print == 0:
                print("Epoch: [{}/{}], g_loss = {:.4f}, d_loss = {:.4f}, "
                      "d_loss_fake = {:.4f}, d_loss_reak = {:.4f}".format(
                          epoch, conf.nepoch, g_loss, d_loss, d_loss_fake,
                          d_loss_real))
            if epoch % conf.freq_log == 0:
                # eval
                gX = gan_sample(sess, gan, conf, conf.nsample)
                gX = gX.reshape(len(gX), -1)
                teX = conf.X.reshape(len(conf.X), -1)
                # teX = conf.teX.reshape(len(conf.teX), -1)
                va_nnd_1k = metrics.nnd_score(gX[:1000],
                                              teX,
                                              metric='euclidean')
                va_nnd_10k = metrics.nnd_score(gX[:10000],
                                               teX,
                                               metric='euclidean')
                va_nnd_100k = metrics.nnd_score(gX[:100000],
                                                teX,
                                                metric='euclidean')
                nnd_1k.append(va_nnd_1k)
                nnd_10k.append(va_nnd_10k)
                nnd_100k.append(va_nnd_100k)

                log_valus = [
                    n_epoches, n_updates, n_examples,
                    time() - t, va_nnd_1k, va_nnd_10k, va_nnd_100k,
                    float(g_loss),
                    float(d_loss_real),
                    float(d_loss_fake)
                ]
                f_log_train.write(
                    json.dumps(dict(zip(log_fields, log_valus))) + '\n')
                f_log_train.flush()
                # save checkpoint
                gan.save(sess, conf.dir_checkpoint, n_updates, conf.model_name)

            if epoch % conf.freq_plot == 0:
                samples = gan_sample(sess, gan, conf, vis_nsample)
                samples = samples.reshape([samples.shape[0], -1])
                img_path = os.path.join(
                    conf.dir_samples,
                    "train_{}.png".format(str(epoch + 1).zfill(4)))
                vis.plot_series(samples, img_path)

        # plot loss
        losses = {
            'g_loss': np.array(g_losses),
            'd_loss': np.array(d_losses),
            'd_loss_fake': np.array(d_losses_fake),
            'd_loss_real': np.array(d_losses_real)
        }
        vis.plot_dic(losses,
                     title='{}_loss'.format(conf.data_name),
                     save_path=os.path.join(
                         conf.dir_logs, 'loss_{}.png'.format(conf.model_name)))
        nnd = {
            'nnd_1k': np.array(nnd_1k),
            'nnd_10k': np.array(nnd_10k),
            'nnd_100k': np.array(nnd_100k)
        }
        vis.plot_dic(nnd,
                     title='{}_nnd'.format(conf.data_name),
                     save_path=os.path.join(
                         conf.dir_logs, 'nnd_{}.png'.format(conf.model_name)))
Beispiel #8
0
def eval_and_disp(epoch, costs, ng=(10 * megabatch_size), plot_latent=False):
    start_time = time()
    eval_costs(epoch, costs)
    kwargs = dict(metric='euclidean')
    outs = OrderedDict()
    _feats = {}

    def _get_feats(f, x):
        key = f, id(x)
        if key not in _feats:
            _feats[key] = batch_map(f, x)
        return _feats[key]

    def _nnc(inputs, labels, f=None):
        assert len(inputs) == len(labels) == 2
        if f is not None:
            inputs = (_get_feats(f, x) for x in inputs)
        (vaX, trX), (vaY, trY) = inputs, labels
        return nnc_score(flat(trX), trY, flat(vaX), vaY, **kwargs)

    # gXs = [flat(batch_map(_gen, eval_gen_inputs, wraparound=True)) for _gen in _gens]
    nnd_sizes = [100, 10, 1]
    nndVaXImages = flat(transform(vaXImages))

    labels = vaY, trY
    images = vaXImages, trXImages
    big_images = vaXBigImages, trXBigImages
    if args.encode:
        outs['NNC_e'] = _nnc(big_images, labels, f=_enc_l2distable)
        outs['NNC_e-'] = _nnc(big_images, labels, f=_enc_feats)
        if plot_latent:
            f = _get_feats(_enc_feats, big_images[0][:1500])
            #fe = _get_feats(_enc_l2distable, big_images[0][:1000])
            plot_latent_encodings(
                f,
                labels[0][:1500],
                os.path.join(args.exp_dir, "latent_encodings_e.png"),
                title="%s Latent Encodings" % (args.dataset.upper()))
            """
            plot_latent_encodings(
                fe, labels[0][:1000], 
                os.path.join(args.exp_dir, "latent_encodings_e-.png"))
            """
    if f_discrim is not None:
        outs['NNC_d'] = _nnc(images, labels, f=_discrim_feats)
    if args.classifier:

        def accuracy(func, feat, Y):
            return 100 * (batch_map(func, feat).argmax(axis=1) == Y).mean()

        if args.encode:
            f = _get_feats(_enc_feats, big_images[0])
            outs['CLS_e-'] = accuracy(_enc_preds, f, vaY)

        if f_discrim is not None:
            f = _get_feats(_discrim_feats, images[0])
            outs['CLS_d'] = accuracy(_discrim_preds, f, vaY)
    for gi, _gen in enumerate(_gens):
        gX = flat(batch_map(_gen, eval_gen_inputs, wraparound=True))
        for subsample in nnd_sizes:
            size = ng // subsample
            gXsubset = gX[:size]
            suffix = '' if (subsample == 1) else '/%d' % subsample
            outs['NND_g%d_' % gi + suffix] = nnd_score(gXsubset, nndVaXImages,
                                                       **kwargs)
        if args.encode:

            def image_recon_error(enc_inputs, gi, recon_sized_inputs=None):
                def l2err(a, b, axis):
                    return ((a - b)**2).sum(axis=axis)**0.5

                def _f_error(enc_inputs, recon_sized_inputs):
                    gen_input = _enc_recon(enc_inputs)
                    if isinstance(recon_sized_inputs, list):
                        recon_sized_inputs = recon_sized_inputs[0]
                    inputs = transform(recon_sized_inputs,
                                       crop=args.crop_resize)
                    axis = tuple(range(1, inputs.ndim))
                    recon = _gens[gi](*gen_input)

                    error = l2err(inputs, recon, axis=axis).reshape(-1, 1)
                    assert len(inputs) > 1
                    shifted_inputs = np.concatenate([inputs[1:], inputs[:1]],
                                                    axis=0)
                    base_error = l2err(shifted_inputs, recon,
                                       axis=axis).reshape(-1, 1)
                    return np.concatenate([error, base_error], axis=1)

                if recon_sized_inputs is None:
                    recon_sized_inputs = enc_inputs
                errors = batch_map(_f_error, [enc_inputs, recon_sized_inputs],
                                   wraparound=True)
                return errors.mean(axis=0)

            if args.crop_size == args.crop_resize:
                outs['EGg%d' % gi], outs['EGg%d_b' % gi] = image_recon_error(
                    gen_output_to_enc_input(gX), gi)
            else:
                # TO FIX
                outs['EGr'], outs['EGr_b'] = image_recon_error(
                    big_images[0], images[0])

    def format_str(key):
        def is_prop(key, prop_metrics=['NNC', 'CLS']):
            return any(key.startswith(m) for m in prop_metrics)

        return '%s: %.2f' + ('%%' if is_prop(key) else '')

    print('  '.join(format_str(k) % (k, v) for k, v in outs.items()))

    for gi, _gen in enumerate(_gens):
        samples = batch_map(_gen, sample_inputs, wraparound=True)
        sample_shape = num_sample_rows, num_sample_cols

        def imname(tag=None):
            tag = '' if (tag is None) else (tag + '.')
            tag += "g%d" % gi
            return '%s/%d%s.png' % (samples_dir, epoch, tag)

        dataset.grid_vis(inverse_transform(samples), sample_shape, imname())
        if args.encode:
            # if args.crop_size == args.crop_resize:
            # pass the generator's samples back through encoder;
            # then pass codes back through generator
            # enc_gen_inputs = gen_output_to_enc_input(samples)
            # samples_enc = batch_map(_enc_recon, enc_gen_inputs, wraparound=True)
            # samples_regen = batch_map(_gen, samples_enc, wraparound=True)
            # dataset.grid_vis(inverse_transform(samples_regen), sample_shape,
            #          imname('regen'))
            assert trXVisRaw.dtype == np.uint8
            real_enc = batch_map(_enc_recon, trXBigVisRaw, wraparound=True)
            real_regen = batch_map(_gen, real_enc, wraparound=True)
            dataset.grid_vis(inverse_transform(real_regen), grid_shape,
                             imname("real_regen"))

            # for func, name in [(_enc_recon, 'real_regen'), (_enc_sample, 'real_regen_s')]:
            #     real_enc = batch_map(func, trXBigVisRaw, wraparound=True)
            #     real_regen = batch_map(_gen, real_enc, wraparound=True)
            #     dataset.grid_vis(inverse_transform(real_regen), grid_shape, imname(name))

    eval_time = time() - start_time
    print('Eval done. (%f seconds)\n' % eval_time)
    return outs
Beispiel #9
0
def train(conf):
    # set up object
    gan = model.GAN(conf.z_dim, conf.img_h, conf.img_w, conf.c_dim,
                    conf.g_learning_rate, conf.d_learning_rate, conf.g_beta1,
                    conf.d_beta2, conf.gf_dim, conf.df_dim)
    sample_x = data_utils.DataSet(conf.X)

    # log ground truth
    vis_nsample = min(6, conf.nbatch)
    vis_X = conf.X[:vis_nsample]
    vis_X = vis_X.reshape([vis_X.shape[0], -1])
    vis.plot_series(vis_X, os.path.join(conf.dir_samples, "000_real.png"))
    # save variables to log
    save_variables(conf, os.path.join(conf.dir_logs, 'variables'))
    f_log_train = open(os.path.join(conf.dir_logs, 'log_train.ndjson'), 'w')
    log_fields = [
        'n_epoches', 'n_updates', 'n_examples', 'n_seconds', 'nnd', 'mmd',
        'g_loss', 'd_loss_real', 'd_loss_fake'
    ]

    # set up tf session and train model
    with tf.Session(config=tf_conf) as sess:
        # initialize
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        t = time()
        n_updates = 0
        n_epoches = 0
        n_examples = 0
        g_losses, d_losses, d_losses_fake, d_losses_real = [], [], [], []
        nnds = []
        mmds = []
        mmd_bandwidths = [2.0, 5.0, 10.0, 20.0, 40.0, 80.0]
        mmd_batchsize = min(conf.nsample, conf.X.shape[0])
        mmd_real_t = tf.placeholder(tf.float32, [mmd_batchsize, conf.img_h],
                                    name='mmd_real')
        mmd_sample_t = tf.placeholder(tf.float32, [mmd_batchsize, conf.img_h],
                                      name='mmd_sample')
        mmd_loss_t = mix_rbf_mmd2(mmd_real_t,
                                  mmd_sample_t,
                                  sigmas=mmd_bandwidths)
        # train
        for epoch in xrange(conf.nepoch):
            g_loss, d_loss, d_loss_fake, d_loss_real = np.zeros(4)
            for i in xrange(sample_x.num_examples // conf.nbatch):
                x = sample_x.next_batch(conf.nbatch)
                z = sample_z([conf.nbatch, conf.z_dim])

                _ = sess.run(gan.d_opt, feed_dict={gan.x: x, gan.z: z})
                _ = sess.run(gan.g_opt, feed_dict={gan.z: z})
                _ = sess.run(gan.g_opt, feed_dict={gan.z: z})

                d_loss, d_loss_real, d_loss_fake, g_loss = sess.run(
                    [gan.d_loss, gan.d_loss_real, gan.d_loss_fake, gan.g_loss],
                    feed_dict={
                        gan.x: x,
                        gan.z: z
                    })
                n_updates += 1
                n_examples += len(x)
            n_epoches += 1
            g_losses.append(g_loss)
            d_losses.append(d_loss)
            d_losses_fake.append(d_loss_fake)
            d_losses_real.append(d_loss_real)

            # log
            if epoch % conf.freq_print == 0:
                print(
                    "Epoch: [{}/{}], g_loss = {:.4f}, d_loss = {:.4f}, d_loss_fake = {:.4f}, d_loss_reak = {:.4f}"
                    .format(epoch, conf.nepoch, g_loss, d_loss, d_loss_fake,
                            d_loss_real))
            if epoch % conf.freq_log == 0 or epoch == conf.nepoch - 1:
                # eval
                gX = gan_sample(sess, gan, conf, conf.nsample)
                gX = gX.reshape(len(gX), -1)
                teX = conf.X.reshape(len(conf.X), -1)
                nnd_ = metrics.nnd_score(gX[:mmd_batchsize],
                                         teX[:mmd_batchsize],
                                         metric='euclidean')
                nnds.append(nnd_)
                mmd_ = sess.run(mmd_loss_t,
                                feed_dict={
                                    mmd_real_t: teX[:mmd_batchsize],
                                    mmd_sample_t: gX[:mmd_batchsize]
                                })
                mmds.append(mmd_)
                log_valus = [
                    n_epoches, n_updates, n_examples,
                    time() - t, nnd_,
                    float(mmd_),
                    float(g_loss),
                    float(d_loss_real),
                    float(d_loss_fake)
                ]
                f_log_train.write(
                    json.dumps(dict(zip(log_fields, log_valus))) + '\n')
                f_log_train.flush()
                # save checkpoint
                gan.save(sess, conf.dir_checkpoint, n_updates, conf.model_name)

            if epoch % conf.freq_plot == 0 or epoch == conf.nepoch - 1:
                samples = gan_sample(sess, gan, conf, vis_nsample)
                samples = samples.reshape([samples.shape[0], -1])
                img_path = os.path.join(
                    conf.dir_samples,
                    "train_{}.png".format(str(epoch + 1).zfill(4)))
                txt_path = os.path.join(
                    conf.dir_samples,
                    "train_{}".format(str(epoch + 1).zfill(4)))
                vis.plot_series(samples, img_path)
                np.savetxt(txt_path, samples, delimiter=',', newline='\n')

    metrics_dic = {
        'g_loss': np.array(g_losses),
        'd_loss': np.array(d_losses),
        'd_loss_fake': np.array(d_losses_fake),
        'd_loss_real': np.array(d_losses_real),
        'nnd': np.array(nnds),
        'mmd': np.array(mmds)
    }

    metrics_save(metrics_dic, conf)
    metrics_vis(metrics_dic, conf)