示例#1
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'),
                        'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'),
                        'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    if options['data_dir'] != 'MNIST':
        num_data_points = len(
            os.listdir(os.path.join(options['data_dir'], 'train', 'patches')))
        num_data_points -= 2

        train_provider = DataProvider(
            num_data_points, options['batch_size'],
            toolbox.ImageLoader(data_dir=os.path.join(options['data_dir'],
                                                      'train', 'patches'),
                                flat=True,
                                extension=options['file_extension']))

        # Valid provider
        num_data_points = len(
            os.listdir(os.path.join(options['data_dir'], 'valid', 'patches')))
        num_data_points -= 2

        val_provider = DataProvider(
            num_data_points, options['batch_size'],
            toolbox.ImageLoader(data_dir=os.path.join(options['data_dir'],
                                                      'valid', 'patches'),
                                flat=True,
                                extension=options['file_extension']))

    else:
        train_provider = DataProvider(
            55000, options['batch_size'],
            toolbox.MNISTLoader(mode='train', flat=True))

        val_provider = DataProvider(
            5000, options['batch_size'],
            toolbox.MNISTLoader(mode='validation', flat=True))

    log.info('Data providers initialized.')

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            options['p_layers'], options['q_layers'],
            np.prod(options['img_shape']), options['latent_dims'],
            options['DKL_weight'], options['sigma_clip'], 'vanilla_vae')
        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        model_label_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['num_classes']],
            name='labels')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='dec_inputs')
        log.info('Inputs defined')

        # Define forward pass
        cost_function = model(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = model.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])

        enc_std = tf.exp(tf.mul(0.5, model.enc_log_std_sq))

        classifier = FC(
            model.latent_dims,
            options['num_classes'],
            activation=None,
            scale=0.01,
            name='classifier_fc')(tf.add(
                tf.mul(tf.random_normal([model.n_samples, model.latent_dims]),
                       enc_std), model.enc_mean))

        classifier = tf.nn.softmax(classifier)
        cost_function = -tf.mul(model_label_batch, tf.log(classifier))
        cost_function = tf.reduce_sum(cost_function)
        cost_function *= 1 / float(options['batch_size'])

        train_step = optimizer.minimize(cost_function)

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        sess.run(init_op)
        saver.restore(
            sess, os.path.join(options['model_dir'], 'model_at_21000.ckpt'))
        log.info('Shared variables restored')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, labels in train_provider:

                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1          2           3               4                     5                       6              7               8            9           10
                    [
                        cost_function, backpass, model.DKL, model.rec_loss,
                        model.dec_log_std_sq, model.enc_log_std_sq,
                        model.enc_mean, model.dec_mean, classifier
                    ] + [gv[0] for gv in grads],
                    feed_dict={
                        model_input_batch: inputs,
                        model_label_batch: labels
                    })

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22',
                        np.mean(
                            np.argmax(labels, axis=1) == np.argmax(result[8],
                                                                   axis=1))))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                     '2016-04-22',
                                                     -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()
                    # val_sig_log.flush()
                # print('\n\nENC_MEAN:')
                # print(result[3])
                # print('\n\nENC_STD:')
                # print(result[2])
                # print('\nDEC_MEAN:')
                # print(result[6])
                # print('\nDEC_STD:')
                # print(result[5])

                # print('\n\nENCODER WEIGHTS:')
                # print(model._encoder.layers[0].weights['w'].eval())
                # print('\n\DECODER WEIGHTS:')
                # print(model._decoder.layers[0].weights['w'].eval())

                # print(model._encoder.layers[0].weights['w'].eval())
                # print(result[2])
                # print(result[3])

                # print(result[3])
                # print(result[2])
                # print(result[-2])
                # print(result[-1])

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(
                        np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(
                        np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch, val_labels in val_provider:

                        val_result = sess.run(
                            [cost_function, classifier],
                            feed_dict={
                                model_input_batch: val_batch,
                                model_label_batch: val_labels
                            })
                        val_cost = val_result[0]
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_samples = sess.run(
                        sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22',
                        np.mean(
                            np.argmax(val_labels, axis=1) == np.argmax(
                                val_result[1], axis=1))))
                    val_log.flush()

                    save_ae_samples(catalog,
                                    np.reshape(result[7],
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    save_samples(
                        val_samples,
                        int(batch_abs_idx / options['freq_validation']),
                        os.path.join(options['model_dir'], 'valid_samples'),
                        True, options['img_shape'], 5)

                    save_samples(
                        inputs,
                        int(batch_abs_idx / options['freq_validation']),
                        os.path.join(options['model_dir'], 'input_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5)

                    save_samples(
                        result[7],
                        int(batch_abs_idx / options['freq_validation']),
                        os.path.join(options['model_dir'], 'rec_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5)

            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#2
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Discriminator Cross-Entropy
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'),
                        'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'),
                        'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write(
        'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n'
    )
    val_log.write(
        'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n'
    )
    dkl_log.write(
        'step,time,DKL (Training Vanilla),DKL (Training Gen.),DKL (Training Disc.)\n'
    )
    ll_log.write(
        'step,time,-LL (Training Vanilla),-LL (Training Gen.),-LL (Training Disc.)\n'
    )

    dec_sig_log.write(
        'step,time,Decoder Log Sigma^2 (Training Vanilla),Decoder Log Sigma^2 (Training Gen.),Decoder Log Sigma^2 (Training Disc.)\n'
    )
    enc_sig_log.write(
        'step,time,Encoder Log Sigma^2 (Training Vanilla),Encoder Log Sigma^2 (Training Gen.),Encoder Log Sigma^2 (Training Disc.)\n'
    )

    dec_std_sig_log.write(
        'step,time,STD of Decoder Log Sigma^2 (Training Vanilla),STD of Decoder Log Sigma^2 (Training Gen.),STD of Decoder Log Sigma^2 (Training Disc.)\n'
    )
    enc_std_sig_log.write(
        'step,time,STD of Encoder Log Sigma^2 (Training Vanilla),STD of Encoder Log Sigma^2 (Training Gen.),STD of Encoder Log Sigma^2 (Training Disc.)\n'
    )

    dec_mean_log.write(
        'step,time,Decoder Mean (Training Vanilla),Decoder Mean (Training Gen.),Decoder Mean (Training Disc.)\n'
    )
    enc_mean_log.write(
        'step,time,Encoder Mean (Training Vanilla),Encoder Mean (Training Gen.),Encoder Mean (Training Disc.)\n'
    )

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='dec_inputs')
        log.info('Inputs defined')

        # Define model
        with tf.variable_scope('vae_scope'):
            vae_model = cupboard('vanilla_vae')(
                options['p_layers'], options['q_layers'],
                np.prod(options['img_shape']), options['latent_dims'],
                options['DKL_weight'], options['sigma_clip'], 'vae_model')

        with tf.variable_scope('disc_scope'):
            disc_model = cupboard('fixed_conv_disc')(
                pickle.load(open(options['disc_params_path'], 'rb')),
                options['num_feat_layers'],
                name='disc_model')

        vae_gan = cupboard('vae_gan')(vae_model,
                                      disc_model,
                                      options['disc_weight'],
                                      options['img_shape'],
                                      options['input_channels'],
                                      'vae_scope',
                                      'disc_scope',
                                      name='vae_gan_model')

        # Define Optimizers ---------------------------------------------------------------------
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])

        vae_backpass, disc_backpass, vanilla_backpass = vae_gan(
            model_input_batch, sampler_input_batch, optimizer)

        log.info('Optimizer graph built')
        # --------------------------------------------------------------------------------------
        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload_all']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Variables initialized')

            if options['reload_vae']:
                vae_model.reload_vae(options['vae_params_path'])

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        base = options['initial_G_iters'] + options['initial_D_iters']

        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                if batch_abs_idx < options['initial_G_iters']:
                    backpass = vanilla_backpass
                    log_format_string = '{},{},{},,\n'
                elif options['initial_G_iters'] <= batch_abs_idx < base:
                    backpass = disc_backpass
                    log_format_string = '{},{},,,{}\n'
                else:
                    if (batch_abs_idx - base) % total_D2G < D_to_G[0]:
                        backpass = disc_backpass
                        log_format_string = '{},{},,,{}\n'
                    else:
                        backpass = vae_backpass
                        log_format_string = '{},{},,{},\n'

                result = sess.run(
                    [
                        vae_gan.disc_CE, backpass, vae_gan._vae.DKL,
                        vae_gan._vae.rec_loss, vae_gan._vae.dec_log_std_sq,
                        vae_gan._vae.enc_log_std_sq, vae_gan._vae.enc_mean,
                        vae_gan._vae.dec_mean
                    ],
                    feed_dict={
                        model_input_batch:
                        inputs,
                        sampler_input_batch:
                        MVN(np.zeros(options['latent_dims']),
                            np.diag(np.ones(options['latent_dims'])),
                            size=options['batch_size'])
                    })

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(last_losses)))
                    dkl_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 -np.mean(result[2])))
                    ll_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(result[4])))
                    enc_sig_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.std(result[4])))
                    enc_std_sig_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.std(result[5])))

                    dec_mean_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(result[7])))
                    enc_mean_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(vae_gan._vae._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(
                        np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(
                        np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    save_dict = {}
                    # Save encoder params ------------------------------------------------------------------
                    for i in range(len(vae_gan._vae._encoder.layers)):
                        layer_dict = {
                            'input_dim':
                            vae_gan._vae._encoder.layers[i].input_dim,
                            'output_dim':
                            vae_gan._vae._encoder.layers[i].output_dim,
                            'act_fn':
                            vae_gan._vae._encoder.layers[i].activation,
                            'W':
                            vae_gan._vae._encoder.layers[i].weights['w'].eval(
                            ),
                            'b':
                            vae_gan._vae._encoder.layers[i].weights['b'].eval(
                            )
                        }
                        save_dict['encoder'] = layer_dict

                    layer_dict = {
                        'input_dim': vae_gan._vae._enc_mean.input_dim,
                        'output_dim': vae_gan._vae._enc_mean.output_dim,
                        'act_fn': vae_gan._vae._enc_mean.activation,
                        'W': vae_gan._vae._enc_mean.weights['w'].eval(),
                        'b': vae_gan._vae._enc_mean.weights['b'].eval()
                    }
                    save_dict['enc_mean'] = layer_dict

                    layer_dict = {
                        'input_dim': vae_gan._vae._enc_log_std_sq.input_dim,
                        'output_dim': vae_gan._vae._enc_log_std_sq.output_dim,
                        'act_fn': vae_gan._vae._enc_log_std_sq.activation,
                        'W': vae_gan._vae._enc_log_std_sq.weights['w'].eval(),
                        'b': vae_gan._vae._enc_log_std_sq.weights['b'].eval()
                    }
                    save_dict['enc_log_std_sq'] = layer_dict

                    # Save decoder params ------------------------------------------------------------------
                    for i in range(len(vae_gan._vae._decoder.layers)):
                        layer_dict = {
                            'input_dim':
                            vae_gan._vae._decoder.layers[i].input_dim,
                            'output_dim':
                            vae_gan._vae._decoder.layers[i].output_dim,
                            'act_fn':
                            vae_gan._vae._decoder.layers[i].activation,
                            'W':
                            vae_gan._vae._decoder.layers[i].weights['w'].eval(
                            ),
                            'b':
                            vae_gan._vae._decoder.layers[i].weights['b'].eval(
                            )
                        }
                        save_dict['decoder'] = layer_dict

                    layer_dict = {
                        'input_dim': vae_gan._vae._dec_mean.input_dim,
                        'output_dim': vae_gan._vae._dec_mean.output_dim,
                        'act_fn': vae_gan._vae._dec_mean.activation,
                        'W': vae_gan._vae._dec_mean.weights['w'].eval(),
                        'b': vae_gan._vae._dec_mean.weights['b'].eval()
                    }
                    save_dict['dec_mean'] = layer_dict

                    layer_dict = {
                        'input_dim': vae_gan._vae._dec_log_std_sq.input_dim,
                        'output_dim': vae_gan._vae._dec_log_std_sq.output_dim,
                        'act_fn': vae_gan._vae._dec_log_std_sq.activation,
                        'W': vae_gan._vae._dec_log_std_sq.weights['w'].eval(),
                        'b': vae_gan._vae._dec_log_std_sq.weights['b'].eval()
                    }
                    save_dict['dec_log_std_sq'] = layer_dict

                    pickle.dump(
                        save_dict,
                        open(
                            os.path.join(options['model_dir'],
                                         'vae_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    vae_gan._vae._decoder.layers[0].weights['w'].eval()[:5, :5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        val_cost = sess.run(
                            vae_gan.disc_CE,
                            feed_dict={
                                model_input_batch:
                                val_batch,
                                sampler_input_batch:
                                MVN(np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size=options['batch_size'])
                            })
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_samples = sess.run(
                        vae_gan.sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(catalog,
                                    np.reshape(result[7],
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#3
0
def train(options):
    # Initialize model ------------------------------------------------------------------
    
    # input_shape, input_channels, enc_params, dec_params, name=''
    with tf.device('/gpu:0'):
        if options['model'] == 'cnn_ae':
            model = cupboard(options['model'])(
                options['img_shape'],
                options['input_channels'],
                options['enc_params'],
                options['dec_params'],
                'cnn_ae'
            )

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']],
                name = 'clean'
            )
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']],
                name = 'noisy'
            )
            log.info('Inputs defined')

        else:
            model = cupboard(options['model'])(
                np.prod(options['img_shape']) * options['input_channels'],
                options['enc_params'],
                options['dec_params'],
                'ae'
            )

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']],
                name = 'clean'
            )
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']],
                name = 'noisy'
            )
            log.info('Inputs defined')

        log.info('Model initialized')

        # Define forward pass
        print(model_clean_input_batch.get_shape())
        print(model_noisy_input_batch.get_shape())
        cost_function = model(model_clean_input_batch, model_noisy_input_batch)
        log.info('Forward pass graph built')

        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'], 'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')
示例#4
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'Log Sigma^2 clipped to: [{}, {}]\n\n'.format(
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Discriminator Cross-Entropy
train_acc.csv,csv,Discriminator Accuracy
val_loss.csv,csv,Validation Cross-Entropy
val_acc.csv,csv,Validation Accuracy
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    train_acc = open(os.path.join(options['dashboard_dir'], 'train_acc.csv'), 'w')
    val_acc = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'), 'w')


    train_log.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n')
    val_log.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n')
    train_acc.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n')
    val_acc.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation Acc. (Training Disc.)\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Define inputs -------------------------------------------------------------------------
        real_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'real_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'noise_channel'
        )
        labels = tf.constant(
            np.expand_dims(
                np.concatenate(
                    (
                        np.ones(options['batch_size']),
                        np.zeros(options['batch_size'])
                    ),
                    axis=0
                ).astype(np.float32),
                axis=1
            )
        )
        labels = tf.cast(labels, tf.float32)
        log.info('Inputs defined')

        # Define model --------------------------------------------------------------------------
        with tf.variable_scope('gen_scope'):
            generator = Sequential('generator')
            generator += FullyConnected(options['latent_dims'], 60, tf.nn.tanh, name='fc_1')
            generator += FullyConnected(60, 60, tf.nn.tanh, name='fc_2')
            generator += FullyConnected(60, np.prod(options['img_shape']), tf.nn.tanh, name='fc_3')

            sampler = generator(sampler_input_batch)

        with tf.variable_scope('disc_scope'):
            disc_model = cupboard('fixed_conv_disc')(
                pickle.load(open(options['disc_params_path'], 'rb')),
                options['num_feat_layers'],
                name='disc_model'
            )

            disc_inputs = tf.concat(0, [real_batch, sampler])
            disc_inputs = tf.reshape(
                disc_inputs,
                [disc_inputs.get_shape()[0].value] + options['img_shape'] + [options['input_channels']]
            )

            preds = disc_model(disc_inputs)
            preds = tf.clip_by_value(preds, 0.00001, 0.99999)

            # Disc Accuracy
            disc_accuracy = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum(
                tf.cast(
                    tf.equal(
                        tf.round(preds),
                        labels
                    ),
                    tf.float32
                )
            )

        # Define Losses -------------------------------------------------------------------------
        # Discrimnator Cross-Entropy
        disc_CE = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum(
            -tf.add(
                tf.mul(
                    labels,
                    tf.log(preds)
                ),
                tf.mul(
                    1.0 - labels,
                    tf.log(1.0 - preds)
                )
            )
        )

        gen_loss = -tf.mul(
            1.0 - labels,
            tf.log(preds)
        )

        # Define Optimizers ---------------------------------------------------------------------
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )

        # Get Generator and Disriminator Trainable Variables
        gen_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'gen_scope')
        disc_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'disc_scope')

        # Get generator gradients
        grads = optimizer.compute_gradients(gen_loss, gen_train_vars)
        grads = [gv for gv in grads if gv[0] != None]
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='gen_grad_clipping'), gv[1]) for gv in grads]
        gen_backpass = optimizer.apply_gradients(clip_grads)

        # Get Dsicriminator gradients
        grads = optimizer.compute_gradients(disc_CE, disc_train_vars)
        grads = [gv for gv in grads if gv[0] != None]
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='disc_grad_clipping'), gv[1]) for gv in grads]
        disc_backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')
        # --------------------------------------------------------------------------------------
        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload_all']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))
        disc_tracker = np.ones((5000))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        base = options['initial_G_iters'] + options['initial_D_iters']
        # must_init = True
        feat_params = pickle.load(open(options['disc_params_path'], 'rb'))

        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                if batch_abs_idx < options['initial_G_iters']:
                    backpass = gen_backpass
                    log_format_string = '{},{},{},,\n'
                elif options['initial_G_iters'] <= batch_abs_idx < base:
                    backpass = disc_backpass
                    log_format_string = '{},{},,,{}\n'
                else:
                    # if np.mean(disc_tracker) < 0.95:
                    #     disc_model._disc.layers[-2].re_init_weights(sess)
                    #     disc_tracker = np.ones((5000))

                    if (batch_abs_idx - base) % total_D2G < D_to_G[0]:
                        # if must_init:
                        #     # i = 0
                        #     # for j in xrange(options['num_feat_layers']):
                        #     #     if feat_params[j]['layer_type'] == 'conv':
                        #     #         disc_model._disc.layers[i].re_init_weights(sess)
                        #     #         # print('@' * 1000)
                        #     #         # print(disc_model._disc.layers[i])
                        #     #         i += 1 # for dealing with activation function
                        #     #     elif feat_params[j]['layer_type'] == 'fc':
                        #     #         disc_model._disc.layers[i].re_init_weights(sess)
                        #     #         # print('@' * 1000)
                        #     #         # print(disc_model._disc.layers[i])
                        #     #     i += 1
                        #     disc_model._disc.layers[-2].re_init_weights(sess)
                        #     # print('@' * 1000)
                        #     # print(disc_model._disc.layers[-2])
                        #     must_init = False
                        backpass = disc_backpass
                        log_format_string = '{},{},,,{}\n'
                    else:
                        # must_init = True
                        backpass = gen_backpass
                        log_format_string = '{},{},,{},\n'

                log_format_string = '{},{},{},,\n'
                result = sess.run(
                    [
                        disc_CE,
                        backpass,
                        disc_accuracy
                    ],
                    feed_dict = {
                        real_batch: inputs,
                        sampler_input_batch: MVN(
                            np.zeros(options['latent_dims']),
                            np.diag(np.ones(options['latent_dims'])),
                            size = options['batch_size']
                        )
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_acc.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_accs)))

                    train_log.flush()
                    train_acc.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                last_accs = np.roll(last_accs, 1)
                last_accs[0] = result[-1]

                disc_tracker = np.roll(disc_tracker, 1)
                disc_tracker[0] = result[-1]

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last accuracies: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_accs)
                    ))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    valid_accs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        result = sess.run(
                            [
                                disc_CE,
                                disc_accuracy
                            ],
                            feed_dict = {
                                real_batch: val_batch,
                                sampler_input_batch: MVN(
                                    np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size = options['batch_size']
                                )
                            }
                        )
                        valid_costs.append(result[0])
                        valid_accs.append(result[-1])
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))
                    log.info('Validation accuracies: {:0>15.4f}'.format(
                        float(np.mean(valid_accs))
                    ))

                    val_samples = sess.run(
                        sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_acc.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_accs)))
                    val_log.flush()
                    val_acc.flush()

                    save_ae_samples(
                        catalog,
                        np.ones([options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

            log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))

    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')

    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
train_acc.csv,csv,Train Accuracy
val_loss.csv,csv,Validation Loss
val_acc.csv,csv,Validation Accuracy
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    train_acc_log = open(os.path.join(options['dashboard_dir'], 'train_acc.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    val_acc_log = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    train_acc_log.write('step,time,Train Accuracy\n')
    val_acc_log.write('step,time,Validation Accuracy\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    num_data_points = len(
        os.listdir(
            os.path.join(options['data_dir'], 'train', 'info')
        )
    )
    num_data_points -= 2

    train_provider = DataProvider(
    	num_data_points,
    	options['batch_size'],
    	toolbox.CIFARLoader(
            data_dir = os.path.join(options['data_dir'], 'train'),
            flat=False
        )
    )

    # Valid provider
    num_data_points = len(
        os.listdir(
            os.path.join(options['data_dir'], 'valid', 'info')
        )
    )
    num_data_points -= 2

    print(num_data_points)

    val_provider = DataProvider(
    	num_data_points,
        options['batch_size'],
        toolbox.CIFARLoader(
        	data_dir = os.path.join(options['data_dir'], 'valid'),
        	flat=False
        )
    )
    log.info('Data providers initialized.')


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
        	options['img_shape'],
        	options['input_channels'],
            options['num_classes'],
            options['conv_params'],
            options['fc_params'],
        	'CIFAR_classifier'
        )
        log.info('Model initialized')

        # Define inputs
        input_batch = tf.placeholder(
        	tf.float32,
        	shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']],
        	name = 'inputs'
        )
        label_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['num_classes']],
            name = 'labels'
        )
        log.info('Inputs defined')

        # Define forward pass
        cost_function, classifier = model(input_batch, label_batch)
        log.info('Forward pass graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'], 'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, labels in train_provider:

                batch_abs_idx += 1
                batch_rel_idx += 1

                results = sess.run(
                    [cost_function, classifier, train_step],
                    feed_dict = {
                        input_batch: inputs,
                        label_batch: labels
                    }
                )

                cost = results[0]
                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    return 1., 1., 1.

                accuracy = np.mean(np.argmax(results[1], axis=1) == np.argmax(labels, axis=1))
                
                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost
                last_accs = np.roll(last_accs, 1)
                last_accs[0] = accuracy

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_acc_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_accs)))
                    train_log.flush()
                    train_acc_log.flush()

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Accuracy: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        np.mean(last_accs)
                    ))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))

                    save_dict = []
                    for c_ind in xrange(0, len(model._classifier_conv.layers), 2):
                        layer_dict = {
                            'n_filters_in': model._classifier_conv.layers[c_ind].n_filters_in,
                            'n_filters_out': model._classifier_conv.layers[c_ind].n_filters_out,
                            'input_dim': model._classifier_conv.layers[c_ind].input_dim,
                            'filter_dim': model._classifier_conv.layers[c_ind].filter_dim,
                            'strides': model._classifier_conv.layers[c_ind].strides,
                            'padding': model._classifier_conv.layers[c_ind].padding,
                        }
                        save_dict.append(layer_dict)
                    pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'class_dict_%d' % batch_abs_idx), 'wb'))

                    log.info('Model saved')

                    # Save params for feature vae training later
                    # conv_feat = deepcopy(model._classifier_conv)
                    # for lay_ind in range(0,len(conv_feat.layers),2):
                    #     conv_feat[lay_ind].weights['W'] = tf.constant(conv_feat[lay_ind].weights['W'].eval())
                    #     conv_feat[lay_ind].weights['b'] = tf.constant(conv_feat[lay_ind].weights['b'].eval())
                    # pickle(conv_feat, open(os.path.join(options['model_dir'], 'classifier_conv_feat_%d' % batch_abs_idx), 'wb'))


                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    val_accuracies = []
                    seen_batches = 0
                    for val_batch, val_label in val_provider:

                        # Break if 10 batches seen for now
                        if seen_batches == options['valid_batches']:
                            break

                        val_results = sess.run(
                            [cost_function, classifier],
                            feed_dict = {
                                input_batch: val_batch,
                                label_batch: val_label
                            }
                        )
                        val_cost = val_results[0]
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        val_accuracies.append(np.mean(np.argmax(val_results[1], axis=1) == np.argmax(val_label, axis=1)))

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))
                    log.info('Validation Accuracy: {:0>15.4f}'.format(
                        np.mean(val_accuracies)
                    ))

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', float(np.mean(valid_costs))))
                    val_acc_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(val_accuracies)))
                    val_log.flush()
                    val_acc_log.flush()

            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#6
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'),
                        'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'),
                        'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            options['p_layers'], options['q_layers'],
            np.prod(options['img_shape']), options['latent_dims'],
            options['DKL_weight'], options['sigma_clip'], 'vanilla_vae')
        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='dec_inputs')
        log.info('Inputs defined')

        # Define forward pass
        cost_function = model(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = model.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])

        train_step = optimizer.minimize(cost_function)

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')

            # test_LL_and_DKL(sess, test_provider, model.DKL, model.rec_loss, options, model_input_batch)
            # return

            # if options['data_dir'] == 'MNIST':
            #     mean_img = np.zeros(np.prod(options['img_shape']))
            #     std_img = np.ones(np.prod(options['img_shape']))
            # else:
            #     mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension']))
            #     std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension']))
            # visualize(model.sampler_mean, sess, model.dec_mean, model.dec_log_std_sq, sampler, sampler_input_batch,
            #             model_input_batch, model.enc_mean, model.enc_log_std_sq,
            #             train_provider, val_provider, options, catalog, mean_img, std_img)
            # return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1          2           3               4                     5                       6              7               8            9           10
                    [
                        cost_function, backpass, model.DKL, model.rec_loss,
                        model.dec_log_std_sq, model.enc_log_std_sq,
                        model.enc_mean, model.dec_mean
                    ],
                    feed_dict={model_input_batch: inputs})

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                     '2016-04-22',
                                                     -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()
                    # val_sig_log.flush()
                # print('\n\nENC_MEAN:')
                # print(result[3])
                # print('\n\nENC_STD:')
                # print(result[2])
                # print('\nDEC_MEAN:')
                # print(result[6])
                # print('\nDEC_STD:')
                # print(result[5])

                # print('\n\nENCODER WEIGHTS:')
                # print(model._encoder.layers[0].weights['w'].eval())
                # print('\n\DECODER WEIGHTS:')
                # print(model._decoder.layers[0].weights['w'].eval())

                # print(model._encoder.layers[0].weights['w'].eval())
                # print(result[2])
                # print(result[3])

                # print(result[3])
                # print(result[2])
                # print(result[-2])
                # print(result[-1])

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(
                        np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(
                        np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    save_dict = {}
                    # Save encoder params ------------------------------------------------------------------
                    for i in range(len(model._encoder.layers)):
                        layer_dict = {
                            'input_dim': model._encoder.layers[i].input_dim,
                            'output_dim': model._encoder.layers[i].output_dim,
                            'act_fn': model._encoder.layers[i].activation,
                            'W': model._encoder.layers[i].weights['w'].eval(),
                            'b': model._encoder.layers[i].weights['b'].eval()
                        }
                        save_dict['encoder'] = layer_dict

                    layer_dict = {
                        'input_dim': model._enc_mean.input_dim,
                        'output_dim': model._enc_mean.output_dim,
                        'act_fn': model._enc_mean.activation,
                        'W': model._enc_mean.weights['w'].eval(),
                        'b': model._enc_mean.weights['b'].eval()
                    }
                    save_dict['enc_mean'] = layer_dict

                    layer_dict = {
                        'input_dim': model._enc_log_std_sq.input_dim,
                        'output_dim': model._enc_log_std_sq.output_dim,
                        'act_fn': model._enc_log_std_sq.activation,
                        'W': model._enc_log_std_sq.weights['w'].eval(),
                        'b': model._enc_log_std_sq.weights['b'].eval()
                    }
                    save_dict['enc_log_std_sq'] = layer_dict

                    # Save decoder params ------------------------------------------------------------------
                    for i in range(len(model._decoder.layers)):
                        layer_dict = {
                            'input_dim': model._decoder.layers[i].input_dim,
                            'output_dim': model._decoder.layers[i].output_dim,
                            'act_fn': model._decoder.layers[i].activation,
                            'W': model._decoder.layers[i].weights['w'].eval(),
                            'b': model._decoder.layers[i].weights['b'].eval()
                        }
                        save_dict['decoder'] = layer_dict

                    layer_dict = {
                        'input_dim': model._dec_mean.input_dim,
                        'output_dim': model._dec_mean.output_dim,
                        'act_fn': model._dec_mean.activation,
                        'W': model._dec_mean.weights['w'].eval(),
                        'b': model._dec_mean.weights['b'].eval()
                    }
                    save_dict['dec_mean'] = layer_dict

                    layer_dict = {
                        'input_dim': model._dec_log_std_sq.input_dim,
                        'output_dim': model._dec_log_std_sq.output_dim,
                        'act_fn': model._dec_log_std_sq.activation,
                        'W': model._dec_log_std_sq.weights['w'].eval(),
                        'b': model._dec_log_std_sq.weights['b'].eval()
                    }
                    save_dict['dec_log_std_sq'] = layer_dict

                    pickle.dump(
                        save_dict,
                        open(
                            os.path.join(options['model_dir'],
                                         'vae_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    model._decoder.layers[0].weights['w'].eval()[:5, :5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        val_cost = sess.run(
                            cost_function,
                            feed_dict={model_input_batch: val_batch})
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_samples = sess.run(
                        sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(catalog,
                                    np.reshape(result[7],
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     val_samples,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     True,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     result[7],
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'rec_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#7
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    with open(os.path.join(options['dashboard_dir'], 'description'),
              'w') as desc_file:
        desc_file.write(options['description'])

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
description,plain,Description
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'),
                        'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'),
                        'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):

        # Define inputs ----------------------------------------------------------
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='dec_inputs')
        log.info('Inputs defined')

        # Discriminator ---------------------------------------------------------
        # with tf.variable_scope('disc_scope'):
        #     disc_model = cupboard('fixed_conv_disc')(
        #         pickle.load(open(options['feat_params_path'], 'rb')),
        #         options['num_feat_layers'],
        #         'discriminator'
        #     )

        # VAE -------------------------------------------------------------------
        # VAE model
        # with tf.variable_scope('vae_scope'):
        vae_model = cupboard('vanilla_vae')(
            options['p_layers'], options['q_layers'],
            np.prod(options['img_shape']), options['latent_dims'],
            options['DKL_weight'], options['sigma_clip'], 'vanilla_vae')
        # VAE/GAN ---------------------------------------------------------------
        # vae_gan = cupboard('vae_gan')(
        #     vae_model,
        #     disc_model,
        #     options['img_shape'],
        #     options['input_channels'],
        #     'vae_scope',
        #     'disc_scope',
        #     name = 'vae_gan_model'
        # )

        log.info('Model initialized')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # Define forward pass
        cost_function = vae_model(model_input_batch)
        # backpass, grads = vae_gan(model_input_batch, sampler_input_batch, optimizer)
        log.info('Forward pass graph built')

        # Define sampler
        # sampler = vae_gan.sampler
        sampler = vae_model.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])

        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')

            test_LL_and_DKL(sess, test_provider, feat_vae.vae.DKL,
                            feat_vae.vae.rec_loss, options, model_input_batch)
            return

            mean_img = np.load(
                os.path.join(options['data_dir'],
                             'mean' + options['extension']))
            std_img = np.load(
                os.path.join(options['data_dir'],
                             'std' + options['extension']))
            visualize(sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq,
                      sampler, sampler_input_batch, model_input_batch,
                      feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq,
                      train_provider, val_provider, options, catalog, mean_img,
                      std_img)
            return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, _ in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                # if batch_abs_idx < options['initial_G_iters']:
                #     optimizer = vae_optimizer
                # else:
                #     optimizer = disc_optimizer
                # if batch_abs_idx % total_D2G < D_to_G[0]:
                #     optimizer = disc_optimizer
                # else:
                #     optimizer = vae_optimizer
                result = sess.run([
                    cost_function,
                    backpass,
                    vae_model.DKL,
                    vae_model.rec_loss,
                    vae_model.dec_log_std_sq,
                    vae_model.enc_log_std_sq,
                    vae_model.enc_mean,
                    vae_model.dec_mean,
                ] + [gv[0] for gv in grads],
                                  feed_dict={model_input_batch: inputs})

                # print('#'*80)
                # print(result[-1])
                # print('#'*80)

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                     '2016-04-22',
                                                     -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()
                    # val_sig_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(
                        np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(
                        np.mean(result[2], axis=0)))
                    # log.info('Batch Mean Acc.: {:0>15.4f}'.format(result[-2], axis=0))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch, _ in val_provider:

                        val_cost = sess.run(
                            vae_model.cost,
                            feed_dict={
                                model_input_batch:
                                val_batch,
                                sampler_input_batch:
                                MVN(np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size=options['batch_size'])
                            })
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_samples = sess.run(
                        sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(catalog,
                                    np.reshape(result[7],
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

                    # save_samples(
                    #     val_samples,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     True,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     result[8],
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'rec_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

            log.info('End of epoch {}'.format(epoch_idx + 1))
    # Test Model --------------------------------------------------------------------------
        test_results = []

        for inputs in test_provider:
            if isinstance(inputs, tuple):
                inputs = inputs[0]
            batch_results = sess.run([
                feat_vae.vae.DKL, feat_vae.vae.rec_loss,
                feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq,
                feat_vae.vae.dec_mean, feat_vae.vae.enc_mean
            ],
                                     feed_dict={model_input_batch: inputs})

            test_results.append(
                map(
                    lambda p: np.mean(p, axis=1)
                    if len(p.shape) > 1 else np.mean(p), batch_results))
        test_results = map(list, zip(*test_results))

        # Print results
        log.info('Test Mean Rec. Loss: {:0>15.4f}'.format(
            float(np.mean(test_results[1]))))
        log.info('Test DKL: {:0>15.4f}'.format(float(np.mean(
            test_results[0]))))
        log.info('Test Dec. Mean Log Std Sq: {:0>15.4f}'.format(
            float(np.mean(test_results[2]))))
        log.info('Test Enc. Mean Log Std Sq: {:0>15.4f}'.format(
            float(np.mean(test_results[3]))))
        log.info('Test Dec. Mean Mean: {:0>15.4f}'.format(
            float(np.mean(test_results[4]))))
        log.info('Test Enc. Mean Mean: {:0>15.4f}'.format(
            float(np.mean(test_results[5]))))
示例#8
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')

    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            np.prod(options['img_shape']) * options['input_channels'],
            options['enc_params'],
            options['dec_params'],
            'ae'
        )
        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'inputs'
        )
        model_label_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['num_classes']],
            name = 'labels'
        )
        log.info('Inputs defined')

        # Load VAE
        model(model_input_batch, model_input_batch)

        feat_params = pickle.load(open(options['feat_params_path'], 'rb'))

        for i in range(len(model._encoder.layers)):
            model._encoder.layers[i].weights['w'] = tf.constant(feat_params['enc_W'][i])
            model._encoder.layers[i].weights['b'] = tf.constant(feat_params['enc_b'][i])

        classifier = FC(
            options['latent_dims'],
            options['num_classes'],
            activation=None,
            scale=0.01,
            name='classifier_fc'
        )(model.encoder)

        classifier = tf.nn.softmax(classifier)
        cost_function = -tf.mul(model_label_batch, tf.log(classifier))
        cost_function = tf.reduce_sum(cost_function)
        cost_function *= 1 / float(options['batch_size'])

        log.info('Forward pass graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        
        # train_step = optimizer.minimize(cost_function)

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'], 'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, labels in train_provider:

                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1          2           3               4                     5                       6              7               8            9           10
                    [cost_function, backpass, classifier] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_input_batch: inputs,
                        model_label_batch: labels
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                last_accs = np.roll(last_accs, 1)
                last_accs[0] = np.mean(np.argmax(labels, axis=1) == np.argmax(result[2], axis=1))

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Acc.: {:0>15.4f} Mean last accs: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        last_accs[0],
                        np.mean(last_accs)
                    ))
                    log.info('Batch Mean Loss: {:0>15.4f}'.format(np.mean(last_losses)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch, labels in val_provider:

                        val_result = sess.run(
                            [cost_function, classifier],
                            feed_dict = {
                                model_input_batch: val_batch,
                                model_label_batch: labels
                            }
                        )
                        val_cost = np.mean(np.argmax(labels, axis=1) == np.argmax(val_result[1], axis=1))
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation acc.: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

            log.info('End of epoch {}'.format(epoch_idx + 1))
    # --------------------------------------------------------------------------
        test_results = []

        for inputs, labels in test_provider:
            if isinstance(inputs, tuple):
                inputs = inputs[0]
            batch_results = sess.run(
                [cost_function, classifier],
                feed_dict = {
                    model_input_batch: inputs,
                    model_label_batch: labels
                }
            )

            test_results.append(np.mean(np.argmax(labels, axis=1) == np.argmax(batch_results[1], axis=1)))

        # Print results
        log.info('Test Accuracy: {:0>15.4f}'.format(
            np.mean(test_results)
        ))
示例#9
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log)

    # Initialize model ------------------------------------------------------------------
    
    # input_shape, input_channels, enc_params, dec_params, name=''
    with tf.device('/gpu:0'):
        if options['model'] == 'cnn_ae':
            model = cupboard(options['model'])(
                options['img_shape'],
                options['input_channels'],
                options['enc_params'],
                options['dec_params'],
                'cnn_ae'
            )

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']],
                name = 'clean'
            )
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']],
                name = 'noisy'
            )
            log.info('Inputs defined')

        else:
            model = cupboard(options['model'])(
                np.prod(options['img_shape']) * options['input_channels'],
                options['enc_params'],
                options['dec_params'],
                'ae'
            )

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']],
                name = 'clean'
            )
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']],
                name = 'noisy'
            )
            log.info('Inputs defined')

        log.info('Model initialized')

        # Define forward pass
        print(model_clean_input_batch.get_shape())
        print(model_noisy_input_batch.get_shape())
        cost_function = model(model_clean_input_batch, model_noisy_input_batch)
        log.info('Forward pass graph built')

        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'], 'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs,_ in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    [cost_function, backpass] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_clean_input_batch: inputs,
                        model_noisy_input_batch: np.float32(inputs) + \
                            normal(
                                loc=0.0,
                                scale=np.float32(options['noise_std']),
                                size=inputs.shape
                            )
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    # Save Encoder Params
                    save_dict = {
                        'enc_W': [],
                        'enc_b': [],
                        'enc_act_fn': [],
                    }
                    if options['model'] == 'cnn_ae':
                        pass
                    else:
                        for i in range(len(model._encoder.layers)):
                            save_dict['enc_W'].append(model._encoder.layers[i].weights['w'].eval())
                            save_dict['enc_b'].append(model._encoder.layers[i].weights['b'].eval())
                            save_dict['enc_act_fn'].append(options['enc_params']['act_fn'][i])

                    pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'enc_dict_%d' % batch_abs_idx), 'wb'))


                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    model._decoder.layers[0].weights['w'].eval()[:5,:5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch,_ in val_provider:

                        noisy_val_batch = val_batch + \
                            normal(
                                loc=0.0,
                                scale=np.float32(options['noise_std']),
                                size=val_batch.shape
                            )

                        val_results = sess.run(
                            (cost_function, model.decoder),
                            feed_dict = {
                                model_clean_input_batch: val_batch,
                                model_noisy_input_batch: noisy_val_batch
                            }
                        )
                        valid_costs.append(val_results[0])
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()


                    if options['model'] == 'conv_ae':
                        val_recon = np.reshape(
                            val_results[-1],
                            val_batch.shape
                        )
                    else:
                        val_batch = np.reshape(
                            val_batch,
                            [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]
                        )
                        noisy_val_batch = np.reshape(
                            noisy_val_batch,
                            [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]
                        )
                        val_recon = np.reshape(
                            val_results[-1],
                            [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]
                        )

                    save_ae_samples(
                        catalog,
                        val_batch,
                        noisy_val_batch,
                        val_recon,
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

                    # save_samples(
                    #     val_recon,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     False,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     False,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )


            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#10
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))

    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')

    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
