Exemple #1
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'),
                        'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'),
                        'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    if options['data_dir'] != 'MNIST':
        num_data_points = len(
            os.listdir(os.path.join(options['data_dir'], 'train', 'patches')))
        num_data_points -= 2

        train_provider = DataProvider(
            num_data_points, options['batch_size'],
            toolbox.ImageLoader(data_dir=os.path.join(options['data_dir'],
                                                      'train', 'patches'),
                                flat=True,
                                extension=options['file_extension']))

        # Valid provider
        num_data_points = len(
            os.listdir(os.path.join(options['data_dir'], 'valid', 'patches')))
        num_data_points -= 2

        val_provider = DataProvider(
            num_data_points, options['batch_size'],
            toolbox.ImageLoader(data_dir=os.path.join(options['data_dir'],
                                                      'valid', 'patches'),
                                flat=True,
                                extension=options['file_extension']))

    else:
        train_provider = DataProvider(
            55000, options['batch_size'],
            toolbox.MNISTLoader(mode='train', flat=True))

        val_provider = DataProvider(
            5000, options['batch_size'],
            toolbox.MNISTLoader(mode='validation', flat=True))

    log.info('Data providers initialized.')

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            options['p_layers'], options['q_layers'],
            np.prod(options['img_shape']), options['latent_dims'],
            options['DKL_weight'], options['sigma_clip'], 'vanilla_vae')
        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        model_label_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['num_classes']],
            name='labels')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='dec_inputs')
        log.info('Inputs defined')

        # Define forward pass
        cost_function = model(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = model.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])

        enc_std = tf.exp(tf.mul(0.5, model.enc_log_std_sq))

        classifier = FC(
            model.latent_dims,
            options['num_classes'],
            activation=None,
            scale=0.01,
            name='classifier_fc')(tf.add(
                tf.mul(tf.random_normal([model.n_samples, model.latent_dims]),
                       enc_std), model.enc_mean))

        classifier = tf.nn.softmax(classifier)
        cost_function = -tf.mul(model_label_batch, tf.log(classifier))
        cost_function = tf.reduce_sum(cost_function)
        cost_function *= 1 / float(options['batch_size'])

        train_step = optimizer.minimize(cost_function)

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        sess.run(init_op)
        saver.restore(
            sess, os.path.join(options['model_dir'], 'model_at_21000.ckpt'))
        log.info('Shared variables restored')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, labels in train_provider:

                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1          2           3               4                     5                       6              7               8            9           10
                    [
                        cost_function, backpass, model.DKL, model.rec_loss,
                        model.dec_log_std_sq, model.enc_log_std_sq,
                        model.enc_mean, model.dec_mean, classifier
                    ] + [gv[0] for gv in grads],
                    feed_dict={
                        model_input_batch: inputs,
                        model_label_batch: labels
                    })

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22',
                        np.mean(
                            np.argmax(labels, axis=1) == np.argmax(result[8],
                                                                   axis=1))))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                     '2016-04-22',
                                                     -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()
                    # val_sig_log.flush()
                # print('\n\nENC_MEAN:')
                # print(result[3])
                # print('\n\nENC_STD:')
                # print(result[2])
                # print('\nDEC_MEAN:')
                # print(result[6])
                # print('\nDEC_STD:')
                # print(result[5])

                # print('\n\nENCODER WEIGHTS:')
                # print(model._encoder.layers[0].weights['w'].eval())
                # print('\n\DECODER WEIGHTS:')
                # print(model._decoder.layers[0].weights['w'].eval())

                # print(model._encoder.layers[0].weights['w'].eval())
                # print(result[2])
                # print(result[3])

                # print(result[3])
                # print(result[2])
                # print(result[-2])
                # print(result[-1])

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(
                        np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(
                        np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch, val_labels in val_provider:

                        val_result = sess.run(
                            [cost_function, classifier],
                            feed_dict={
                                model_input_batch: val_batch,
                                model_label_batch: val_labels
                            })
                        val_cost = val_result[0]
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_samples = sess.run(
                        sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22',
                        np.mean(
                            np.argmax(val_labels, axis=1) == np.argmax(
                                val_result[1], axis=1))))
                    val_log.flush()

                    save_ae_samples(catalog,
                                    np.reshape(result[7],
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    save_samples(
                        val_samples,
                        int(batch_abs_idx / options['freq_validation']),
                        os.path.join(options['model_dir'], 'valid_samples'),
                        True, options['img_shape'], 5)

                    save_samples(
                        inputs,
                        int(batch_abs_idx / options['freq_validation']),
                        os.path.join(options['model_dir'], 'input_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5)

                    save_samples(
                        result[7],
                        int(batch_abs_idx / options['freq_validation']),
                        os.path.join(options['model_dir'], 'rec_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5)

            log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'],
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):

        # Define inputs ----------------------------------------------------------
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'enc_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'dec_inputs'
        )
        log.info('Inputs defined')

        # Feature Extractor -----------------------------------------------------
        feat_layers = []
        feat_params = pickle.load(open(options['feat_params_path'], 'rb'))
        _classifier = Sequential('CNN_Classifier')

        conv_count, pool_count, fc_count = 0, 0, 0
        for lay in feat_params:
            print(lay['layer_type'])
        for i in xrange(options['num_feat_layers']):
            if feat_params[i]['layer_type'] == 'conv':
                _classifier += ConvLayer(
                    feat_params[i]['n_filters_in'],
                    feat_params[i]['n_filters_out'],
                    feat_params[i]['input_dim'],
                    feat_params[i]['filter_dim'],
                    feat_params[i]['strides'],
                    name='classifier_conv_%d' % conv_count
                )
                _classifier.layers[-1].weights['W'] = tf.constant(feat_params[i]['W'])
                _classifier.layers[-1].weights['b'] = tf.constant(feat_params[i]['b'])
                _classifier += feat_params[i]['act_fn']
                conv_count += 1
            elif feat_params[i]['layer_type'] == 'pool':
                _classifier += PoolLayer(
                    feat_params[i]['input_dim'],
                    feat_params[i]['filter_dim'],
                    feat_params[i]['strides'],
                    name='classifier_pool_%d' % i
                )
                pool_count += 1
                feat_layers.append(i)
            elif feat_params[i]['layer_type'] == 'fc':
                _classifier += ConstFC(
                    feat_params[i]['W'],
                    feat_params[i]['b'],
                    activation=feat_params[i]['act_fn'],
                    name='classifier_fc_%d' % fc_count
                )
                fc_count += 1
                feat_layers.append(i)

        # if options['feat_type'] == 'fc':
        #     feat_model = Sequential('feat_extractor')
        #     feat_params = pickle.load(open(options['feat_params_path'], 'rb'))
        #     for i in range(options['num_feat_layers']):
        #         feat_model += ConstFC(
        #             feat_params['enc_W'][i],
        #             feat_params['enc_b'][i],
        #             activation=feat_params['enc_act_fn'][i],
        #             name='feat_layer_%d'%i
        #         )
        # else:
        #     pass

        # VAE -------------------------------------------------------------------
        # VAE model
        vae_model = cupboard('vanilla_vae')(
            options['p_layers'],
            options['q_layers'],
            np.prod(options['img_shape']),
            options['latent_dims'],
            options['DKL_weight'],
            options['sigma_clip'],
            'vanilla_vae'
        )
        # -----------------------------------------------------------------------
        feat_vae = cupboard('feat_vae')(
            vae_model,
            _classifier,
            feat_layers,
            options['DKL_weight'],
            options['vae_rec_loss_weight'],
            img_shape=options['img_shape'],
            input_channels=options['input_channels'],
            flat=False, 
            name='feat_vae_model'
        )

        log.info('Model initialized')

        # Define forward pass
        cost_function = feat_vae(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = feat_vae.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')

            test_LL_and_DKL(sess, test_provider, feat_vae.vae.DKL, feat_vae.vae.rec_loss, options, model_input_batch)
            return

            mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension']))
            std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension']))
            visualize(sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq, sampler, sampler_input_batch,
                        model_input_batch, feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq,
                        train_provider, val_provider, options, catalog, mean_img, std_img)
            return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs,_ in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1               2           3               4               5               6              7               8            9           10
                    [cost_function, backpass, feat_vae.vae.DKL, feat_vae.vae.rec_loss, feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq, feat_vae.vae.enc_mean, feat_vae.vae.dec_mean] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_input_batch: inputs
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()                    
                    # val_sig_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch,_ in val_provider:

                        val_cost = sess.run(
                            cost_function,
                            feed_dict = {
                                model_input_batch: val_batch
                            }
                        )
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_samples = sess.run(
                        sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(
                        catalog,
                        np.reshape(result[7], [options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    save_samples(
                        val_samples,
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'valid_samples'),
                        True,
                        options['img_shape'],
                        5
                    )

                    save_samples(
                        inputs,
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'input_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5
                    )

                    save_samples(
                        result[7],
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'rec_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5
                    )


            log.info('End of epoch {}'.format(epoch_idx + 1))
    # Test Model --------------------------------------------------------------------------
        test_results = []

        for inputs in test_provider:
            if isinstance(inputs, tuple):
                inputs = inputs[0]
            batch_results = sess.run(
                [
                    feat_vae.vae.DKL, feat_vae.vae.rec_loss,
                    feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq,
                    feat_vae.vae.dec_mean, feat_vae.vae.enc_mean
                ],
                feed_dict = {
                    model_input_batch: inputs
                }
            )

            test_results.append(map(lambda p: np.mean(p, axis=1) if len(p.shape) > 1 else np.mean(p), batch_results))
        test_results = map(list, zip(*test_results))

        # Print results
        log.info('Test Mean Rec. Loss: {:0>15.4f}'.format(
            float(np.mean(test_results[1]))
        ))
        log.info('Test DKL: {:0>15.4f}'.format(
            float(np.mean(test_results[0]))
        ))
        log.info('Test Dec. Mean Log Std Sq: {:0>15.4f}'.format(
            float(np.mean(test_results[2]))
        ))
        log.info('Test Enc. Mean Log Std Sq: {:0>15.4f}'.format(
            float(np.mean(test_results[3]))
        ))
        log.info('Test Dec. Mean Mean: {:0>15.4f}'.format(
            float(np.mean(test_results[4]))
        ))
        log.info('Test Enc. Mean Mean: {:0>15.4f}'.format(
            float(np.mean(test_results[5]))
        ))
Exemple #3
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'Log Sigma^2 clipped to: [{}, {}]\n\n'.format(
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Discriminator Cross-Entropy
train_acc.csv,csv,Discriminator Accuracy
val_loss.csv,csv,Validation Cross-Entropy
val_acc.csv,csv,Validation Accuracy
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    train_acc = open(os.path.join(options['dashboard_dir'], 'train_acc.csv'), 'w')
    val_acc = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'), 'w')


    train_log.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n')
    val_log.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n')
    train_acc.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n')
    val_acc.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation Acc. (Training Disc.)\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Define inputs -------------------------------------------------------------------------
        real_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'real_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'noise_channel'
        )
        labels = tf.constant(
            np.expand_dims(
                np.concatenate(
                    (
                        np.ones(options['batch_size']),
                        np.zeros(options['batch_size'])
                    ),
                    axis=0
                ).astype(np.float32),
                axis=1
            )
        )
        labels = tf.cast(labels, tf.float32)
        log.info('Inputs defined')

        # Define model --------------------------------------------------------------------------
        with tf.variable_scope('gen_scope'):
            generator = Sequential('generator')
            generator += FullyConnected(options['latent_dims'], 60, tf.nn.tanh, name='fc_1')
            generator += FullyConnected(60, 60, tf.nn.tanh, name='fc_2')
            generator += FullyConnected(60, np.prod(options['img_shape']), tf.nn.tanh, name='fc_3')

            sampler = generator(sampler_input_batch)

        with tf.variable_scope('disc_scope'):
            disc_model = cupboard('fixed_conv_disc')(
                pickle.load(open(options['disc_params_path'], 'rb')),
                options['num_feat_layers'],
                name='disc_model'
            )

            disc_inputs = tf.concat(0, [real_batch, sampler])
            disc_inputs = tf.reshape(
                disc_inputs,
                [disc_inputs.get_shape()[0].value] + options['img_shape'] + [options['input_channels']]
            )

            preds = disc_model(disc_inputs)
            preds = tf.clip_by_value(preds, 0.00001, 0.99999)

            # Disc Accuracy
            disc_accuracy = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum(
                tf.cast(
                    tf.equal(
                        tf.round(preds),
                        labels
                    ),
                    tf.float32
                )
            )

        # Define Losses -------------------------------------------------------------------------
        # Discrimnator Cross-Entropy
        disc_CE = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum(
            -tf.add(
                tf.mul(
                    labels,
                    tf.log(preds)
                ),
                tf.mul(
                    1.0 - labels,
                    tf.log(1.0 - preds)
                )
            )
        )

        gen_loss = -tf.mul(
            1.0 - labels,
            tf.log(preds)
        )

        # Define Optimizers ---------------------------------------------------------------------
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )

        # Get Generator and Disriminator Trainable Variables
        gen_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'gen_scope')
        disc_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'disc_scope')

        # Get generator gradients
        grads = optimizer.compute_gradients(gen_loss, gen_train_vars)
        grads = [gv for gv in grads if gv[0] != None]
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='gen_grad_clipping'), gv[1]) for gv in grads]
        gen_backpass = optimizer.apply_gradients(clip_grads)

        # Get Dsicriminator gradients
        grads = optimizer.compute_gradients(disc_CE, disc_train_vars)
        grads = [gv for gv in grads if gv[0] != None]
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='disc_grad_clipping'), gv[1]) for gv in grads]
        disc_backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')
        # --------------------------------------------------------------------------------------
        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload_all']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))
        disc_tracker = np.ones((5000))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        base = options['initial_G_iters'] + options['initial_D_iters']
        # must_init = True
        feat_params = pickle.load(open(options['disc_params_path'], 'rb'))

        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                if batch_abs_idx < options['initial_G_iters']:
                    backpass = gen_backpass
                    log_format_string = '{},{},{},,\n'
                elif options['initial_G_iters'] <= batch_abs_idx < base:
                    backpass = disc_backpass
                    log_format_string = '{},{},,,{}\n'
                else:
                    # if np.mean(disc_tracker) < 0.95:
                    #     disc_model._disc.layers[-2].re_init_weights(sess)
                    #     disc_tracker = np.ones((5000))

                    if (batch_abs_idx - base) % total_D2G < D_to_G[0]:
                        # if must_init:
                        #     # i = 0
                        #     # for j in xrange(options['num_feat_layers']):
                        #     #     if feat_params[j]['layer_type'] == 'conv':
                        #     #         disc_model._disc.layers[i].re_init_weights(sess)
                        #     #         # print('@' * 1000)
                        #     #         # print(disc_model._disc.layers[i])
                        #     #         i += 1 # for dealing with activation function
                        #     #     elif feat_params[j]['layer_type'] == 'fc':
                        #     #         disc_model._disc.layers[i].re_init_weights(sess)
                        #     #         # print('@' * 1000)
                        #     #         # print(disc_model._disc.layers[i])
                        #     #     i += 1
                        #     disc_model._disc.layers[-2].re_init_weights(sess)
                        #     # print('@' * 1000)
                        #     # print(disc_model._disc.layers[-2])
                        #     must_init = False
                        backpass = disc_backpass
                        log_format_string = '{},{},,,{}\n'
                    else:
                        # must_init = True
                        backpass = gen_backpass
                        log_format_string = '{},{},,{},\n'

                log_format_string = '{},{},{},,\n'
                result = sess.run(
                    [
                        disc_CE,
                        backpass,
                        disc_accuracy
                    ],
                    feed_dict = {
                        real_batch: inputs,
                        sampler_input_batch: MVN(
                            np.zeros(options['latent_dims']),
                            np.diag(np.ones(options['latent_dims'])),
                            size = options['batch_size']
                        )
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_acc.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_accs)))

                    train_log.flush()
                    train_acc.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                last_accs = np.roll(last_accs, 1)
                last_accs[0] = result[-1]

                disc_tracker = np.roll(disc_tracker, 1)
                disc_tracker[0] = result[-1]

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last accuracies: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_accs)
                    ))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    valid_accs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        result = sess.run(
                            [
                                disc_CE,
                                disc_accuracy
                            ],
                            feed_dict = {
                                real_batch: val_batch,
                                sampler_input_batch: MVN(
                                    np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size = options['batch_size']
                                )
                            }
                        )
                        valid_costs.append(result[0])
                        valid_accs.append(result[-1])
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))
                    log.info('Validation accuracies: {:0>15.4f}'.format(
                        float(np.mean(valid_accs))
                    ))

                    val_samples = sess.run(
                        sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_acc.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_accs)))
                    val_log.flush()
                    val_acc.flush()

                    save_ae_samples(
                        catalog,
                        np.ones([options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

            log.info('End of epoch {}'.format(epoch_idx + 1))
Exemple #4
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Discriminator Cross-Entropy
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'),
                        'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'),
                        'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write(
        'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n'
    )
    val_log.write(
        'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n'
    )
    dkl_log.write(
        'step,time,DKL (Training Vanilla),DKL (Training Gen.),DKL (Training Disc.)\n'
    )
    ll_log.write(
        'step,time,-LL (Training Vanilla),-LL (Training Gen.),-LL (Training Disc.)\n'
    )

    dec_sig_log.write(
        'step,time,Decoder Log Sigma^2 (Training Vanilla),Decoder Log Sigma^2 (Training Gen.),Decoder Log Sigma^2 (Training Disc.)\n'
    )
    enc_sig_log.write(
        'step,time,Encoder Log Sigma^2 (Training Vanilla),Encoder Log Sigma^2 (Training Gen.),Encoder Log Sigma^2 (Training Disc.)\n'
    )

    dec_std_sig_log.write(
        'step,time,STD of Decoder Log Sigma^2 (Training Vanilla),STD of Decoder Log Sigma^2 (Training Gen.),STD of Decoder Log Sigma^2 (Training Disc.)\n'
    )
    enc_std_sig_log.write(
        'step,time,STD of Encoder Log Sigma^2 (Training Vanilla),STD of Encoder Log Sigma^2 (Training Gen.),STD of Encoder Log Sigma^2 (Training Disc.)\n'
    )

    dec_mean_log.write(
        'step,time,Decoder Mean (Training Vanilla),Decoder Mean (Training Gen.),Decoder Mean (Training Disc.)\n'
    )
    enc_mean_log.write(
        'step,time,Encoder Mean (Training Vanilla),Encoder Mean (Training Gen.),Encoder Mean (Training Disc.)\n'
    )

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='dec_inputs')
        log.info('Inputs defined')

        # Define model
        with tf.variable_scope('vae_scope'):
            vae_model = cupboard('vanilla_vae')(
                options['p_layers'], options['q_layers'],
                np.prod(options['img_shape']), options['latent_dims'],
                options['DKL_weight'], options['sigma_clip'], 'vae_model')

        with tf.variable_scope('disc_scope'):
            disc_model = cupboard('fixed_conv_disc')(
                pickle.load(open(options['disc_params_path'], 'rb')),
                options['num_feat_layers'],
                name='disc_model')

        vae_gan = cupboard('vae_gan')(vae_model,
                                      disc_model,
                                      options['disc_weight'],
                                      options['img_shape'],
                                      options['input_channels'],
                                      'vae_scope',
                                      'disc_scope',
                                      name='vae_gan_model')

        # Define Optimizers ---------------------------------------------------------------------
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])

        vae_backpass, disc_backpass, vanilla_backpass = vae_gan(
            model_input_batch, sampler_input_batch, optimizer)

        log.info('Optimizer graph built')
        # --------------------------------------------------------------------------------------
        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload_all']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Variables initialized')

            if options['reload_vae']:
                vae_model.reload_vae(options['vae_params_path'])

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        base = options['initial_G_iters'] + options['initial_D_iters']

        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                if batch_abs_idx < options['initial_G_iters']:
                    backpass = vanilla_backpass
                    log_format_string = '{},{},{},,\n'
                elif options['initial_G_iters'] <= batch_abs_idx < base:
                    backpass = disc_backpass
                    log_format_string = '{},{},,,{}\n'
                else:
                    if (batch_abs_idx - base) % total_D2G < D_to_G[0]:
                        backpass = disc_backpass
                        log_format_string = '{},{},,,{}\n'
                    else:
                        backpass = vae_backpass
                        log_format_string = '{},{},,{},\n'

                result = sess.run(
                    [
                        vae_gan.disc_CE, backpass, vae_gan._vae.DKL,
                        vae_gan._vae.rec_loss, vae_gan._vae.dec_log_std_sq,
                        vae_gan._vae.enc_log_std_sq, vae_gan._vae.enc_mean,
                        vae_gan._vae.dec_mean
                    ],
                    feed_dict={
                        model_input_batch:
                        inputs,
                        sampler_input_batch:
                        MVN(np.zeros(options['latent_dims']),
                            np.diag(np.ones(options['latent_dims'])),
                            size=options['batch_size'])
                    })

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(last_losses)))
                    dkl_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 -np.mean(result[2])))
                    ll_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(result[4])))
                    enc_sig_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.std(result[4])))
                    enc_std_sig_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.std(result[5])))

                    dec_mean_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(result[7])))
                    enc_mean_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(vae_gan._vae._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(
                        np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(
                        np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    save_dict = {}
                    # Save encoder params ------------------------------------------------------------------
                    for i in range(len(vae_gan._vae._encoder.layers)):
                        layer_dict = {
                            'input_dim':
                            vae_gan._vae._encoder.layers[i].input_dim,
                            'output_dim':
                            vae_gan._vae._encoder.layers[i].output_dim,
                            'act_fn':
                            vae_gan._vae._encoder.layers[i].activation,
                            'W':
                            vae_gan._vae._encoder.layers[i].weights['w'].eval(
                            ),
                            'b':
                            vae_gan._vae._encoder.layers[i].weights['b'].eval(
                            )
                        }
                        save_dict['encoder'] = layer_dict

                    layer_dict = {
                        'input_dim': vae_gan._vae._enc_mean.input_dim,
                        'output_dim': vae_gan._vae._enc_mean.output_dim,
                        'act_fn': vae_gan._vae._enc_mean.activation,
                        'W': vae_gan._vae._enc_mean.weights['w'].eval(),
                        'b': vae_gan._vae._enc_mean.weights['b'].eval()
                    }
                    save_dict['enc_mean'] = layer_dict

                    layer_dict = {
                        'input_dim': vae_gan._vae._enc_log_std_sq.input_dim,
                        'output_dim': vae_gan._vae._enc_log_std_sq.output_dim,
                        'act_fn': vae_gan._vae._enc_log_std_sq.activation,
                        'W': vae_gan._vae._enc_log_std_sq.weights['w'].eval(),
                        'b': vae_gan._vae._enc_log_std_sq.weights['b'].eval()
                    }
                    save_dict['enc_log_std_sq'] = layer_dict

                    # Save decoder params ------------------------------------------------------------------
                    for i in range(len(vae_gan._vae._decoder.layers)):
                        layer_dict = {
                            'input_dim':
                            vae_gan._vae._decoder.layers[i].input_dim,
                            'output_dim':
                            vae_gan._vae._decoder.layers[i].output_dim,
                            'act_fn':
                            vae_gan._vae._decoder.layers[i].activation,
                            'W':
                            vae_gan._vae._decoder.layers[i].weights['w'].eval(
                            ),
                            'b':
                            vae_gan._vae._decoder.layers[i].weights['b'].eval(
                            )
                        }
                        save_dict['decoder'] = layer_dict

                    layer_dict = {
                        'input_dim': vae_gan._vae._dec_mean.input_dim,
                        'output_dim': vae_gan._vae._dec_mean.output_dim,
                        'act_fn': vae_gan._vae._dec_mean.activation,
                        'W': vae_gan._vae._dec_mean.weights['w'].eval(),
                        'b': vae_gan._vae._dec_mean.weights['b'].eval()
                    }
                    save_dict['dec_mean'] = layer_dict

                    layer_dict = {
                        'input_dim': vae_gan._vae._dec_log_std_sq.input_dim,
                        'output_dim': vae_gan._vae._dec_log_std_sq.output_dim,
                        'act_fn': vae_gan._vae._dec_log_std_sq.activation,
                        'W': vae_gan._vae._dec_log_std_sq.weights['w'].eval(),
                        'b': vae_gan._vae._dec_log_std_sq.weights['b'].eval()
                    }
                    save_dict['dec_log_std_sq'] = layer_dict

                    pickle.dump(
                        save_dict,
                        open(
                            os.path.join(options['model_dir'],
                                         'vae_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    vae_gan._vae._decoder.layers[0].weights['w'].eval()[:5, :5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        val_cost = sess.run(
                            vae_gan.disc_CE,
                            feed_dict={
                                model_input_batch:
                                val_batch,
                                sampler_input_batch:
                                MVN(np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size=options['batch_size'])
                            })
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_samples = sess.run(
                        vae_gan.sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(catalog,
                                    np.reshape(result[7],
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

            log.info('End of epoch {}'.format(epoch_idx + 1))
Exemple #5
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log)

    # Initialize model ------------------------------------------------------------------
    
    # input_shape, input_channels, enc_params, dec_params, name=''
    with tf.device('/gpu:0'):
        if options['model'] == 'cnn_ae':
            model = cupboard(options['model'])(
                options['img_shape'],
                options['input_channels'],
                options['enc_params'],
                options['dec_params'],
                'cnn_ae'
            )

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']],
                name = 'clean'
            )
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']],
                name = 'noisy'
            )
            log.info('Inputs defined')

        else:
            model = cupboard(options['model'])(
                np.prod(options['img_shape']) * options['input_channels'],
                options['enc_params'],
                options['dec_params'],
                'ae'
            )

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']],
                name = 'clean'
            )
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape = [options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']],
                name = 'noisy'
            )
            log.info('Inputs defined')

        log.info('Model initialized')

        # Define forward pass
        print(model_clean_input_batch.get_shape())
        print(model_noisy_input_batch.get_shape())
        cost_function = model(model_clean_input_batch, model_noisy_input_batch)
        log.info('Forward pass graph built')

        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'], 'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs,_ in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    [cost_function, backpass] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_clean_input_batch: inputs,
                        model_noisy_input_batch: np.float32(inputs) + \
                            normal(
                                loc=0.0,
                                scale=np.float32(options['noise_std']),
                                size=inputs.shape
                            )
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    # Save Encoder Params
                    save_dict = {
                        'enc_W': [],
                        'enc_b': [],
                        'enc_act_fn': [],
                    }
                    if options['model'] == 'cnn_ae':
                        pass
                    else:
                        for i in range(len(model._encoder.layers)):
                            save_dict['enc_W'].append(model._encoder.layers[i].weights['w'].eval())
                            save_dict['enc_b'].append(model._encoder.layers[i].weights['b'].eval())
                            save_dict['enc_act_fn'].append(options['enc_params']['act_fn'][i])

                    pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'enc_dict_%d' % batch_abs_idx), 'wb'))


                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    model._decoder.layers[0].weights['w'].eval()[:5,:5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch,_ in val_provider:

                        noisy_val_batch = val_batch + \
                            normal(
                                loc=0.0,
                                scale=np.float32(options['noise_std']),
                                size=val_batch.shape
                            )

                        val_results = sess.run(
                            (cost_function, model.decoder),
                            feed_dict = {
                                model_clean_input_batch: val_batch,
                                model_noisy_input_batch: noisy_val_batch
                            }
                        )
                        valid_costs.append(val_results[0])
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()


                    if options['model'] == 'conv_ae':
                        val_recon = np.reshape(
                            val_results[-1],
                            val_batch.shape
                        )
                    else:
                        val_batch = np.reshape(
                            val_batch,
                            [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]
                        )
                        noisy_val_batch = np.reshape(
                            noisy_val_batch,
                            [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]
                        )
                        val_recon = np.reshape(
                            val_results[-1],
                            [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]
                        )

                    save_ae_samples(
                        catalog,
                        val_batch,
                        noisy_val_batch,
                        val_recon,
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

                    # save_samples(
                    #     val_recon,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     False,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     False,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )


            log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    with open(os.path.join(options['dashboard_dir'], 'description'),
              'w') as desc_file:
        desc_file.write(options['description'])

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
description,plain,Description
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'),
                        'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'),
                        'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):

        # Define inputs ----------------------------------------------------------
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='dec_inputs')
        log.info('Inputs defined')

        # Discriminator ---------------------------------------------------------
        # with tf.variable_scope('disc_scope'):
        #     disc_model = cupboard('fixed_conv_disc')(
        #         pickle.load(open(options['feat_params_path'], 'rb')),
        #         options['num_feat_layers'],
        #         'discriminator'
        #     )

        # VAE -------------------------------------------------------------------
        # VAE model
        # with tf.variable_scope('vae_scope'):
        vae_model = cupboard('vanilla_vae')(
            options['p_layers'], options['q_layers'],
            np.prod(options['img_shape']), options['latent_dims'],
            options['DKL_weight'], options['sigma_clip'], 'vanilla_vae')
        # VAE/GAN ---------------------------------------------------------------
        # vae_gan = cupboard('vae_gan')(
        #     vae_model,
        #     disc_model,
        #     options['img_shape'],
        #     options['input_channels'],
        #     'vae_scope',
        #     'disc_scope',
        #     name = 'vae_gan_model'
        # )

        log.info('Model initialized')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # Define forward pass
        cost_function = vae_model(model_input_batch)
        # backpass, grads = vae_gan(model_input_batch, sampler_input_batch, optimizer)
        log.info('Forward pass graph built')

        # Define sampler
        # sampler = vae_gan.sampler
        sampler = vae_model.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])

        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')

            test_LL_and_DKL(sess, test_provider, feat_vae.vae.DKL,
                            feat_vae.vae.rec_loss, options, model_input_batch)
            return

            mean_img = np.load(
                os.path.join(options['data_dir'],
                             'mean' + options['extension']))
            std_img = np.load(
                os.path.join(options['data_dir'],
                             'std' + options['extension']))
            visualize(sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq,
                      sampler, sampler_input_batch, model_input_batch,
                      feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq,
                      train_provider, val_provider, options, catalog, mean_img,
                      std_img)
            return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, _ in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                # if batch_abs_idx < options['initial_G_iters']:
                #     optimizer = vae_optimizer
                # else:
                #     optimizer = disc_optimizer
                # if batch_abs_idx % total_D2G < D_to_G[0]:
                #     optimizer = disc_optimizer
                # else:
                #     optimizer = vae_optimizer
                result = sess.run([
                    cost_function,
                    backpass,
                    vae_model.DKL,
                    vae_model.rec_loss,
                    vae_model.dec_log_std_sq,
                    vae_model.enc_log_std_sq,
                    vae_model.enc_mean,
                    vae_model.dec_mean,
                ] + [gv[0] for gv in grads],
                                  feed_dict={model_input_batch: inputs})

                # print('#'*80)
                # print(result[-1])
                # print('#'*80)

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                     '2016-04-22',
                                                     -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()
                    # val_sig_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(
                        np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(
                        np.mean(result[2], axis=0)))
                    # log.info('Batch Mean Acc.: {:0>15.4f}'.format(result[-2], axis=0))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch, _ in val_provider:

                        val_cost = sess.run(
                            vae_model.cost,
                            feed_dict={
                                model_input_batch:
                                val_batch,
                                sampler_input_batch:
                                MVN(np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size=options['batch_size'])
                            })
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_samples = sess.run(
                        sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(catalog,
                                    np.reshape(result[7],
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

                    # save_samples(
                    #     val_samples,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     True,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     result[8],
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'rec_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

            log.info('End of epoch {}'.format(epoch_idx + 1))
    # Test Model --------------------------------------------------------------------------
        test_results = []

        for inputs in test_provider:
            if isinstance(inputs, tuple):
                inputs = inputs[0]
            batch_results = sess.run([
                feat_vae.vae.DKL, feat_vae.vae.rec_loss,
                feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq,
                feat_vae.vae.dec_mean, feat_vae.vae.enc_mean
            ],
                                     feed_dict={model_input_batch: inputs})

            test_results.append(
                map(
                    lambda p: np.mean(p, axis=1)
                    if len(p.shape) > 1 else np.mean(p), batch_results))
        test_results = map(list, zip(*test_results))

        # Print results
        log.info('Test Mean Rec. Loss: {:0>15.4f}'.format(
            float(np.mean(test_results[1]))))
        log.info('Test DKL: {:0>15.4f}'.format(float(np.mean(
            test_results[0]))))
        log.info('Test Dec. Mean Log Std Sq: {:0>15.4f}'.format(
            float(np.mean(test_results[2]))))
        log.info('Test Enc. Mean Log Std Sq: {:0>15.4f}'.format(
            float(np.mean(test_results[3]))))
        log.info('Test Dec. Mean Mean: {:0>15.4f}'.format(
            float(np.mean(test_results[4]))))
        log.info('Test Enc. Mean Mean: {:0>15.4f}'.format(
            float(np.mean(test_results[5]))))
Exemple #7
0
def visualize(sampler_mean, sess, dec_mean, dec_log_std_sq, sampler, sampler_input_batch, model_input_batch, enc_mean, enc_log_std_sq, train_provider, val_provider, options, catalog, mean_img, std_img):
	from numpy.random import multivariate_normal as MVN, uniform

	mean_img = mean_img.flatten()
	std_img = std_img.flatten()

	# Validation Samples --------------------------------------------------------------------------
	print('Generate Samples from N(0,I)')
	val_samples = sess.run(
	    sampler_mean,
	    feed_dict = {
	        sampler_input_batch: MVN(
	            np.zeros(options['latent_dims']),
	            np.diag(np.ones(options['latent_dims'])),
	            size = options['batch_size']
	        )
	    }
	)
	val_samples = (val_samples * std_img) + mean_img

	for inputs in val_provider:
	    break
	if isinstance(inputs, tuple):
		inputs = inputs[0]
	rec_samples = sess.run(
	    dec_mean,
	    feed_dict = {
	        model_input_batch: inputs
	    }
	)

	# Reconstruction Samples --------------------------------------------------------------------------
	print('Generate Reconstruction Samples')
	print("NOT STUCK HERE!")

	# recons = []
	# for i, temp in enumerate(zip(rec_samples[0], rec_samples[1])):
	#     mean, log_std_sq = temp
	#     std = np.exp(0.5 * log_std_sq)
	#     recons.append(
	#         std * MVN(
	#             np.zeros(mean.shape[0]),
	#             np.diag(np.ones(std.shape[0]))
	#         ) + mean
	#     )
	#     print(i)
	# print("NOT STUCK HERE!")
	# recons = np.array(recons)

	recons = rec_samples
	recons = (recons * std_img) + mean_img

	inputs = (inputs * std_img) + mean_img

	print("NOT STUCK HERE!")

	try:
	    os.mkdir(options['visu_save_dir'])
	except:
	    pass

	save_ae_samples(
	    catalog,
	    np.reshape(recons, [options['batch_size']]+options['img_shape']),
	    np.reshape(inputs, [options['batch_size']]+options['img_shape']),
	    np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
	    100,
	    options['visu_save_dir'],
	    num_to_save=10,
	    save_gray=True
	)
	save_samples(
	    val_samples,
	    int(0),
	    options['visu_save_dir'],
	    True,
	    options['img_shape'],
	    10
	)
	save_samples(
	    recons,
	    int(1),
	    options['visu_save_dir'],
	    True,
	    options['img_shape'],
	    10
	)
	save_samples(
	    inputs,
	    int(2),
	    options['visu_save_dir'],
	    True,
	    options['img_shape'],
	    10
	)

	# Gaussian Sampling --------------------------------------------------------------------------
	print('Fit Gaussian to Samples')
	enc_samples = None
	for i, inputs in enumerate(train_provider):
	    if isinstance(inputs, tuple):
	    	inputs = inputs[0]
	    if i == 11:
	        break
	    
	    encs = sess.run(
	        enc_mean + tf.random_normal(enc_mean.get_shape()) * tf.exp(0.5 * enc_log_std_sq),
	        feed_dict = {
	            model_input_batch: inputs
	        }
	    )

	    # codes = []
	    # for i, temp in enumerate(zip(encs[0], encs[1])):
	    #     mean, log_std_sq = temp
	    #     var = np.exp(log_std_sq)
	    #     codes.append(MVN(
	    #         mean,
	    #         np.diag(var)
	    #     ))
	    # codes = np.array(codes)
	    codes = encs
	    if enc_samples == None:
	        enc_samples = codes
	    else:
	        enc_samples = np.concatenate((enc_samples, codes))

	mean = np.mean(enc_samples, axis=0)
	std = np.std(enc_samples, axis=0)

	print("Generate new samples from Gaussian")
	val_samples = sess.run(
	    sampler_mean,
	    feed_dict = {
	        sampler_input_batch: MVN(
	            mean,
	            np.diag(std),
	            size = options['batch_size']
	        )
	    }
	)
	val_samples = (val_samples * std_img) + mean_img

	save_samples(
	    val_samples,
	    int(3),
	    options['visu_save_dir'],
	    True,
	    options['img_shape'],
	    10
	)
Exemple #8
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write('Log Sigma^2 clipped to: [{}, {}]\n\n'.format(
        -options['sigma_clip'], options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Discriminator Cross-Entropy
train_acc.csv,csv,Discriminator Accuracy
val_loss.csv,csv,Validation Cross-Entropy
val_acc.csv,csv,Validation Accuracy
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    train_acc = open(os.path.join(options['dashboard_dir'], 'train_acc.csv'),
                     'w')
    val_acc = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'), 'w')

    train_log.write(
        'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n'
    )
    val_log.write(
        'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n'
    )
    train_acc.write(
        'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n'
    )
    val_acc.write(
        'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation Acc. (Training Disc.)\n'
    )

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Define inputs -------------------------------------------------------------------------
        real_batch = tf.placeholder(tf.float32,
                                    shape=[
                                        options['batch_size'],
                                        np.prod(np.array(options['img_shape']))
                                    ],
                                    name='real_inputs')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='noise_channel')
        labels = tf.constant(
            np.expand_dims(np.concatenate((np.ones(
                options['batch_size']), np.zeros(options['batch_size'])),
                                          axis=0).astype(np.float32),
                           axis=1))
        labels = tf.cast(labels, tf.float32)
        log.info('Inputs defined')

        # Define model --------------------------------------------------------------------------
        with tf.variable_scope('gen_scope'):
            generator = Sequential('generator')
            generator += FullyConnected(options['latent_dims'],
                                        60,
                                        tf.nn.tanh,
                                        name='fc_1')
            generator += FullyConnected(60, 60, tf.nn.tanh, name='fc_2')
            generator += FullyConnected(60,
                                        np.prod(options['img_shape']),
                                        tf.nn.tanh,
                                        name='fc_3')

            sampler = generator(sampler_input_batch)

        with tf.variable_scope('disc_scope'):
            disc_model = cupboard('fixed_conv_disc')(
                pickle.load(open(options['disc_params_path'], 'rb')),
                options['num_feat_layers'],
                name='disc_model')

            disc_inputs = tf.concat(0, [real_batch, sampler])
            disc_inputs = tf.reshape(
                disc_inputs, [disc_inputs.get_shape()[0].value] +
                options['img_shape'] + [options['input_channels']])

            preds = disc_model(disc_inputs)
            preds = tf.clip_by_value(preds, 0.00001, 0.99999)

            # Disc Accuracy
            disc_accuracy = (
                1 / float(labels.get_shape()[0].value)) * tf.reduce_sum(
                    tf.cast(tf.equal(tf.round(preds), labels), tf.float32))

        # Define Losses -------------------------------------------------------------------------
        # Discrimnator Cross-Entropy
        disc_CE = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum(
            -tf.add(tf.mul(labels, tf.log(preds)),
                    tf.mul(1.0 - labels, tf.log(1.0 - preds))))

        gen_loss = -tf.mul(1.0 - labels, tf.log(preds))

        # Define Optimizers ---------------------------------------------------------------------
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])

        # Get Generator and Disriminator Trainable Variables
        gen_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           'gen_scope')
        disc_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                            'disc_scope')

        # Get generator gradients
        grads = optimizer.compute_gradients(gen_loss, gen_train_vars)
        grads = [gv for gv in grads if gv[0] != None]
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='gen_grad_clipping'), gv[1])
                      for gv in grads]
        gen_backpass = optimizer.apply_gradients(clip_grads)

        # Get Dsicriminator gradients
        grads = optimizer.compute_gradients(disc_CE, disc_train_vars)
        grads = [gv for gv in grads if gv[0] != None]
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='disc_grad_clipping'), gv[1])
                      for gv in grads]
        disc_backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')
        # --------------------------------------------------------------------------------------
        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload_all']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))
        last_accs = np.zeros((10))
        disc_tracker = np.ones((5000))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        base = options['initial_G_iters'] + options['initial_D_iters']
        # must_init = True
        feat_params = pickle.load(open(options['disc_params_path'], 'rb'))

        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                if batch_abs_idx < options['initial_G_iters']:
                    backpass = gen_backpass
                    log_format_string = '{},{},{},,\n'
                elif options['initial_G_iters'] <= batch_abs_idx < base:
                    backpass = disc_backpass
                    log_format_string = '{},{},,,{}\n'
                else:
                    # if np.mean(disc_tracker) < 0.95:
                    #     disc_model._disc.layers[-2].re_init_weights(sess)
                    #     disc_tracker = np.ones((5000))

                    if (batch_abs_idx - base) % total_D2G < D_to_G[0]:
                        # if must_init:
                        #     # i = 0
                        #     # for j in xrange(options['num_feat_layers']):
                        #     #     if feat_params[j]['layer_type'] == 'conv':
                        #     #         disc_model._disc.layers[i].re_init_weights(sess)
                        #     #         # print('@' * 1000)
                        #     #         # print(disc_model._disc.layers[i])
                        #     #         i += 1 # for dealing with activation function
                        #     #     elif feat_params[j]['layer_type'] == 'fc':
                        #     #         disc_model._disc.layers[i].re_init_weights(sess)
                        #     #         # print('@' * 1000)
                        #     #         # print(disc_model._disc.layers[i])
                        #     #     i += 1
                        #     disc_model._disc.layers[-2].re_init_weights(sess)
                        #     # print('@' * 1000)
                        #     # print(disc_model._disc.layers[-2])
                        #     must_init = False
                        backpass = disc_backpass
                        log_format_string = '{},{},,,{}\n'
                    else:
                        # must_init = True
                        backpass = gen_backpass
                        log_format_string = '{},{},,{},\n'

                log_format_string = '{},{},{},,\n'
                result = sess.run(
                    [disc_CE, backpass, disc_accuracy],
                    feed_dict={
                        real_batch:
                        inputs,
                        sampler_input_batch:
                        MVN(np.zeros(options['latent_dims']),
                            np.diag(np.ones(options['latent_dims'])),
                            size=options['batch_size'])
                    })

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(last_losses)))
                    train_acc.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(last_accs)))

                    train_log.flush()
                    train_acc.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                last_accs = np.roll(last_accs, 1)
                last_accs[0] = result[-1]

                disc_tracker = np.roll(disc_tracker, 1)
                disc_tracker[0] = result[-1]

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last accuracies: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_accs)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    valid_accs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        result = sess.run(
                            [disc_CE, disc_accuracy],
                            feed_dict={
                                real_batch:
                                val_batch,
                                sampler_input_batch:
                                MVN(np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size=options['batch_size'])
                            })
                        valid_costs.append(result[0])
                        valid_accs.append(result[-1])
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))
                    log.info('Validation accuracies: {:0>15.4f}'.format(
                        float(np.mean(valid_accs))))

                    val_samples = sess.run(
                        sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(valid_costs)))
                    val_acc.write(
                        log_format_string.format(batch_abs_idx, '2016-04-22',
                                                 np.mean(valid_accs)))
                    val_log.flush()
                    val_acc.flush()

                    save_ae_samples(catalog,
                                    np.ones([options['batch_size']] +
                                            options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

            log.info('End of epoch {}'.format(epoch_idx + 1))
Exemple #9
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'],
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            options['p_layers'],
            options['q_layers'],
            np.prod(options['img_shape']),
            options['latent_dims'],
            options['DKL_weight'],
            options['sigma_clip'],
            'vanilla_vae'
        )
        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'enc_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'dec_inputs'
        )
        log.info('Inputs defined')

        # Define forward pass
        cost_function = model(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = model.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        
        train_step = optimizer.minimize(cost_function)

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')

            # test_LL_and_DKL(sess, test_provider, model.DKL, model.rec_loss, options, model_input_batch)
            # return

            # if options['data_dir'] == 'MNIST':
            #     mean_img = np.zeros(np.prod(options['img_shape']))
            #     std_img = np.ones(np.prod(options['img_shape']))
            # else:
            #     mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension']))
            #     std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension']))
            # visualize(model.sampler_mean, sess, model.dec_mean, model.dec_log_std_sq, sampler, sampler_input_batch,
            #             model_input_batch, model.enc_mean, model.enc_log_std_sq,
            #             train_provider, val_provider, options, catalog, mean_img, std_img)
            # return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1          2           3               4                     5                       6              7               8            9           10
                    [cost_function, backpass, model.DKL, model.rec_loss, model.dec_log_std_sq, model.enc_log_std_sq, model.enc_mean, model.dec_mean],
                    feed_dict = {
                        model_input_batch: inputs
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()                    
                    # val_sig_log.flush()
                # print('\n\nENC_MEAN:')
                # print(result[3])
                # print('\n\nENC_STD:')
                # print(result[2])
                # print('\nDEC_MEAN:')
                # print(result[6])
                # print('\nDEC_STD:')
                # print(result[5])

                # print('\n\nENCODER WEIGHTS:')
                # print(model._encoder.layers[0].weights['w'].eval())
                # print('\n\DECODER WEIGHTS:')
                # print(model._decoder.layers[0].weights['w'].eval())

                # print(model._encoder.layers[0].weights['w'].eval())
                # print(result[2])
                # print(result[3])

                # print(result[3])
                # print(result[2])
                # print(result[-2])
                # print(result[-1])

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    save_dict = {}
                    # Save encoder params ------------------------------------------------------------------
                    for i in range(len(model._encoder.layers)):
                        layer_dict = {
                            'input_dim':model._encoder.layers[i].input_dim,
                            'output_dim':model._encoder.layers[i].output_dim,
                            'act_fn':model._encoder.layers[i].activation,
                            'W':model._encoder.layers[i].weights['w'].eval(),
                            'b':model._encoder.layers[i].weights['b'].eval()
                        }
                        save_dict['encoder'] = layer_dict

                    layer_dict = {
                        'input_dim':model._enc_mean.input_dim,
                        'output_dim':model._enc_mean.output_dim,
                        'act_fn':model._enc_mean.activation,
                        'W':model._enc_mean.weights['w'].eval(),
                        'b':model._enc_mean.weights['b'].eval()
                    }
                    save_dict['enc_mean'] = layer_dict

                    layer_dict = {
                        'input_dim':model._enc_log_std_sq.input_dim,
                        'output_dim':model._enc_log_std_sq.output_dim,
                        'act_fn':model._enc_log_std_sq.activation,
                        'W':model._enc_log_std_sq.weights['w'].eval(),
                        'b':model._enc_log_std_sq.weights['b'].eval()
                    }
                    save_dict['enc_log_std_sq'] = layer_dict

                    # Save decoder params ------------------------------------------------------------------
                    for i in range(len(model._decoder.layers)):
                        layer_dict = {
                            'input_dim':model._decoder.layers[i].input_dim,
                            'output_dim':model._decoder.layers[i].output_dim,
                            'act_fn':model._decoder.layers[i].activation,
                            'W':model._decoder.layers[i].weights['w'].eval(),
                            'b':model._decoder.layers[i].weights['b'].eval()
                        }
                        save_dict['decoder'] = layer_dict

                    layer_dict = {
                        'input_dim':model._dec_mean.input_dim,
                        'output_dim':model._dec_mean.output_dim,
                        'act_fn':model._dec_mean.activation,
                        'W':model._dec_mean.weights['w'].eval(),
                        'b':model._dec_mean.weights['b'].eval()
                    }
                    save_dict['dec_mean'] = layer_dict

                    layer_dict = {
                        'input_dim':model._dec_log_std_sq.input_dim,
                        'output_dim':model._dec_log_std_sq.output_dim,
                        'act_fn':model._dec_log_std_sq.activation,
                        'W':model._dec_log_std_sq.weights['w'].eval(),
                        'b':model._dec_log_std_sq.weights['b'].eval()
                    }
                    save_dict['dec_log_std_sq'] = layer_dict

                    pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'vae_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    model._decoder.layers[0].weights['w'].eval()[:5,:5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        val_cost = sess.run(
                            cost_function,
                            feed_dict = {
                                model_input_batch: val_batch
                            }
                        )
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_samples = sess.run(
                        sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(
                        catalog,
                        np.reshape(result[7], [options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     val_samples,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     True,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     result[7],
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'rec_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )


            log.info('End of epoch {}'.format(epoch_idx + 1))
Exemple #10
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log)

    # Initialize model ------------------------------------------------------------------

    # input_shape, input_channels, enc_params, dec_params, name=''
    with tf.device('/gpu:0'):
        if options['model'] == 'cnn_ae':
            model = cupboard(options['model'])(options['img_shape'],
                                               options['input_channels'],
                                               options['enc_params'],
                                               options['dec_params'], 'cnn_ae')

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape=[options['batch_size']] + options['img_shape'] +
                [options['input_channels']],
                name='clean')
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape=[options['batch_size']] + options['img_shape'] +
                [options['input_channels']],
                name='noisy')
            log.info('Inputs defined')

        else:
            model = cupboard(options['model'])(
                np.prod(options['img_shape']) * options['input_channels'],
                options['enc_params'], options['dec_params'], 'ae')

            # Define inputs
            model_clean_input_batch = tf.placeholder(
                tf.float32,
                shape=[options['batch_size']] +
                [np.prod(options['img_shape']) * options['input_channels']],
                name='clean')
            model_noisy_input_batch = tf.placeholder(
                tf.float32,
                shape=[options['batch_size']] +
                [np.prod(options['img_shape']) * options['input_channels']],
                name='noisy')
            log.info('Inputs defined')

        log.info('Model initialized')

        # Define forward pass
        print(model_clean_input_batch.get_shape())
        print(model_noisy_input_batch.get_shape())
        cost_function = model(model_clean_input_batch, model_noisy_input_batch)
        log.info('Forward pass graph built')

        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, os.path.join(options['model_dir'],
                                             'model.ckpt'))
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs, _ in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    [cost_function, backpass] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_clean_input_batch: inputs,
                        model_noisy_input_batch: np.float32(inputs) + \
                            normal(
                                loc=0.0,
                                scale=np.float32(options['noise_std']),
                                size=inputs.shape
                            )
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    train_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    # Save Encoder Params
                    save_dict = {
                        'enc_W': [],
                        'enc_b': [],
                        'enc_act_fn': [],
                    }
                    if options['model'] == 'cnn_ae':
                        pass
                    else:
                        for i in range(len(model._encoder.layers)):
                            save_dict['enc_W'].append(
                                model._encoder.layers[i].weights['w'].eval())
                            save_dict['enc_b'].append(
                                model._encoder.layers[i].weights['b'].eval())
                            save_dict['enc_act_fn'].append(
                                options['enc_params']['act_fn'][i])

                    pickle.dump(
                        save_dict,
                        open(
                            os.path.join(options['model_dir'],
                                         'enc_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    model._decoder.layers[0].weights['w'].eval()[:5, :5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch, _ in val_provider:

                        noisy_val_batch = val_batch + \
                            normal(
                                loc=0.0,
                                scale=np.float32(options['noise_std']),
                                size=val_batch.shape
                            )

                        val_results = sess.run(
                            (cost_function, model.decoder),
                            feed_dict={
                                model_clean_input_batch: val_batch,
                                model_noisy_input_batch: noisy_val_batch
                            })
                        valid_costs.append(val_results[0])
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      np.mean(valid_costs)))
                    val_log.flush()

                    if options['model'] == 'conv_ae':
                        val_recon = np.reshape(val_results[-1],
                                               val_batch.shape)
                    else:
                        val_batch = np.reshape(
                            val_batch, [val_batch.shape[0]] +
                            options['img_shape'] + [options['input_channels']])
                        noisy_val_batch = np.reshape(
                            noisy_val_batch, [val_batch.shape[0]] +
                            options['img_shape'] + [options['input_channels']])
                        val_recon = np.reshape(
                            val_results[-1], [val_batch.shape[0]] +
                            options['img_shape'] + [options['input_channels']])

                    save_ae_samples(catalog,
                                    val_batch,
                                    noisy_val_batch,
                                    val_recon,
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

                    # save_samples(
                    #     val_recon,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     False,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     False,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

            log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'],
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    if options['data_dir'] != 'MNIST':
        num_data_points = len(
            os.listdir(
                os.path.join(options['data_dir'], 'train', 'patches')
            )
        )
        num_data_points -= 2

        train_provider = DataProvider(
            num_data_points,
            options['batch_size'],
            toolbox.ImageLoader(
                data_dir = os.path.join(options['data_dir'], 'train', 'patches'),
                flat=True,
                extension=options['file_extension']
            )
        )

        # Valid provider
        num_data_points = len(
            os.listdir(
                os.path.join(options['data_dir'], 'valid', 'patches')
            )
        )
        num_data_points -= 2

        val_provider = DataProvider(
            num_data_points,
            options['batch_size'],
            toolbox.ImageLoader(
                data_dir = os.path.join(options['data_dir'], 'valid', 'patches'),
                flat = True,
                extension=options['file_extension']
            )
        )

    else:
        train_provider = DataProvider(
            55000,
            options['batch_size'],
            toolbox.MNISTLoader(
                mode='train',
                flat=True
            )
        )

        val_provider = DataProvider(
            5000,
            options['batch_size'],
            toolbox.MNISTLoader(
                mode='validation',
                flat = True
            )
        )

    log.info('Data providers initialized.')


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Feature Extractor -----------------------------------------------------
        # feat_params = pickle.load(open(options['feat_params_path'], 'rb'))
        # _classifier = Sequential('CNN_Classifier')
        # _classifier += ConvLayer(
        #     options['input_channels'],
        #     feat_params[0]['n_filters_out'],
        #     feat_params[0]['input_dim'],
        #     feat_params[0]['filter_dim'],
        #     feat_params[0]['strides'],
        #     name='classifier_conv_0'
        # )
        # _classifier += feat_params[0]['act_fn']
        # _classifier.layers[-2].weights['W'] = tf.constant(feat_params[0]['W'])
        # _classifier.layers[-2].weights['b'] = tf.constant(feat_params[0]['b'])
        # print("1 conv layer")

        # i = 1
        # while i < options['num_feat_layers']:
        #     if 'filter_dim' in feat_params[i]:
        #         _classifier += ConvLayer(
        #             feat_params[i]['n_filters_in'],
        #             feat_params[i]['n_filters_out'],
        #             feat_params[i]['input_dim'],
        #             feat_params[i]['filter_dim'],
        #             feat_params[i]['strides'],
        #             name='classifier_conv_0'
        #         )
        #         _classifier += feat_params[i]['act_fn']
        #         _classifier.layers[-2].weights['W'] = tf.constant(feat_params[i]['W'])
        #         _classifier.layers[-2].weights['b'] = tf.constant(feat_params[i]['b'])
        #         print("1 conv layer")
        #     else:
        #         _classifier += ConstFC(
        #             feat_params[i]['W'],
        #             feat_params[i]['b'],
        #             activation=feat_params[i]['act_fn'],
        #             name='classifier_fc_0'
        #         )
        #         print("1 fc layer")
        #     i += 1

        if options['feat_type'] == 'fc':
            feat_model = Sequential('feat_extractor')
            feat_params = pickle.load(open(options['feat_params_path'], 'rb'))
            for i in range(options['num_feat_layers']):
                feat_model += ConstFC(
                    feat_params['enc_W'][i],
                    feat_params['enc_b'][i],
                    activation=feat_params['enc_act_fn'][i],
                    name='feat_layer_%d'%i
                )
        else:
            pass

        # VAE -------------------------------------------------------------------
        # VAE model
        vae_model = cupboard('vanilla_vae')(
            options['p_layers'],
            options['q_layers'],
            np.prod(options['img_shape']),
            options['latent_dims'],
            options['DKL_weight'],
            options['sigma_clip'],
            'vanilla_vae'
        )
        # -----------------------------------------------------------------------
        feat_vae = cupboard('feat_vae')(
            vae_model,
            feat_model,
            options['DKL_weight'],
            0.0,
            img_shape=options['img_shape'],
            input_channels=options['input_channels'],
            flat=True, 
            name='feat_vae_model'
        )

        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'enc_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'dec_inputs'
        )
        log.info('Inputs defined')

        # Define forward pass
        cost_function = feat_vae(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = feat_vae.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])
        
        # train_step = optimizer.minimize(cost_function)
        log.info('Optimizer graph built')

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grads = [gv for gv in grads if gv[0] != None]
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
            if options['data_dir'] == 'MNIST':
                mean_img = np.zeros(np.prod(options['img_shape']))
                std_img = np.ones(np.prod(options['img_shape']))
            else:
                mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension']))
                std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension']))
            visualize(feat_vae.vae.sampler_mean, sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq, sampler, sampler_input_batch,
                        model_input_batch, feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq,
                        train_provider, val_provider, options, catalog, mean_img, std_img)
            return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1               2           3               4               5               6              7               8            9           10
                    [cost_function, backpass, feat_vae.vae.DKL, feat_vae.vae.rec_loss, feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq, feat_vae.vae.enc_mean, feat_vae.vae.dec_mean] + [gv[0] for gv in grads],
                    feed_dict = {
                        model_input_batch: inputs
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()                    
                    # val_sig_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:

                        val_cost = sess.run(
                            cost_function,
                            feed_dict = {
                                model_input_batch: val_batch
                            }
                        )
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_samples = sess.run(
                        sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(
                        catalog,
                        np.reshape(result[7], [options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    save_samples(
                        val_samples,
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'valid_samples'),
                        True,
                        options['img_shape'],
                        5
                    )

                    save_samples(
                        inputs,
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'input_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5
                    )

                    save_samples(
                        result[7],
                        int(batch_abs_idx/options['freq_validation']),
                        os.path.join(options['model_dir'], 'rec_sanity'),
                        True,
                        options['img_shape'],
                        num_to_save=5
                    )


            log.info('End of epoch {}'.format(epoch_idx + 1))
Exemple #12
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'], -options['sigma_clip'],
            options['sigma_clip']))
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write("""filename,type,name
options,plain,Options
train_loss.csv,csv,Train Loss
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
""")
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'),
                     'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(
        os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'),
                        'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'),
                        'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train Loss\n')
    val_log.write('step,time,Validation Loss\n')
    dkl_log.write('step,time,DKL\n')
    ll_log.write('step,time,-LL\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n')

    dec_mean_log.write('step,time,Decoder Mean\n')
    enc_mean_log.write('step,time,Encoder Mean\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options,
                                                                log,
                                                                flat=True)

    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        model = cupboard(options['model'])(
            options['p_layers'], options['q_layers'],
            np.prod(options['img_shape']), options['latent_dims'],
            options['DKL_weight'], options['sigma_clip'], 'vanilla_vae')
        log.info('Model initialized')

        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape=[
                options['batch_size'],
                np.prod(np.array(options['img_shape']))
            ],
            name='enc_inputs')
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape=[options['batch_size'], options['latent_dims']],
            name='dec_inputs')
        log.info('Inputs defined')

        # Define forward pass
        cost_function = model(model_input_batch)
        log.info('Forward pass graph built')

        # Define sampler
        sampler = model.build_sampler(sampler_input_batch)
        log.info('Sampler graph built')

        # Define optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=options['lr'])
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr'])

        train_step = optimizer.minimize(cost_function)

        # Get gradients
        grads = optimizer.compute_gradients(cost_function)
        grad_tensors = [gv[0] for gv in grads]

        # Clip gradients
        clip_grads = [(tf.clip_by_norm(gv[0], 5.0,
                                       name='grad_clipping'), gv[1])
                      for gv in grads]

        # Update op
        backpass = optimizer.apply_gradients(clip_grads)

        log.info('Optimizer graph built')

        # # Get gradients
        # grad = optimizer.compute_gradients(cost_function)

        # # Clip gradients
        # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping')

        # # Update op
        # backpass = optimizer.apply_gradients(clipped_grad)

        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')

            # test_LL_and_DKL(sess, test_provider, model.DKL, model.rec_loss, options, model_input_batch)
            # return

            # if options['data_dir'] == 'MNIST':
            #     mean_img = np.zeros(np.prod(options['img_shape']))
            #     std_img = np.ones(np.prod(options['img_shape']))
            # else:
            #     mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension']))
            #     std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension']))
            # visualize(model.sampler_mean, sess, model.dec_mean, model.dec_log_std_sq, sampler, sampler_input_batch,
            #             model_input_batch, model.enc_mean, model.enc_log_std_sq,
            #             train_provider, val_provider, options, catalog, mean_img, std_img)
            # return
        else:
            sess.run(init_op)
            log.info('Shared variables initialized')

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                result = sess.run(
                    # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL),
                    #       0           1          2           3               4                     5                       6              7               8            9           10
                    [
                        cost_function, backpass, model.DKL, model.rec_loss,
                        model.dec_log_std_sq, model.enc_log_std_sq,
                        model.enc_mean, model.dec_mean
                    ],
                    feed_dict={model_input_batch: inputs})

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      -np.mean(result[2])))
                    ll_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                     '2016-04-22',
                                                     -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write('{},{},{}\n'.format(
                        batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()
                    # val_sig_log.flush()
                # print('\n\nENC_MEAN:')
                # print(result[3])
                # print('\n\nENC_STD:')
                # print(result[2])
                # print('\nDEC_MEAN:')
                # print(result[6])
                # print('\nDEC_STD:')
                # print(result[5])

                # print('\n\nENCODER WEIGHTS:')
                # print(model._encoder.layers[0].weights['w'].eval())
                # print('\n\DECODER WEIGHTS:')
                # print(model._decoder.layers[0].weights['w'].eval())

                # print(model._encoder.layers[0].weights['w'].eval())
                # print(result[2])
                # print(result[3])

                # print(result[3])
                # print(result[2])
                # print(result[-2])
                # print(result[-1])

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(model._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info(
                        'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'
                        .format(epoch_idx + 1, options['n_epochs'],
                                batch_abs_idx, float(cost),
                                np.mean(last_losses)))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(
                        np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(
                        np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(
                        sess,
                        os.path.join(options['model_dir'],
                                     'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    save_dict = {}
                    # Save encoder params ------------------------------------------------------------------
                    for i in range(len(model._encoder.layers)):
                        layer_dict = {
                            'input_dim': model._encoder.layers[i].input_dim,
                            'output_dim': model._encoder.layers[i].output_dim,
                            'act_fn': model._encoder.layers[i].activation,
                            'W': model._encoder.layers[i].weights['w'].eval(),
                            'b': model._encoder.layers[i].weights['b'].eval()
                        }
                        save_dict['encoder'] = layer_dict

                    layer_dict = {
                        'input_dim': model._enc_mean.input_dim,
                        'output_dim': model._enc_mean.output_dim,
                        'act_fn': model._enc_mean.activation,
                        'W': model._enc_mean.weights['w'].eval(),
                        'b': model._enc_mean.weights['b'].eval()
                    }
                    save_dict['enc_mean'] = layer_dict

                    layer_dict = {
                        'input_dim': model._enc_log_std_sq.input_dim,
                        'output_dim': model._enc_log_std_sq.output_dim,
                        'act_fn': model._enc_log_std_sq.activation,
                        'W': model._enc_log_std_sq.weights['w'].eval(),
                        'b': model._enc_log_std_sq.weights['b'].eval()
                    }
                    save_dict['enc_log_std_sq'] = layer_dict

                    # Save decoder params ------------------------------------------------------------------
                    for i in range(len(model._decoder.layers)):
                        layer_dict = {
                            'input_dim': model._decoder.layers[i].input_dim,
                            'output_dim': model._decoder.layers[i].output_dim,
                            'act_fn': model._decoder.layers[i].activation,
                            'W': model._decoder.layers[i].weights['w'].eval(),
                            'b': model._decoder.layers[i].weights['b'].eval()
                        }
                        save_dict['decoder'] = layer_dict

                    layer_dict = {
                        'input_dim': model._dec_mean.input_dim,
                        'output_dim': model._dec_mean.output_dim,
                        'act_fn': model._dec_mean.activation,
                        'W': model._dec_mean.weights['w'].eval(),
                        'b': model._dec_mean.weights['b'].eval()
                    }
                    save_dict['dec_mean'] = layer_dict

                    layer_dict = {
                        'input_dim': model._dec_log_std_sq.input_dim,
                        'output_dim': model._dec_log_std_sq.output_dim,
                        'act_fn': model._dec_log_std_sq.activation,
                        'W': model._dec_log_std_sq.weights['w'].eval(),
                        'b': model._dec_log_std_sq.weights['b'].eval()
                    }
                    save_dict['dec_log_std_sq'] = layer_dict

                    pickle.dump(
                        save_dict,
                        open(
                            os.path.join(options['model_dir'],
                                         'vae_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    model._decoder.layers[0].weights['w'].eval()[:5, :5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        val_cost = sess.run(
                            cost_function,
                            feed_dict={model_input_batch: val_batch})
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))))

                    val_samples = sess.run(
                        sampler,
                        feed_dict={
                            sampler_input_batch:
                            MVN(np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size=options['batch_size'])
                        })

                    val_log.write('{},{},{}\n'.format(batch_abs_idx,
                                                      '2016-04-22',
                                                      np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(catalog,
                                    np.reshape(result[7],
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(inputs,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    np.reshape(val_samples,
                                               [options['batch_size']] +
                                               options['img_shape']),
                                    batch_abs_idx,
                                    options['dashboard_dir'],
                                    num_to_save=5,
                                    save_gray=True)

                    # save_dash_samples(
                    #     catalog,
                    #     val_samples,
                    #     batch_abs_idx,
                    #     options['dashboard_dir'],
                    #     flat_samples=True,
                    #     img_shape=options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     val_samples,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'valid_samples'),
                    #     True,
                    #     options['img_shape'],
                    #     5
                    # )

                    # save_samples(
                    #     inputs,
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'input_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

                    # save_samples(
                    #     result[7],
                    #     int(batch_abs_idx/options['freq_validation']),
                    #     os.path.join(options['model_dir'], 'rec_sanity'),
                    #     True,
                    #     options['img_shape'],
                    #     num_to_save=5
                    # )

            log.info('End of epoch {}'.format(epoch_idx + 1))
Exemple #13
0
def train(options):
    # Get logger
    log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt'))
    options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w')
    options_file.write(options['description'] + '\n')
    options_file.write(
        'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format(
            options['DKL_weight'],
            -options['sigma_clip'],
            options['sigma_clip']
        )
    )
    for optn in options:
        options_file.write(optn)
        options_file.write(':\t')
        options_file.write(str(options[optn]))
        options_file.write('\n')
    options_file.close()

    # Dashboard Catalog
    catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w')
    catalog.write(
"""filename,type,name
options,plain,Options
train_loss.csv,csv,Discriminator Cross-Entropy
ll.csv,csv,Neg. Log-Likelihood
dec_log_sig_sq.csv,csv,Decoder Log Simga^2
dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2
dec_mean.csv,csv,Decoder Mean
dkl.csv,csv,DKL
enc_log_sig_sq.csv,csv,Encoder Log Sigma^2
enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2
enc_mean.csv,csv,Encoder Mean
val_loss.csv,csv,Validation Loss
"""
    )
    catalog.flush()
    train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w')
    val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w')
    dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w')
    ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w')

    dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w')
    enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w')

    dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w')
    enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w')

    dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w')
    enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w')
    # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w')

    train_log.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n')
    val_log.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n')
    dkl_log.write('step,time,DKL (Training Vanilla),DKL (Training Gen.),DKL (Training Disc.)\n')
    ll_log.write('step,time,-LL (Training Vanilla),-LL (Training Gen.),-LL (Training Disc.)\n')

    dec_sig_log.write('step,time,Decoder Log Sigma^2 (Training Vanilla),Decoder Log Sigma^2 (Training Gen.),Decoder Log Sigma^2 (Training Disc.)\n')
    enc_sig_log.write('step,time,Encoder Log Sigma^2 (Training Vanilla),Encoder Log Sigma^2 (Training Gen.),Encoder Log Sigma^2 (Training Disc.)\n')

    dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2 (Training Vanilla),STD of Decoder Log Sigma^2 (Training Gen.),STD of Decoder Log Sigma^2 (Training Disc.)\n')
    enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2 (Training Vanilla),STD of Encoder Log Sigma^2 (Training Gen.),STD of Encoder Log Sigma^2 (Training Disc.)\n')

    dec_mean_log.write('step,time,Decoder Mean (Training Vanilla),Decoder Mean (Training Gen.),Decoder Mean (Training Disc.)\n')
    enc_mean_log.write('step,time,Encoder Mean (Training Vanilla),Encoder Mean (Training Gen.),Encoder Mean (Training Disc.)\n')

    # Print options
    utils.print_options(options, log)

    # Load dataset ----------------------------------------------------------------------
    # Train provider
    train_provider, val_provider, test_provider = get_providers(options, log, flat=True)


    # Initialize model ------------------------------------------------------------------
    with tf.device('/gpu:0'):
        # Define inputs
        model_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], np.prod(np.array(options['img_shape']))],
            name = 'enc_inputs'
        )
        sampler_input_batch = tf.placeholder(
            tf.float32,
            shape = [options['batch_size'], options['latent_dims']],
            name = 'dec_inputs'
        )
        log.info('Inputs defined')

        # Define model
        with tf.variable_scope('vae_scope'):
            vae_model = cupboard('vanilla_vae')(
                options['p_layers'],
                options['q_layers'],
                np.prod(options['img_shape']),
                options['latent_dims'],
                options['DKL_weight'],
                options['sigma_clip'],
                'vae_model'
            )

        with tf.variable_scope('disc_scope'):
            disc_model = cupboard('fixed_conv_disc')(
                pickle.load(open(options['disc_params_path'], 'rb')),
                options['num_feat_layers'],
                name='disc_model'
            )

        vae_gan = cupboard('vae_gan')(
            vae_model,
            disc_model,
            options['disc_weight'],
            options['img_shape'],
            options['input_channels'],
            'vae_scope',
            'disc_scope',
            name='vae_gan_model'
        )

        # Define Optimizers ---------------------------------------------------------------------
        optimizer = tf.train.AdamOptimizer(
            learning_rate=options['lr']
        )

        vae_backpass, disc_backpass, vanilla_backpass = vae_gan(model_input_batch, sampler_input_batch, optimizer)

        log.info('Optimizer graph built')
        # --------------------------------------------------------------------------------------
        # Define init operation
        init_op = tf.initialize_all_variables()
        log.info('Variable initialization graph built')

    # Define op to save and restore variables
    saver = tf.train.Saver()
    log.info('Save operation built')
    # --------------------------------------------------------------------------

    # Train loop ---------------------------------------------------------------
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        log.info('Session started')

        # Initialize shared variables or restore
        if options['reload_all']:
            saver.restore(sess, options['reload_file'])
            log.info('Shared variables restored')
        else:
            sess.run(init_op)
            log.info('Variables initialized')

            if options['reload_vae']:
                vae_model.reload_vae(options['vae_params_path'])

        # Define last losses to compute a running average
        last_losses = np.zeros((10))

        batch_abs_idx = 0
        D_to_G = options['D_to_G']
        total_D2G = sum(D_to_G)
        base = options['initial_G_iters'] + options['initial_D_iters']

        for epoch_idx in xrange(options['n_epochs']):
            batch_rel_idx = 0
            log.info('Epoch {}'.format(epoch_idx + 1))

            for inputs in train_provider:
                if isinstance(inputs, tuple):
                    inputs = inputs[0]

                batch_abs_idx += 1
                batch_rel_idx += 1

                if batch_abs_idx < options['initial_G_iters']:
                    backpass = vanilla_backpass
                    log_format_string = '{},{},{},,\n'
                elif options['initial_G_iters'] <= batch_abs_idx < base:
                    backpass = disc_backpass
                    log_format_string = '{},{},,,{}\n'
                else:
                    if (batch_abs_idx - base) % total_D2G < D_to_G[0]:
                        backpass = disc_backpass
                        log_format_string = '{},{},,,{}\n'
                    else:
                        backpass = vae_backpass
                        log_format_string = '{},{},,{},\n'

                result = sess.run(
                    [
                        vae_gan.disc_CE,
                        backpass,
                        vae_gan._vae.DKL,
                        vae_gan._vae.rec_loss,
                        vae_gan._vae.dec_log_std_sq,
                        vae_gan._vae.enc_log_std_sq,
                        vae_gan._vae.enc_mean,
                        vae_gan._vae.dec_mean
                    ],
                    feed_dict = {
                        model_input_batch: inputs,
                        sampler_input_batch: MVN(
                            np.zeros(options['latent_dims']),
                            np.diag(np.ones(options['latent_dims'])),
                            size = options['batch_size']
                        )
                    }
                )

                cost = result[0]

                if batch_abs_idx % 10 == 0:
                    train_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses)))
                    dkl_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', -np.mean(result[2])))
                    ll_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', -np.mean(result[3])))

                    train_log.flush()
                    dkl_log.flush()
                    ll_log.flush()

                    dec_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[4])))
                    enc_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[5])))
                    # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_sig_log.flush()
                    enc_sig_log.flush()

                    dec_std_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.std(result[4])))
                    enc_std_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.std(result[5])))

                    dec_mean_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[7])))
                    enc_mean_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[6])))

                    dec_std_sig_log.flush()
                    enc_std_sig_log.flush()

                    dec_mean_log.flush()
                    enc_mean_log.flush()

                # Check cost
                if np.isnan(cost) or np.isinf(cost):
                    log.info('NaN detected')
                    for i in range(len(result)):
                        print("\n\nresult[%d]:" % i)
                        try:
                            print(np.any(np.isnan(result[i])))
                        except:
                            pass
                        print(result[i])
                    print(result[3].shape)
                    print(vae_gan._vae._encoder.layers[0].weights['w'].eval())
                    print('\n\nAny:')
                    print(np.any(np.isnan(result[8])))
                    print(np.any(np.isnan(result[9])))
                    print(np.any(np.isnan(result[10])))
                    print(inputs)
                    return 1., 1., 1.

                # Update last losses
                last_losses = np.roll(last_losses, 1)
                last_losses[0] = cost

                # Display training information
                if np.mod(epoch_idx, options['freq_logging']) == 0:
                    log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format(
                        epoch_idx + 1,
                        options['n_epochs'],
                        batch_abs_idx,
                        float(cost),
                        np.mean(last_losses)
                    ))
                    log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0)))
                    log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0)))

                # Save model
                if np.mod(batch_abs_idx, options['freq_saving']) == 0:
                    saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx))
                    log.info('Model saved')

                    save_dict = {}
                    # Save encoder params ------------------------------------------------------------------
                    for i in range(len(vae_gan._vae._encoder.layers)):
                        layer_dict = {
                            'input_dim':vae_gan._vae._encoder.layers[i].input_dim,
                            'output_dim':vae_gan._vae._encoder.layers[i].output_dim,
                            'act_fn':vae_gan._vae._encoder.layers[i].activation,
                            'W':vae_gan._vae._encoder.layers[i].weights['w'].eval(),
                            'b':vae_gan._vae._encoder.layers[i].weights['b'].eval()
                        }
                        save_dict['encoder'] = layer_dict

                    layer_dict = {
                        'input_dim':vae_gan._vae._enc_mean.input_dim,
                        'output_dim':vae_gan._vae._enc_mean.output_dim,
                        'act_fn':vae_gan._vae._enc_mean.activation,
                        'W':vae_gan._vae._enc_mean.weights['w'].eval(),
                        'b':vae_gan._vae._enc_mean.weights['b'].eval()
                    }
                    save_dict['enc_mean'] = layer_dict

                    layer_dict = {
                        'input_dim':vae_gan._vae._enc_log_std_sq.input_dim,
                        'output_dim':vae_gan._vae._enc_log_std_sq.output_dim,
                        'act_fn':vae_gan._vae._enc_log_std_sq.activation,
                        'W':vae_gan._vae._enc_log_std_sq.weights['w'].eval(),
                        'b':vae_gan._vae._enc_log_std_sq.weights['b'].eval()
                    }
                    save_dict['enc_log_std_sq'] = layer_dict

                    # Save decoder params ------------------------------------------------------------------
                    for i in range(len(vae_gan._vae._decoder.layers)):
                        layer_dict = {
                            'input_dim':vae_gan._vae._decoder.layers[i].input_dim,
                            'output_dim':vae_gan._vae._decoder.layers[i].output_dim,
                            'act_fn':vae_gan._vae._decoder.layers[i].activation,
                            'W':vae_gan._vae._decoder.layers[i].weights['w'].eval(),
                            'b':vae_gan._vae._decoder.layers[i].weights['b'].eval()
                        }
                        save_dict['decoder'] = layer_dict

                    layer_dict = {
                        'input_dim':vae_gan._vae._dec_mean.input_dim,
                        'output_dim':vae_gan._vae._dec_mean.output_dim,
                        'act_fn':vae_gan._vae._dec_mean.activation,
                        'W':vae_gan._vae._dec_mean.weights['w'].eval(),
                        'b':vae_gan._vae._dec_mean.weights['b'].eval()
                    }
                    save_dict['dec_mean'] = layer_dict

                    layer_dict = {
                        'input_dim':vae_gan._vae._dec_log_std_sq.input_dim,
                        'output_dim':vae_gan._vae._dec_log_std_sq.output_dim,
                        'act_fn':vae_gan._vae._dec_log_std_sq.activation,
                        'W':vae_gan._vae._dec_log_std_sq.weights['w'].eval(),
                        'b':vae_gan._vae._dec_log_std_sq.weights['b'].eval()
                    }
                    save_dict['dec_log_std_sq'] = layer_dict

                    pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'vae_dict_%d' % batch_abs_idx), 'wb'))

                # Validate model
                if np.mod(batch_abs_idx, options['freq_validation']) == 0:

                    vae_gan._vae._decoder.layers[0].weights['w'].eval()[:5,:5]

                    valid_costs = []
                    seen_batches = 0
                    for val_batch in val_provider:
                        if isinstance(val_batch, tuple):
                            val_batch = val_batch[0]

                        val_cost = sess.run(
                            vae_gan.disc_CE,
                            feed_dict = {
                                model_input_batch: val_batch,
                                sampler_input_batch: MVN(
                                    np.zeros(options['latent_dims']),
                                    np.diag(np.ones(options['latent_dims'])),
                                    size = options['batch_size']
                                )
                            }
                        )
                        valid_costs.append(val_cost)
                        seen_batches += 1

                        if seen_batches == options['valid_batches']:
                            break

                    # Print results
                    log.info('Validation loss: {:0>15.4f}'.format(
                        float(np.mean(valid_costs))
                    ))

                    val_samples = sess.run(
                        vae_gan.sampler,
                        feed_dict = {
                            sampler_input_batch: MVN(
                                np.zeros(options['latent_dims']),
                                np.diag(np.ones(options['latent_dims'])),
                                size = options['batch_size']
                            )
                        }
                    )

                    val_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs)))
                    val_log.flush()

                    save_ae_samples(
                        catalog,
                        np.reshape(result[7], [options['batch_size']]+options['img_shape']),
                        np.reshape(inputs, [options['batch_size']]+options['img_shape']),
                        np.reshape(val_samples, [options['batch_size']]+options['img_shape']),
                        batch_abs_idx,
                        options['dashboard_dir'],
                        num_to_save=5,
                        save_gray=True
                    )

            log.info('End of epoch {}'.format(epoch_idx + 1))