train_acc.csv,csv,Train Accuracy
val_loss.csv,csv,Validation Loss
val_acc.csv,csv,Validation Accuracy
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    train_acc_log = open(
        os.path.join(options['dashboard_dir'], 'train_acc.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    val_acc_log = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'),
                       'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    train_acc_log.write('step,time,Train Accuracy\n')
    val_acc_log.write('step,time,Validation Accuracy\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    num_data_points = len(
        os.listdir(os.path.join(options['data_dir'], 'train', 'info')))
    num_data_points -= 2

    train_provider = DataProvider(
        num_data_points, options['batch_size'],
        toolbox.CIFARLoader(data_dir=os.path.join(options['data_dir'],
                                                  'train'),
                            flat=False))

    # Valid provider
    num_data_points = len(
        os.listdir(os.path.join(options['data_dir'], 'valid', 'info')))
    num_data_points -= 2

    print(num_data_points)

    val_provider = DataProvider(
        num_data_points, options['batch_size'],
        toolbox.CIFARLoader(data_dir=os.path.join(options['data_dir'],
                                                  'valid'),
                            flat=False))
    log.info('Data providers initialized.')

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(
            options['model'])(options['img_shape'], options['input_channels'],
                              options['num_classes'], options['conv_params'],
                              options['fc_params'], 'CIFAR_classifier')
        log.info('Model initialized')

        # Define inputs
        input_batch = tf.placeholder(tf.float32,
                                     shape=[options['batch_size']] +
                                     options['img_shape'] +
                                     [options['input_channels']],
                                     name='inputs')
        label_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['num_classes']],
            name='labels')
        log.info('Inputs defined')

        # Define forward pass
        cost_function, classifier = model(input_batch, label_batch)
        log.info('Forward pass graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'],
                                             'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, labels in train_provider:

                batch_abs_idx += 1
                batch_rel_idx += 1

                results = sess.run([cost_function, classifier, train_step],
                                   feed_dict={
                                       input_batch: inputs,
                                       label_batch: labels
                                   })

                cost = results[0]
                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    return 1., 1., 1.

                accuracy = np.mean(
                    np.argmax(results[1], axis=1) == np.argmax(labels, axis=1))

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost
                last_accs = np.roll(last_accs, 1)
                last_accs[0] = accuracy

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_acc_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_accs)))
                    train_log.flush()
                    train_acc_log.flush()

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Accuracy: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, np.mean(last_accs)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))

                    save_dict = []
                    for c_ind in xrange(0, len(model._classifier_conv.layers),
                                        2):
                        layer_dict = {
                            'n_filters_in':
                            model._classifier_conv.layers[c_ind].n_filters_in,
                            'n_filters_out':
                            model._classifier_conv.layers[c_ind].n_filters_out,
                            'input_dim':
                            model._classifier_conv.layers[c_ind].input_dim,
                            'filter_dim':
                            model._classifier_conv.layers[c_ind].filter_dim,
                            'strides':
                            model._classifier_conv.layers[c_ind].strides,
                            'padding':
                            model._classifier_conv.layers[c_ind].padding,
                        }
                        save_dict.append(layer_dict)
                    pickle.dump(
                        save_dict,
                        open(
                            os.path.join(options['model_dir'],
                                         'class_dict_%d' % batch_abs_idx),
                            'wb'))

                    log.info('Model saved')

                    # Save params for feature vae training later
                    # conv_feat = deepcopy(model._classifier_conv)
                    # for lay_ind in range(0,len(conv_feat.layers),2):
                    #     conv_feat[lay_ind].weights['W'] = tf.constant(conv_feat[lay_ind].weights['W'].eval())
                    #     conv_feat[lay_ind].weights['b'] = tf.constant(conv_feat[lay_ind].weights['b'].eval())
                    # pickle(conv_feat, open(os.path.join(options['model_dir'], 'classifier_conv_feat_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    val_accuracies = []
                    seen_batches = 0
                    for val_batch, val_label in val_provider:

                        # Break if 10 batches seen for now
                        if seen_batches == options['valid_batches']:
                            break

                        val_results = sess.run([cost_function, classifier],
                                               feed_dict={
                                                   input_batch: val_batch,
                                                   label_batch: val_label
                                               })
                        val_cost = val_results[0]
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        val_accuracies.append(
                            np.mean(
                                np.argmax(val_results[1], axis=1) == np.argmax(
                                    val_label, axis=1)))

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))
                    log.info('Validation Accuracy: {:0>15.4f}'.format(
                        np.mean(val_accuracies)))

                    val_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22',
                        float(np.mean(valid_costs))))
                    val_acc_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(val_accuracies)))
                    val_log.flush()
                    val_acc_log.flush()

            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#11
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write('Log Sigma^2 clipped to: [{}, {}]\n\n'.format(
        -options['sigma_clip'], options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Discriminator Cross-Entropy
train_acc.csv,csv,Discriminator Accuracy
val_loss.csv,csv,Validation Cross-Entropy
val_acc.csv,csv,Validation Accuracy
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    train_acc = open(os.path.join(options['dashboard_dir'], 'train_acc.csv'),
                     'w')
    val_acc = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'), 'w')

    train_log.write(
        'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n'
    )
    val_log.write(
        'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n'
    )
    train_acc.write(
        'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n'
    )
    val_acc.write(
        'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation Acc. (Training Disc.)\n'
    )

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Define inputs -------------------------------------------------------------------------
        real_batch = tf.placeholder(tf.float32,
                                    shape=[
                                        options['batch_size'],
                                        np.prod(np.array(options['img_shape']))
                                    ],
                                    name='real_inputs')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='noise_channel')
        labels = tf.constant(
            np.expand_dims(np.concatenate((np.ones(
                options['batch_size']), np.zeros(options['batch_size'])),
                                          axis=0).astype(np.float32),
                           axis=1))
        labels = tf.cast(labels, tf.float32)
        log.info('Inputs defined')

        # Define model --------------------------------------------------------------------------
        with tf.variable_scope('gen_scope'):
            generator = Sequential('generator')
            generator += FullyConnected(options['latent_dims'],
                                        60,
                                        tf.nn.tanh,
                                        name='fc_1')
            generator += FullyConnected(60, 60, tf.nn.tanh, name='fc_2')
            generator += FullyConnected(60,
                                        np.prod(options['img_shape']),
                                        tf.nn.tanh,
                                        name='fc_3')

            sampler = generator(sampler_input_batch)

        with tf.variable_scope('disc_scope'):
            disc_model = cupboard('fixed_conv_disc')(
                pickle.load(open(options['disc_params_path'], 'rb')),
                options['num_feat_layers'],
                name='disc_model')

            disc_inputs = tf.concat(0, [real_batch, sampler])
            disc_inputs = tf.reshape(
                disc_inputs, [disc_inputs.get_shape()[0].value] +
                options['img_shape'] + [options['input_channels']])

            preds = disc_model(disc_inputs)
            preds = tf.clip_by_value(preds, 0.00001, 0.99999)

            # Disc Accuracy
            disc_accuracy = (
                1 / float(labels.get_shape()[0].value)) * tf.reduce_sum(
                    tf.cast(tf.equal(tf.round(preds), labels), tf.float32))

        # Define Losses -------------------------------------------------------------------------
        # Discrimnator Cross-Entropy
        disc_CE = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum(
            -tf.add(tf.mul(labels, tf.log(preds)),
                    tf.mul(1.0 - labels, tf.log(1.0 - preds))))

        gen_loss = -tf.mul(1.0 - labels, tf.log(preds))

        # Define Optimizers ---------------------------------------------------------------------
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])

        # Get Generator and Disriminator Trainable Variables
        gen_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           'gen_scope')
        disc_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                            'disc_scope')

        # Get generator gradients
        grads = optimizer.compute_gradients(gen_loss, gen_train_vars)
        grads = [gv for gv in grads if gv[0] != None]
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='gen_grad_clipping'), gv[1])
                      for gv in grads]
        gen_backpass = optimizer.apply_gradients(clip_grads)

        # Get Dsicriminator gradients
        grads = optimizer.compute_gradients(disc_CE, disc_train_vars)
        grads = [gv for gv in grads if gv[0] != None]
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='disc_grad_clipping'), gv[1])
                      for gv in grads]
        disc_backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')
        # --------------------------------------------------------------------------------------
        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload_all']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))
        disc_tracker = np.ones((5000))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        base = options['initial_G_iters'] + options['initial_D_iters']
        # must_init = True
        feat_params = pickle.load(open(options['disc_params_path'], 'rb'))

        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                if batch_abs_idx < options['initial_G_iters']:
                    backpass = gen_backpass
                    log_format_string = '{},{},{},,\n'
                elif options['initial_G_iters'] <= batch_abs_idx < base:
                    backpass = disc_backpass
                    log_format_string = '{},{},,,{}\n'
                else:
                    # if np.mean(disc_tracker) < 0.95:
                    #     disc_model._disc.layers[-2].re_init_weights(sess)
                    #     disc_tracker = np.ones((5000))

                    if (batch_abs_idx - base) % total_D2G < D_to_G[0]:
                        # if must_init:
                        #     # i = 0
                        #     # for j in xrange(options['num_feat_layers']):
                        #     #     if feat_params[j]['layer_type'] == 'conv':
                        #     #         disc_model._disc.layers[i].re_init_weights(sess)
                        #     #         # print('@' * 1000)
                        #     #         # print(disc_model._disc.layers[i])
                        #     #         i += 1 # for dealing with activation function
                        #     #     elif feat_params[j]['layer_type'] == 'fc':
                        #     #         disc_model._disc.layers[i].re_init_weights(sess)
                        #     #         # print('@' * 1000)
                        #     #         # print(disc_model._disc.layers[i])
                        #     #     i += 1
                        #     disc_model._disc.layers[-2].re_init_weights(sess)
                        #     # print('@' * 1000)
                        #     # print(disc_model._disc.layers[-2])
                        #     must_init = False
                        backpass = disc_backpass
                        log_format_string = '{},{},,,{}\n'
                    else:
                        # must_init = True
                        backpass = gen_backpass
                        log_format_string = '{},{},,{},\n'

                log_format_string = '{},{},{},,\n'
                result = sess.run(
                    [disc_CE, backpass, disc_accuracy],
                    feed_dict={
                        real_batch:
                        inputs,
                        sampler_input_batch:
                        MVN(np.zeros(options['latent_dims']),
                            np.diag(np.ones(options['latent_dims'])),
                            size=options['batch_size'])
                    })

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(last_losses)))
                    train_acc.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(last_accs)))

                    train_log.flush()
                    train_acc.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                last_accs = np.roll(last_accs, 1)
                last_accs[0] = result[-1]

                disc_tracker = np.roll(disc_tracker, 1)
                disc_tracker[0] = result[-1]

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last accuracies: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_accs)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    valid_accs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        result = sess.run(
                            [disc_CE, disc_accuracy],
                            feed_dict={
                                real_batch:
                                val_batch,
                                sampler_input_batch:
                                MVN(np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size=options['batch_size'])
                            })
                        valid_costs.append(result[0])
                        valid_accs.append(result[-1])
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))
                    log.info('Validation accuracies: {:0>15.4f}'.format(
                        float(np.mean(valid_accs))))

                    val_samples = sess.run(
                        sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(valid_costs)))
                    val_acc.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(valid_accs)))
                    val_log.flush()
                    val_acc.flush()

                    save_ae_samples(catalog,
                                    np.ones([options['batch_size']] +
                                            options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#12
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'],
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            options['p_layers'],
            options['q_layers'],
            np.prod(options['img_shape']),
            options['latent_dims'],
            options['DKL_weight'],
            options['sigma_clip'],
            'vanilla_vae'
        )
        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'enc_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'dec_inputs'
        )
        log.info('Inputs defined')

        # Define forward pass
        cost_function = model(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = model.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        
        train_step = optimizer.minimize(cost_function)

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')

            # test_LL_and_DKL(sess, test_provider, model.DKL, model.rec_loss, options, model_input_batch)
            # return

            # if options['data_dir'] == 'MNIST':
            #     mean_img = np.zeros(np.prod(options['img_shape']))
            #     std_img = np.ones(np.prod(options['img_shape']))
            # else:
            #     mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension']))
            #     std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension']))
            # visualize(model.sampler_mean, sess, model.dec_mean, model.dec_log_std_sq, sampler, sampler_input_batch,
            #             model_input_batch, model.enc_mean, model.enc_log_std_sq,
            #             train_provider, val_provider, options, catalog, mean_img, std_img)
            # return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1          2           3               4                     5                       6              7               8            9           10
                    [cost_function, backpass, model.DKL, model.rec_loss, model.dec_log_std_sq, model.enc_log_std_sq, model.enc_mean, model.dec_mean],
                    feed_dict = {
                        model_input_batch: inputs
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()                    
                    # val_sig_log.flush()
                # print('\n\nENC_MEAN:')
                # print(result[3])
                # print('\n\nENC_STD:')
                # print(result[2])
                # print('\nDEC_MEAN:')
                # print(result[6])
                # print('\nDEC_STD:')
                # print(result[5])

                # print('\n\nENCODER WEIGHTS:')
                # print(model._encoder.layers[0].weights['w'].eval())
                # print('\n\DECODER WEIGHTS:')
                # print(model._decoder.layers[0].weights['w'].eval())

                # print(model._encoder.layers[0].weights['w'].eval())
                # print(result[2])
                # print(result[3])

                # print(result[3])
                # print(result[2])
                # print(result[-2])
                # print(result[-1])

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    save_dict = {}
                    # Save encoder params ------------------------------------------------------------------
                    for i in range(len(model._encoder.layers)):
                        layer_dict = {
                            'input_dim':model._encoder.layers[i].input_dim,
                            'output_dim':model._encoder.layers[i].output_dim,
                            'act_fn':model._encoder.layers[i].activation,
                            'W':model._encoder.layers[i].weights['w'].eval(),
                            'b':model._encoder.layers[i].weights['b'].eval()
                        }
                        save_dict['encoder'] = layer_dict

                    layer_dict = {
                        'input_dim':model._enc_mean.input_dim,
                        'output_dim':model._enc_mean.output_dim,
                        'act_fn':model._enc_mean.activation,
                        'W':model._enc_mean.weights['w'].eval(),
                        'b':model._enc_mean.weights['b'].eval()
                    }
                    save_dict['enc_mean'] = layer_dict

                    layer_dict = {
                        'input_dim':model._enc_log_std_sq.input_dim,
                        'output_dim':model._enc_log_std_sq.output_dim,
                        'act_fn':model._enc_log_std_sq.activation,
                        'W':model._enc_log_std_sq.weights['w'].eval(),
                        'b':model._enc_log_std_sq.weights['b'].eval()
                    }
                    save_dict['enc_log_std_sq'] = layer_dict

                    # Save decoder params ------------------------------------------------------------------
                    for i in range(len(model._decoder.layers)):
                        layer_dict = {
                            'input_dim':model._decoder.layers[i].input_dim,
                            'output_dim':model._decoder.layers[i].output_dim,
                            'act_fn':model._decoder.layers[i].activation,
                            'W':model._decoder.layers[i].weights['w'].eval(),
                            'b':model._decoder.layers[i].weights['b'].eval()
                        }
                        save_dict['decoder'] = layer_dict

                    layer_dict = {
                        'input_dim':model._dec_mean.input_dim,
                        'output_dim':model._dec_mean.output_dim,
                        'act_fn':model._dec_mean.activation,
                        'W':model._dec_mean.weights['w'].eval(),
                        'b':model._dec_mean.weights['b'].eval()
                    }
                    save_dict['dec_mean'] = layer_dict

                    layer_dict = {
                        'input_dim':model._dec_log_std_sq.input_dim,
                        'output_dim':model._dec_log_std_sq.output_dim,
                        'act_fn':model._dec_log_std_sq.activation,
                        'W':model._dec_log_std_sq.weights['w'].eval(),
                        'b':model._dec_log_std_sq.weights['b'].eval()
                    }
                    save_dict['dec_log_std_sq'] = layer_dict

                    pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'vae_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    model._decoder.layers[0].weights['w'].eval()[:5,:5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        val_cost = sess.run(
                            cost_function,
                            feed_dict = {
                                model_input_batch: val_batch
                            }
                        )
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_samples = sess.run(
                        sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(
                        catalog,
                        np.reshape(result[7], [options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     val_samples,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     True,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     result[7],
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'rec_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )


            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#13
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log)

    # Initialize model ------------------------------------------------------------------

    # input_shape, input_channels, enc_params, dec_params, name=''
    with tf.device('/gpu:0'):
        if options['model'] == 'cnn_ae':
            model = cupboard(options['model'])(options['img_shape'],
                                               options['input_channels'],
                                               options['enc_params'],
                                               options['dec_params'], 'cnn_ae')

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape=[options['batch_size']] + options['img_shape'] +
                [options['input_channels']],
                name='clean')
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape=[options['batch_size']] + options['img_shape'] +
                [options['input_channels']],
                name='noisy')
            log.info('Inputs defined')

        else:
            model = cupboard(options['model'])(
                np.prod(options['img_shape']) * options['input_channels'],
                options['enc_params'], options['dec_params'], 'ae')

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape=[options['batch_size']] +
                [np.prod(options['img_shape']) * options['input_channels']],
                name='clean')
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape=[options['batch_size']] +
                [np.prod(options['img_shape']) * options['input_channels']],
                name='noisy')
            log.info('Inputs defined')

        log.info('Model initialized')

        # Define forward pass
        print(model_clean_input_batch.get_shape())
        print(model_noisy_input_batch.get_shape())
        cost_function = model(model_clean_input_batch, model_noisy_input_batch)
        log.info('Forward pass graph built')

        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'],
                                             'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, _ in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    [cost_function, backpass] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_clean_input_batch: inputs,
                        model_noisy_input_batch: np.float32(inputs) + \
                            normal(
                                loc=0.0,
                                scale=np.float32(options['noise_std']),
                                size=inputs.shape
                            )
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    # Save Encoder Params
                    save_dict = {
                        'enc_W': [],
                        'enc_b': [],
                        'enc_act_fn': [],
                    }
                    if options['model'] == 'cnn_ae':
                        pass
                    else:
                        for i in range(len(model._encoder.layers)):
                            save_dict['enc_W'].append(
                                model._encoder.layers[i].weights['w'].eval())
                            save_dict['enc_b'].append(
                                model._encoder.layers[i].weights['b'].eval())
                            save_dict['enc_act_fn'].append(
                                options['enc_params']['act_fn'][i])

                    pickle.dump(
                        save_dict,
                        open(
                            os.path.join(options['model_dir'],
                                         'enc_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    model._decoder.layers[0].weights['w'].eval()[:5, :5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch, _ in val_provider:

                        noisy_val_batch = val_batch + \
                            normal(
                                loc=0.0,
                                scale=np.float32(options['noise_std']),
                                size=val_batch.shape
                            )

                        val_results = sess.run(
                            (cost_function, model.decoder),
                            feed_dict={
                                model_clean_input_batch: val_batch,
                                model_noisy_input_batch: noisy_val_batch
                            })
                        valid_costs.append(val_results[0])
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      np.mean(valid_costs)))
                    val_log.flush()

                    if options['model'] == 'conv_ae':
                        val_recon = np.reshape(val_results[-1],
                                               val_batch.shape)
                    else:
                        val_batch = np.reshape(
                            val_batch, [val_batch.shape[0]] +
                            options['img_shape'] + [options['input_channels']])
                        noisy_val_batch = np.reshape(
                            noisy_val_batch, [val_batch.shape[0]] +
                            options['img_shape'] + [options['input_channels']])
                        val_recon = np.reshape(
                            val_results[-1], [val_batch.shape[0]] +
                            options['img_shape'] + [options['input_channels']])

                    save_ae_samples(catalog,
                                    val_batch,
                                    noisy_val_batch,
                                    val_recon,
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

                    # save_samples(
                    #     val_recon,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     False,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     False,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

            log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'],
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    if options['data_dir'] != 'MNIST':
        num_data_points = len(
            os.listdir(
                os.path.join(options['data_dir'], 'train', 'patches')
            )
        )
        num_data_points -= 2

        train_provider = DataProvider(
            num_data_points,
            options['batch_size'],
            toolbox.ImageLoader(
                data_dir = os.path.join(options['data_dir'], 'train', 'patches'),
                flat=True,
                extension=options['file_extension']
            )
        )

        # Valid provider
        num_data_points = len(
            os.listdir(
                os.path.join(options['data_dir'], 'valid', 'patches')
            )
        )
        num_data_points -= 2

        val_provider = DataProvider(
            num_data_points,
            options['batch_size'],
            toolbox.ImageLoader(
                data_dir = os.path.join(options['data_dir'], 'valid', 'patches'),
                flat = True,
                extension=options['file_extension']
            )
        )

    else:
        train_provider = DataProvider(
            55000,
            options['batch_size'],
            toolbox.MNISTLoader(
                mode='train',
                flat=True
            )
        )

        val_provider = DataProvider(
            5000,
            options['batch_size'],
            toolbox.MNISTLoader(
                mode='validation',
                flat = True
            )
        )

    log.info('Data providers initialized.')


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Feature Extractor -----------------------------------------------------
        # feat_params = pickle.load(open(options['feat_params_path'], 'rb'))
        # _classifier = Sequential('CNN_Classifier')
        # _classifier += ConvLayer(
        #     options['input_channels'],
        #     feat_params[0]['n_filters_out'],
        #     feat_params[0]['input_dim'],
        #     feat_params[0]['filter_dim'],
        #     feat_params[0]['strides'],
        #     name='classifier_conv_0'
        # )
        # _classifier += feat_params[0]['act_fn']
        # _classifier.layers[-2].weights['W'] = tf.constant(feat_params[0]['W'])
        # _classifier.layers[-2].weights['b'] = tf.constant(feat_params[0]['b'])
        # print("1 conv layer")

        # i = 1
        # while i < options['num_feat_layers']:
        #     if 'filter_dim' in feat_params[i]:
        #         _classifier += ConvLayer(
        #             feat_params[i]['n_filters_in'],
        #             feat_params[i]['n_filters_out'],
        #             feat_params[i]['input_dim'],
        #             feat_params[i]['filter_dim'],
        #             feat_params[i]['strides'],
        #             name='classifier_conv_0'
        #         )
        #         _classifier += feat_params[i]['act_fn']
        #         _classifier.layers[-2].weights['W'] = tf.constant(feat_params[i]['W'])
        #         _classifier.layers[-2].weights['b'] = tf.constant(feat_params[i]['b'])
        #         print("1 conv layer")
        #     else:
        #         _classifier += ConstFC(
        #             feat_params[i]['W'],
        #             feat_params[i]['b'],
        #             activation=feat_params[i]['act_fn'],
        #             name='classifier_fc_0'
        #         )
        #         print("1 fc layer")
        #     i += 1

        if options['feat_type'] == 'fc':
            feat_model = Sequential('feat_extractor')
            feat_params = pickle.load(open(options['feat_params_path'], 'rb'))
            for i in range(options['num_feat_layers']):
                feat_model += ConstFC(
                    feat_params['enc_W'][i],
                    feat_params['enc_b'][i],
                    activation=feat_params['enc_act_fn'][i],
                    name='feat_layer_%d'%i
                )
        else:
            pass

        # VAE -------------------------------------------------------------------
        # VAE model
        vae_model = cupboard('vanilla_vae')(
            options['p_layers'],
            options['q_layers'],
            np.prod(options['img_shape']),
            options['latent_dims'],
            options['DKL_weight'],
            options['sigma_clip'],
            'vanilla_vae'
        )
        # -----------------------------------------------------------------------
        feat_vae = cupboard('feat_vae')(
            vae_model,
            feat_model,
            options['DKL_weight'],
            0.0,
            img_shape=options['img_shape'],
            input_channels=options['input_channels'],
            flat=True, 
            name='feat_vae_model'
        )

        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'enc_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'dec_inputs'
        )
        log.info('Inputs defined')

        # Define forward pass
        cost_function = feat_vae(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = feat_vae.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
            if options['data_dir'] == 'MNIST':
                mean_img = np.zeros(np.prod(options['img_shape']))
                std_img = np.ones(np.prod(options['img_shape']))
            else:
                mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension']))
                std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension']))
            visualize(feat_vae.vae.sampler_mean, sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq, sampler, sampler_input_batch,
                        model_input_batch, feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq,
                        train_provider, val_provider, options, catalog, mean_img, std_img)
            return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1               2           3               4               5               6              7               8            9           10
                    [cost_function, backpass, feat_vae.vae.DKL, feat_vae.vae.rec_loss, feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq, feat_vae.vae.enc_mean, feat_vae.vae.dec_mean] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_input_batch: inputs
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()                    
                    # val_sig_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:

                        val_cost = sess.run(
                            cost_function,
                            feed_dict = {
                                model_input_batch: val_batch
                            }
                        )
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_samples = sess.run(
                        sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(
                        catalog,
                        np.reshape(result[7], [options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    save_samples(
                        val_samples,
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'valid_samples'),
                        True,
                        options['img_shape'],
                        5
                    )

                    save_samples(
                        inputs,
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'input_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5
                    )

                    save_samples(
                        result[7],
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'rec_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5
                    )


            log.info('End of epoch {}'.format(epoch_idx + 1))
示例#15
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'],
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):

        # Define inputs ----------------------------------------------------------
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'enc_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'dec_inputs'
        )
        log.info('Inputs defined')

        # Feature Extractor -----------------------------------------------------
        feat_layers = []
        feat_params = pickle.load(open(options['feat_params_path'], 'rb'))
        _classifier = Sequential('CNN_Classifier')

        conv_count, pool_count, fc_count = 0, 0, 0
        for lay in feat_params:
            print(lay['layer_type'])
        for i in xrange(options['num_feat_layers']):
            if feat_params[i]['layer_type'] == 'conv':
                _classifier += ConvLayer(
                    feat_params[i]['n_filters_in'],
                    feat_params[i]['n_filters_out'],
                    feat_params[i]['input_dim'],
                    feat_params[i]['filter_dim'],
                    feat_params[i]['strides'],
                    name='classifier_conv_%d' % conv_count
                )
                _classifier.layers[-1].weights['W'] = tf.constant(feat_params[i]['W'])
                _classifier.layers[-1].weights['b'] = tf.constant(feat_params[i]['b'])
                _classifier += feat_params[i]['act_fn']
                conv_count += 1
            elif feat_params[i]['layer_type'] == 'pool':
                _classifier += PoolLayer(
                    feat_params[i]['input_dim'],
                    feat_params[i]['filter_dim'],
                    feat_params[i]['strides'],
                    name='classifier_pool_%d' % i
                )
                pool_count += 1
                feat_layers.append(i)
            elif feat_params[i]['layer_type'] == 'fc':
                _classifier += ConstFC(
                    feat_params[i]['W'],
                    feat_params[i]['b'],
                    activation=feat_params[i]['act_fn'],
                    name='classifier_fc_%d' % fc_count
                )
                fc_count += 1
                feat_layers.append(i)

        # if options['feat_type'] == 'fc':
        #     feat_model = Sequential('feat_extractor')
        #     feat_params = pickle.load(open(options['feat_params_path'], 'rb'))
        #     for i in range(options['num_feat_layers']):
        #         feat_model += ConstFC(
        #             feat_params['enc_W'][i],
        #             feat_params['enc_b'][i],
        #             activation=feat_params['enc_act_fn'][i],
        #             name='feat_layer_%d'%i
        #         )
        # else:
        #     pass

        # VAE -------------------------------------------------------------------
        # VAE model
        vae_model = cupboard('vanilla_vae')(
            options['p_layers'],
            options['q_layers'],
            np.prod(options['img_shape']),
            options['latent_dims'],
            options['DKL_weight'],
            options['sigma_clip'],
            'vanilla_vae'
        )
        # -----------------------------------------------------------------------
        feat_vae = cupboard('feat_vae')(
            vae_model,
            _classifier,
            feat_layers,
            options['DKL_weight'],
            options['vae_rec_loss_weight'],
            img_shape=options['img_shape'],
            input_channels=options['input_channels'],
            flat=False, 
            name='feat_vae_model'
        )

        log.info('Model initialized')

        # Define forward pass
        cost_function = feat_vae(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = feat_vae.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')

            test_LL_and_DKL(sess, test_provider, feat_vae.vae.DKL, feat_vae.vae.rec_loss, options, model_input_batch)
            return

            mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension']))
            std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension']))
            visualize(sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq, sampler, sampler_input_batch,
                        model_input_batch, feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq,
                        train_provider, val_provider, options, catalog, mean_img, std_img)
            return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs,_ in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1               2           3               4               5               6              7               8            9           10
                    [cost_function, backpass, feat_vae.vae.DKL, feat_vae.vae.rec_loss, feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq, feat_vae.vae.enc_mean, feat_vae.vae.dec_mean] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_input_batch: inputs
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()                    
                    # val_sig_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch,_ in val_provider:

                        val_cost = sess.run(
                            cost_function,
                            feed_dict = {
                                model_input_batch: val_batch
                            }
                        )
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_samples = sess.run(
                        sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(
                        catalog,
                        np.reshape(result[7], [options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    save_samples(
                        val_samples,
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'valid_samples'),
                        True,
                        options['img_shape'],
                        5
                    )

                    save_samples(
                        inputs,
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'input_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5
                    )

                    save_samples(
                        result[7],
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'rec_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5
                    )


            log.info('End of epoch {}'.format(epoch_idx + 1))
    # Test Model --------------------------------------------------------------------------
        test_results = []

        for inputs in test_provider:
            if isinstance(inputs, tuple):
                inputs = inputs[0]
            batch_results = sess.run(
                [
                    feat_vae.vae.DKL, feat_vae.vae.rec_loss,
                    feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq,
                    feat_vae.vae.dec_mean, feat_vae.vae.enc_mean
                ],
                feed_dict = {
                    model_input_batch: inputs
                }
            )

            test_results.append(map(lambda p: np.mean(p, axis=1) if len(p.shape) > 1 else np.mean(p), batch_results))
        test_results = map(list, zip(*test_results))

        # Print results
        log.info('Test Mean Rec. Loss: {:0>15.4f}'.format(
            float(np.mean(test_results[1]))
        ))
        log.info('Test DKL: {:0>15.4f}'.format(
            float(np.mean(test_results[0]))
        ))
        log.info('Test Dec. Mean Log Std Sq: {:0>15.4f}'.format(
            float(np.mean(test_results[2]))
        ))
        log.info('Test Enc. Mean Log Std Sq: {:0>15.4f}'.format(
            float(np.mean(test_results[3]))
        ))
        log.info('Test Dec. Mean Mean: {:0>15.4f}'.format(
            float(np.mean(test_results[4]))
        ))
        log.info('Test Enc. Mean Mean: {:0>15.4f}'.format(
            float(np.mean(test_results[5]))
        ))
示例#16
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            options['p_layers'], options['q_layers'],
            np.prod(options['img_shape']), options['latent_dims'],
            options['DKL_weight'], options['sigma_clip'], 'vanilla_vae')
        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        model_label_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['num_classes']],
            name='labels')
        log.info('Inputs defined')

        # Load VAE
        model(model_input_batch)

        feat_params = pickle.load(open(options['feat_params_path'], 'rb'))

        for i in range(len(model._encoder.layers)):
            model._encoder.layers[i].weights['w'] = tf.constant(
                feat_params[i]['W'])
            model._encoder.layers[i].weights['b'] = tf.constant(
                feat_params[i]['b'])

        model._enc_mean.weights['w'] = tf.constant(feat_params[-2]['W'])
        model._enc_mean.weights['b'] = tf.constant(feat_params[-2]['b'])

        model._enc_log_std_sq.weights['w'] = tf.constant(feat_params[-1]['W'])
        model._enc_log_std_sq.weights['b'] = tf.constant(feat_params[-1]['b'])

        enc_std = tf.exp(tf.mul(0.5, model.enc_log_std_sq))

        classifier = FC(
            model.latent_dims,
            options['num_classes'],
            activation=None,
            scale=0.01,
            name='classifier_fc')(tf.add(
                tf.mul(tf.random_normal([model.n_samples, model.latent_dims]),
                       enc_std), model.enc_mean))

        classifier = tf.nn.softmax(classifier)
        cost_function = -tf.mul(model_label_batch, tf.log(classifier))
        cost_function = tf.reduce_sum(cost_function)
        cost_function *= 1 / float(options['batch_size'])

        log.info('Forward pass graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])

        # train_step = optimizer.minimize(cost_function)

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'],
                                             'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, labels in train_provider:

                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1          2           3               4                     5                       6              7               8            9           10
                    [cost_function, backpass, classifier] +
                    [gv[0] for gv in grads],
                    feed_dict={
                        model_input_batch: inputs,
                        model_label_batch: labels
                    })

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                last_accs = np.roll(last_accs, 1)
                last_accs[0] = np.mean(
                    np.argmax(labels, axis=1) == np.argmax(result[2], axis=1))

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Acc.: {:0>15.4f} Mean last accs: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, last_accs[0],
                                np.mean(last_accs)))
                    log.info('Batch Mean Loss: {:0>15.4f}'.format(
                        np.mean(last_losses)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch, labels in val_provider:

                        val_result = sess.run(
                            [cost_function, classifier],
                            feed_dict={
                                model_input_batch: val_batch,
                                model_label_batch: labels
                            })
                        val_cost = np.mean(
                            np.argmax(labels, axis=1) == np.argmax(
                                val_result[1], axis=1))
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation acc.: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      np.mean(valid_costs)))
                    val_log.flush()

            log.info('End of epoch {}'.format(epoch_idx + 1))
    # --------------------------------------------------------------------------
        test_results = []

        for inputs, labels in test_provider:
            if isinstance(inputs, tuple):
                inputs = inputs[0]
            batch_results = sess.run([cost_function, classifier],
                                     feed_dict={
                                         model_input_batch: inputs,
                                         model_label_batch: labels
                                     })

            test_results.append(
                np.mean(
                    np.argmax(labels, axis=1) == np.argmax(batch_results[1],
                                                           axis=1)))

        # Print results
        log.info('Test Accuracy: {:0>15.4f}'.format(np.mean(test_results)))
示例#17
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))

    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')

    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
train_acc.csv,csv,Train Accuracy
val_loss.csv,csv,Validation Loss
val_acc.csv,csv,Validation Accuracy
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    train_acc_log = open(
        os.path.join(options['dashboard_dir'], 'train_acc.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    val_acc_log = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'),
                       'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    train_acc_log.write('step,time,Train Accuracy\n')
    val_acc_log.write('step,time,Validation Accuracy\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

    train_data = mnist.train.images
    train_labels = mnist.train.labels
    validation_data = mnist.validation.images
    validation_labels = mnist.validation.labels
    test_data = mnist.test.images
    test_labels = mnist.test.labels

    data_percentage = options['data_percentage']

    print(train_data.shape)

    log.info('Data providers initialized.')

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            options['img_shape'], options['input_channels'],
            options['num_classes'], options['conv_params'],
            options['pool_params'], options['fc_params'], 'MNIST_classifier')
        log.info('Model initialized')

        # Define inputs
        input_batch = tf.placeholder(tf.float32,
                                     shape=[options['batch_size']] +
                                     options['img_shape'] +
                                     [options['input_channels']],
                                     name='inputs')
        label_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['num_classes']],
            name='labels')
        keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        log.info('Inputs defined')

        # Define forward pass
        cost_function, classifier = model(input_batch, label_batch, keep_prob)
        log.info('Forward pass graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'],
                                             'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for i in xrange(
                    int((data_percentage * train_data.shape[0]) /
                        options['batch_size'])):
                inputs = np.reshape(train_data[i:i + options['batch_size'], :],
                                    [options['batch_size'], 28, 28, 1])
                labels = train_labels[i:i + options['batch_size'], :]

                batch_abs_idx += 1
                batch_rel_idx += 1

                results = sess.run([cost_function, classifier, backpass] +
                                   [gv[0] for gv in grads],
                                   feed_dict={
                                       input_batch: inputs,
                                       label_batch: labels,
                                       keep_prob: 0.5
                                   })

                cost = results[0]
                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    return 1., 1., 1.

                accuracy = np.mean(
                    np.argmax(results[1], axis=1) == np.argmax(labels, axis=1))

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost
                last_accs = np.roll(last_accs, 1)
                last_accs[0] = accuracy

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_acc_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_accs)))
                    train_log.flush()
                    train_acc_log.flush()

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Accuracy: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, np.mean(last_accs)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))

                    save_dict = []
                    for c_ind in xrange(0, len(model._classifier_conv.layers)):
                        if isinstance(model._classifier_conv.layers[c_ind],
                                      ConvLayer):
                            layer_dict = {
                                'n_filters_in':
                                model._classifier_conv.layers[c_ind].
                                n_filters_in,
                                'n_filters_out':
                                model._classifier_conv.layers[c_ind].
                                n_filters_out,
                                'input_dim':
                                model._classifier_conv.layers[c_ind].input_dim,
                                'filter_dim':
                                model._classifier_conv.layers[c_ind].
                                filter_dim,
                                'strides':
                                model._classifier_conv.layers[c_ind].strides,
                                'padding':
                                model._classifier_conv.layers[c_ind].padding,
                                'act_fn':
                                model._classifier_conv.layers[c_ind + 1],
                                'W':
                                model._classifier_conv.layers[c_ind].
                                weights['W'].eval(),
                                'b':
                                model._classifier_conv.layers[c_ind].
                                weights['b'].eval(),
                                'layer_type':
                                'conv'
                            }
                            save_dict.append(layer_dict)
                        elif isinstance(model._classifier_conv.layers[c_ind],
                                        PoolLayer):
                            layer_dict = {
                                'input_dim':
                                model._classifier_conv.layers[c_ind].input_dim,
                                'filter_dim':
                                model._classifier_conv.layers[c_ind].
                                filter_dim,
                                'strides':
                                model._classifier_conv.layers[c_ind].strides,
                                'layer_type':
                                'pool'
                            }
                            save_dict.append(layer_dict)

                    for c_ind in xrange(0,
                                        len(model._classifier_fc.layers) - 2,
                                        2):
                        layer_dict = {
                            'input_dim':
                            model._classifier_fc.layers[c_ind].input_dim,
                            'output_dim':
                            model._classifier_fc.layers[c_ind].output_dim,
                            'act_fn':
                            model.fc_params['act_fn'][c_ind],
                            'W':
                            model._classifier_fc.layers[c_ind].weights['w'].
                            eval(),
                            'b':
                            model._classifier_fc.layers[c_ind].weights['b'].
                            eval(),
                            'layer_type':
                            'fc'
                        }
                        save_dict.append(layer_dict)
                    pickle.dump(
                        save_dict,
                        open(
                            os.path.join(options['model_dir'],
                                         'class_dict_%d' % batch_abs_idx),
                            'wb'))

                    log.info('Model saved')

                    # Save params for feature vae training later
                    # conv_feat = deepcopy(model._classifier_conv)
                    # for lay_ind in range(0,len(conv_feat.layers),2):
                    #     conv_feat[lay_ind].weights['W'] = tf.constant(conv_feat[lay_ind].weights['W'].eval())
                    #     conv_feat[lay_ind].weights['b'] = tf.constant(conv_feat[lay_ind].weights['b'].eval())
                    # pickle(conv_feat, open(os.path.join(options['model_dir'], 'classifier_conv_feat_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    val_accuracies = []
                    seen_batches = 0
                    for j in xrange(
                            int((data_percentage * validation_data.shape[0]) /
                                options['batch_size'])):
                        val_batch = np.reshape(
                            validation_data[j:j + options['batch_size'], :],
                            [options['batch_size'], 28, 28, 1])
                        val_label = validation_labels[j:j +
                                                      options['batch_size'], :]

                        # Break if 10 batches seen for now
                        if seen_batches == options['valid_batches']:
                            break

                        val_results = sess.run(
                            [cost_function, classifier],
                            feed_dict={
                                input_batch: val_batch,
                                label_batch: val_label,
                                keep_prob: 1.0
                            })
                        val_cost = val_results[0]
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        val_accuracies.append(
                            np.mean(
                                np.argmax(val_results[1], axis=1) == np.argmax(
                                    val_label, axis=1)))

                    # Print results
                    log.info('Mean Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))
                    log.info('Mean Validation Accuracy: {:0>15.4f}'.format(
                        np.mean(val_accuracies)))

                    val_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22',
                        float(np.mean(valid_costs))))
                    val_acc_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(val_accuracies)))
                    val_log.flush()
                    val_acc_log.flush()

            log.info('End of epoch {}'.format(epoch_idx + 1))
    # --------------------------------------------------------------------------

        test_costs = []
        test_accuracies = []
        for j in xrange(test_data.shape[0] / options['batch_size']):
            test_batch = np.reshape(test_data[j:j + options['batch_size'], :],
                                    [options['batch_size'], 28, 28, 1])
            test_label = test_labels[j:j + options['batch_size'], :]

            test_results = sess.run([cost_function, classifier],
                                    feed_dict={
                                        input_batch: test_batch,
                                        label_batch: test_label,
                                        keep_prob: 1.0
                                    })
            test_cost = test_results[0]
            test_costs.append(test_cost)

            test_accuracies.append(
                np.mean(
                    np.argmax(test_results[1], axis=1) == np.argmax(test_label,
                                                                    axis=1)))

        # Print results
        log.info('Test loss: {:0>15.4f}'.format(float(np.mean(test_costs))))
        log.info('Test Accuracy: {:0>15.4f}'.format(np.mean(test_accuracies)))
示例#18
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'],
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Discriminator Cross-Entropy
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n')
    val_log.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n')
    dkl_log.write('step,time,DKL (Training Vanilla),DKL (Training Gen.),DKL (Training Disc.)\n')
    ll_log.write('step,time,-LL (Training Vanilla),-LL (Training Gen.),-LL (Training Disc.)\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2 (Training Vanilla),Decoder Log Sigma^2 (Training Gen.),Decoder Log Sigma^2 (Training Disc.)\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2 (Training Vanilla),Encoder Log Sigma^2 (Training Gen.),Encoder Log Sigma^2 (Training Disc.)\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2 (Training Vanilla),STD of Decoder Log Sigma^2 (Training Gen.),STD of Decoder Log Sigma^2 (Training Disc.)\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2 (Training Vanilla),STD of Encoder Log Sigma^2 (Training Gen.),STD of Encoder Log Sigma^2 (Training Disc.)\n')

    dec_mean_log.write('step,time,Decoder Mean (Training Vanilla),Decoder Mean (Training Gen.),Decoder Mean (Training Disc.)\n')
    enc_mean_log.write('step,time,Encoder Mean (Training Vanilla),Encoder Mean (Training Gen.),Encoder Mean (Training Disc.)\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'enc_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'dec_inputs'
        )
        log.info('Inputs defined')

        # Define model
        with tf.variable_scope('vae_scope'):
            vae_model = cupboard('vanilla_vae')(
                options['p_layers'],
                options['q_layers'],
                np.prod(options['img_shape']),
                options['latent_dims'],
                options['DKL_weight'],
                options['sigma_clip'],
                'vae_model'
            )

        with tf.variable_scope('disc_scope'):
            disc_model = cupboard('fixed_conv_disc')(
                pickle.load(open(options['disc_params_path'], 'rb')),
                options['num_feat_layers'],
                name='disc_model'
            )

        vae_gan = cupboard('vae_gan')(
            vae_model,
            disc_model,
            options['disc_weight'],
            options['img_shape'],
            options['input_channels'],
            'vae_scope',
            'disc_scope',
            name='vae_gan_model'
        )

        # Define Optimizers ---------------------------------------------------------------------
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )

        vae_backpass, disc_backpass, vanilla_backpass = vae_gan(model_input_batch, sampler_input_batch, optimizer)

        log.info('Optimizer graph built')
        # --------------------------------------------------------------------------------------
        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload_all']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Variables initialized')

            if options['reload_vae']:
                vae_model.reload_vae(options['vae_params_path'])

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        base = options['initial_G_iters'] + options['initial_D_iters']

        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                if batch_abs_idx < options['initial_G_iters']:
                    backpass = vanilla_backpass
                    log_format_string = '{},{},{},,\n'
                elif options['initial_G_iters'] <= batch_abs_idx < base:
                    backpass = disc_backpass
                    log_format_string = '{},{},,,{}\n'
                else:
                    if (batch_abs_idx - base) % total_D2G < D_to_G[0]:
                        backpass = disc_backpass
                        log_format_string = '{},{},,,{}\n'
                    else:
                        backpass = vae_backpass
                        log_format_string = '{},{},,{},\n'

                result = sess.run(
                    [
                        vae_gan.disc_CE,
                        backpass,
                        vae_gan._vae.DKL,
                        vae_gan._vae.rec_loss,
                        vae_gan._vae.dec_log_std_sq,
                        vae_gan._vae.enc_log_std_sq,
                        vae_gan._vae.enc_mean,
                        vae_gan._vae.dec_mean
                    ],
                    feed_dict = {
                        model_input_batch: inputs,
                        sampler_input_batch: MVN(
                            np.zeros(options['latent_dims']),
                            np.diag(np.ones(options['latent_dims'])),
                            size = options['batch_size']
                        )
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', -np.mean(result[2])))
                    ll_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(vae_gan._vae._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    save_dict = {}
                    # Save encoder params ------------------------------------------------------------------
                    for i in range(len(vae_gan._vae._encoder.layers)):
                        layer_dict = {
                            'input_dim':vae_gan._vae._encoder.layers[i].input_dim,
                            'output_dim':vae_gan._vae._encoder.layers[i].output_dim,
                            'act_fn':vae_gan._vae._encoder.layers[i].activation,
                            'W':vae_gan._vae._encoder.layers[i].weights['w'].eval(),
                            'b':vae_gan._vae._encoder.layers[i].weights['b'].eval()
                        }
                        save_dict['encoder'] = layer_dict

                    layer_dict = {
                        'input_dim':vae_gan._vae._enc_mean.input_dim,
                        'output_dim':vae_gan._vae._enc_mean.output_dim,
                        'act_fn':vae_gan._vae._enc_mean.activation,
                        'W':vae_gan._vae._enc_mean.weights['w'].eval(),
                        'b':vae_gan._vae._enc_mean.weights['b'].eval()
                    }
                    save_dict['enc_mean'] = layer_dict

                    layer_dict = {
                        'input_dim':vae_gan._vae._enc_log_std_sq.input_dim,
                        'output_dim':vae_gan._vae._enc_log_std_sq.output_dim,
                        'act_fn':vae_gan._vae._enc_log_std_sq.activation,
                        'W':vae_gan._vae._enc_log_std_sq.weights['w'].eval(),
                        'b':vae_gan._vae._enc_log_std_sq.weights['b'].eval()
                    }
                    save_dict['enc_log_std_sq'] = layer_dict

                    # Save decoder params ------------------------------------------------------------------
                    for i in range(len(vae_gan._vae._decoder.layers)):
                        layer_dict = {
                            'input_dim':vae_gan._vae._decoder.layers[i].input_dim,
                            'output_dim':vae_gan._vae._decoder.layers[i].output_dim,
                            'act_fn':vae_gan._vae._decoder.layers[i].activation,
                            'W':vae_gan._vae._decoder.layers[i].weights['w'].eval(),
                            'b':vae_gan._vae._decoder.layers[i].weights['b'].eval()
                        }
                        save_dict['decoder'] = layer_dict

                    layer_dict = {
                        'input_dim':vae_gan._vae._dec_mean.input_dim,
                        'output_dim':vae_gan._vae._dec_mean.output_dim,
                        'act_fn':vae_gan._vae._dec_mean.activation,
                        'W':vae_gan._vae._dec_mean.weights['w'].eval(),
                        'b':vae_gan._vae._dec_mean.weights['b'].eval()
                    }
                    save_dict['dec_mean'] = layer_dict

                    layer_dict = {
                        'input_dim':vae_gan._vae._dec_log_std_sq.input_dim,
                        'output_dim':vae_gan._vae._dec_log_std_sq.output_dim,
                        'act_fn':vae_gan._vae._dec_log_std_sq.activation,
                        'W':vae_gan._vae._dec_log_std_sq.weights['w'].eval(),
                        'b':vae_gan._vae._dec_log_std_sq.weights['b'].eval()
                    }
                    save_dict['dec_log_std_sq'] = layer_dict

                    pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'vae_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    vae_gan._vae._decoder.layers[0].weights['w'].eval()[:5,:5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        val_cost = sess.run(
                            vae_gan.disc_CE,
                            feed_dict = {
                                model_input_batch: val_batch,
                                sampler_input_batch: MVN(
                                    np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size = options['batch_size']
                                )
                            }
                        )
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_samples = sess.run(
                        vae_gan.sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(
                        catalog,
                        np.reshape(result[7], [options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

            log.info('End of epoch {}'.format(epoch_idx + 1))