def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format( options['DKL_weight'], -options['sigma_clip'], options['sigma_clip'])) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write("""filename,type,name options,plain,Options train_loss.csv,csv,Train Loss ll.csv,csv,Neg. Log-Likelihood dec_log_sig_sq.csv,csv,Decoder Log Simga^2 dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2 dec_mean.csv,csv,Decoder Mean dkl.csv,csv,DKL enc_log_sig_sq.csv,csv,Encoder Log Sigma^2 enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2 enc_mean.csv,csv,Encoder Mean val_loss.csv,csv,Validation Loss """) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w') ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w') dec_sig_log = open( os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w') enc_sig_log = open( os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w') dec_std_sig_log = open( os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w') enc_std_sig_log = open( os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w') dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w') enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w') # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w') train_log.write('step,time,Train Loss\n') val_log.write('step,time,Validation Loss\n') dkl_log.write('step,time,DKL\n') ll_log.write('step,time,-LL\n') dec_sig_log.write('step,time,Decoder Log Sigma^2\n') enc_sig_log.write('step,time,Encoder Log Sigma^2\n') dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n') enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n') dec_mean_log.write('step,time,Decoder Mean\n') enc_mean_log.write('step,time,Encoder Mean\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider if options['data_dir'] != 'MNIST': num_data_points = len( os.listdir(os.path.join(options['data_dir'], 'train', 'patches'))) num_data_points -= 2 train_provider = DataProvider( num_data_points, options['batch_size'], toolbox.ImageLoader(data_dir=os.path.join(options['data_dir'], 'train', 'patches'), flat=True, extension=options['file_extension'])) # Valid provider num_data_points = len( os.listdir(os.path.join(options['data_dir'], 'valid', 'patches'))) num_data_points -= 2 val_provider = DataProvider( num_data_points, options['batch_size'], toolbox.ImageLoader(data_dir=os.path.join(options['data_dir'], 'valid', 'patches'), flat=True, extension=options['file_extension'])) else: train_provider = DataProvider( 55000, options['batch_size'], toolbox.MNISTLoader(mode='train', flat=True)) val_provider = DataProvider( 5000, options['batch_size'], toolbox.MNISTLoader(mode='validation', flat=True)) log.info('Data providers initialized.') # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): model = cupboard(options['model'])( options['p_layers'], options['q_layers'], np.prod(options['img_shape']), options['latent_dims'], options['DKL_weight'], options['sigma_clip'], 'vanilla_vae') log.info('Model initialized') # Define inputs model_input_batch = tf.placeholder( tf.float32, shape=[ options['batch_size'], np.prod(np.array(options['img_shape'])) ], name='enc_inputs') model_label_batch = tf.placeholder( tf.float32, shape=[options['batch_size'], options['num_classes']], name='labels') sampler_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size'], options['latent_dims']], name='dec_inputs') log.info('Inputs defined') # Define forward pass cost_function = model(model_input_batch) log.info('Forward pass graph built') # Define sampler sampler = model.build_sampler(sampler_input_batch) log.info('Sampler graph built') # Define optimizer optimizer = tf.train.AdamOptimizer(learning_rate=options['lr']) # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr']) enc_std = tf.exp(tf.mul(0.5, model.enc_log_std_sq)) classifier = FC( model.latent_dims, options['num_classes'], activation=None, scale=0.01, name='classifier_fc')(tf.add( tf.mul(tf.random_normal([model.n_samples, model.latent_dims]), enc_std), model.enc_mean)) classifier = tf.nn.softmax(classifier) cost_function = -tf.mul(model_label_batch, tf.log(classifier)) cost_function = tf.reduce_sum(cost_function) cost_function *= 1 / float(options['batch_size']) train_step = optimizer.minimize(cost_function) # Get gradients grads = optimizer.compute_gradients(cost_function) grads = [gv for gv in grads if gv[0] != None] grad_tensors = [gv[0] for gv in grads] # Clip gradients clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads] # Update op backpass = optimizer.apply_gradients(clip_grads) log.info('Optimizer graph built') # # Get gradients # grad = optimizer.compute_gradients(cost_function) # # Clip gradients # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping') # # Update op # backpass = optimizer.apply_gradients(clipped_grad) # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore sess.run(init_op) saver.restore( sess, os.path.join(options['model_dir'], 'model_at_21000.ckpt')) log.info('Shared variables restored') # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs, labels in train_provider: batch_abs_idx += 1 batch_rel_idx += 1 result = sess.run( # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL), # 0 1 2 3 4 5 6 7 8 9 10 [ cost_function, backpass, model.DKL, model.rec_loss, model.dec_log_std_sq, model.enc_log_std_sq, model.enc_mean, model.dec_mean, classifier ] + [gv[0] for gv in grads], feed_dict={ model_input_batch: inputs, model_label_batch: labels }) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean( np.argmax(labels, axis=1) == np.argmax(result[8], axis=1)))) dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2]))) ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3]))) train_log.flush() dkl_log.flush() ll_log.flush() dec_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[4]))) enc_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[5]))) # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_sig_log.flush() enc_sig_log.flush() dec_std_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.std(result[4]))) enc_std_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.std(result[5]))) dec_mean_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[7]))) enc_mean_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_std_sig_log.flush() enc_std_sig_log.flush() dec_mean_log.flush() enc_mean_log.flush() # val_sig_log.flush() # print('\n\nENC_MEAN:') # print(result[3]) # print('\n\nENC_STD:') # print(result[2]) # print('\nDEC_MEAN:') # print(result[6]) # print('\nDEC_STD:') # print(result[5]) # print('\n\nENCODER WEIGHTS:') # print(model._encoder.layers[0].weights['w'].eval()) # print('\n\DECODER WEIGHTS:') # print(model._decoder.layers[0].weights['w'].eval()) # print(model._encoder.layers[0].weights['w'].eval()) # print(result[2]) # print(result[3]) # print(result[3]) # print(result[2]) # print(result[-2]) # print(result[-1]) # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(model._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses))) log.info('Batch Mean LL: {:0>15.4f}'.format( np.mean(result[3], axis=0))) log.info('Batch Mean -DKL: {:0>15.4f}'.format( np.mean(result[2], axis=0))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save( sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: valid_costs = [] seen_batches = 0 for val_batch, val_labels in val_provider: val_result = sess.run( [cost_function, classifier], feed_dict={ model_input_batch: val_batch, model_label_batch: val_labels }) val_cost = val_result[0] valid_costs.append(val_cost) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)))) val_samples = sess.run( sampler, feed_dict={ sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) val_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean( np.argmax(val_labels, axis=1) == np.argmax( val_result[1], axis=1)))) val_log.flush() save_ae_samples(catalog, np.reshape(result[7], [options['batch_size']] + options['img_shape']), np.reshape(inputs, [options['batch_size']] + options['img_shape']), np.reshape(val_samples, [options['batch_size']] + options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True) # save_dash_samples( # catalog, # val_samples, # batch_abs_idx, # options['dashboard_dir'], # flat_samples=True, # img_shape=options['img_shape'], # num_to_save=5 # ) save_samples( val_samples, int(batch_abs_idx / options['freq_validation']), os.path.join(options['model_dir'], 'valid_samples'), True, options['img_shape'], 5) save_samples( inputs, int(batch_abs_idx / options['freq_validation']), os.path.join(options['model_dir'], 'input_sanity'), True, options['img_shape'], num_to_save=5) save_samples( result[7], int(batch_abs_idx / options['freq_validation']), os.path.join(options['model_dir'], 'rec_sanity'), True, options['img_shape'], num_to_save=5) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format( options['DKL_weight'], -options['sigma_clip'], options['sigma_clip'] ) ) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write( """filename,type,name options,plain,Options train_loss.csv,csv,Train Loss ll.csv,csv,Neg. Log-Likelihood dec_log_sig_sq.csv,csv,Decoder Log Simga^2 dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2 dec_mean.csv,csv,Decoder Mean dkl.csv,csv,DKL enc_log_sig_sq.csv,csv,Encoder Log Sigma^2 enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2 enc_mean.csv,csv,Encoder Mean val_loss.csv,csv,Validation Loss """ ) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w') ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w') dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w') enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w') dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w') enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w') dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w') enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w') # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w') train_log.write('step,time,Train Loss\n') val_log.write('step,time,Validation Loss\n') dkl_log.write('step,time,DKL\n') ll_log.write('step,time,-LL\n') dec_sig_log.write('step,time,Decoder Log Sigma^2\n') enc_sig_log.write('step,time,Encoder Log Sigma^2\n') dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n') enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n') dec_mean_log.write('step,time,Decoder Mean\n') enc_mean_log.write('step,time,Encoder Mean\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): # Define inputs ---------------------------------------------------------- model_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], np.prod(np.array(options['img_shape']))], name = 'enc_inputs' ) sampler_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], options['latent_dims']], name = 'dec_inputs' ) log.info('Inputs defined') # Feature Extractor ----------------------------------------------------- feat_layers = [] feat_params = pickle.load(open(options['feat_params_path'], 'rb')) _classifier = Sequential('CNN_Classifier') conv_count, pool_count, fc_count = 0, 0, 0 for lay in feat_params: print(lay['layer_type']) for i in xrange(options['num_feat_layers']): if feat_params[i]['layer_type'] == 'conv': _classifier += ConvLayer( feat_params[i]['n_filters_in'], feat_params[i]['n_filters_out'], feat_params[i]['input_dim'], feat_params[i]['filter_dim'], feat_params[i]['strides'], name='classifier_conv_%d' % conv_count ) _classifier.layers[-1].weights['W'] = tf.constant(feat_params[i]['W']) _classifier.layers[-1].weights['b'] = tf.constant(feat_params[i]['b']) _classifier += feat_params[i]['act_fn'] conv_count += 1 elif feat_params[i]['layer_type'] == 'pool': _classifier += PoolLayer( feat_params[i]['input_dim'], feat_params[i]['filter_dim'], feat_params[i]['strides'], name='classifier_pool_%d' % i ) pool_count += 1 feat_layers.append(i) elif feat_params[i]['layer_type'] == 'fc': _classifier += ConstFC( feat_params[i]['W'], feat_params[i]['b'], activation=feat_params[i]['act_fn'], name='classifier_fc_%d' % fc_count ) fc_count += 1 feat_layers.append(i) # if options['feat_type'] == 'fc': # feat_model = Sequential('feat_extractor') # feat_params = pickle.load(open(options['feat_params_path'], 'rb')) # for i in range(options['num_feat_layers']): # feat_model += ConstFC( # feat_params['enc_W'][i], # feat_params['enc_b'][i], # activation=feat_params['enc_act_fn'][i], # name='feat_layer_%d'%i # ) # else: # pass # VAE ------------------------------------------------------------------- # VAE model vae_model = cupboard('vanilla_vae')( options['p_layers'], options['q_layers'], np.prod(options['img_shape']), options['latent_dims'], options['DKL_weight'], options['sigma_clip'], 'vanilla_vae' ) # ----------------------------------------------------------------------- feat_vae = cupboard('feat_vae')( vae_model, _classifier, feat_layers, options['DKL_weight'], options['vae_rec_loss_weight'], img_shape=options['img_shape'], input_channels=options['input_channels'], flat=False, name='feat_vae_model' ) log.info('Model initialized') # Define forward pass cost_function = feat_vae(model_input_batch) log.info('Forward pass graph built') # Define sampler sampler = feat_vae.build_sampler(sampler_input_batch) log.info('Sampler graph built') # Define optimizer optimizer = tf.train.AdamOptimizer( learning_rate=options['lr'] ) # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr']) # train_step = optimizer.minimize(cost_function) log.info('Optimizer graph built') # Get gradients grads = optimizer.compute_gradients(cost_function) grads = [gv for gv in grads if gv[0] != None] grad_tensors = [gv[0] for gv in grads] # Clip gradients clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads] # Update op backpass = optimizer.apply_gradients(clip_grads) # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') test_LL_and_DKL(sess, test_provider, feat_vae.vae.DKL, feat_vae.vae.rec_loss, options, model_input_batch) return mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension'])) std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension'])) visualize(sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq, sampler, sampler_input_batch, model_input_batch, feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq, train_provider, val_provider, options, catalog, mean_img, std_img) return else: sess.run(init_op) log.info('Shared variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs,_ in train_provider: batch_abs_idx += 1 batch_rel_idx += 1 result = sess.run( # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL), # 0 1 2 3 4 5 6 7 8 9 10 [cost_function, backpass, feat_vae.vae.DKL, feat_vae.vae.rec_loss, feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq, feat_vae.vae.enc_mean, feat_vae.vae.dec_mean] + [gv[0] for gv in grads], feed_dict = { model_input_batch: inputs } ) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2]))) ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3]))) train_log.flush() dkl_log.flush() ll_log.flush() dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4]))) enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5]))) # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_sig_log.flush() enc_sig_log.flush() dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4]))) enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5]))) dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7]))) enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_std_sig_log.flush() enc_std_sig_log.flush() dec_mean_log.flush() enc_mean_log.flush() # val_sig_log.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(model._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format( epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses) )) log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0))) log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: valid_costs = [] seen_batches = 0 for val_batch,_ in val_provider: val_cost = sess.run( cost_function, feed_dict = { model_input_batch: val_batch } ) valid_costs.append(val_cost) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)) )) val_samples = sess.run( sampler, feed_dict = { sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() save_ae_samples( catalog, np.reshape(result[7], [options['batch_size']]+options['img_shape']), np.reshape(inputs, [options['batch_size']]+options['img_shape']), np.reshape(val_samples, [options['batch_size']]+options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True ) # save_dash_samples( # catalog, # val_samples, # batch_abs_idx, # options['dashboard_dir'], # flat_samples=True, # img_shape=options['img_shape'], # num_to_save=5 # ) save_samples( val_samples, int(batch_abs_idx/options['freq_validation']), os.path.join(options['model_dir'], 'valid_samples'), True, options['img_shape'], 5 ) save_samples( inputs, int(batch_abs_idx/options['freq_validation']), os.path.join(options['model_dir'], 'input_sanity'), True, options['img_shape'], num_to_save=5 ) save_samples( result[7], int(batch_abs_idx/options['freq_validation']), os.path.join(options['model_dir'], 'rec_sanity'), True, options['img_shape'], num_to_save=5 ) log.info('End of epoch {}'.format(epoch_idx + 1)) # Test Model -------------------------------------------------------------------------- test_results = [] for inputs in test_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_results = sess.run( [ feat_vae.vae.DKL, feat_vae.vae.rec_loss, feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq, feat_vae.vae.dec_mean, feat_vae.vae.enc_mean ], feed_dict = { model_input_batch: inputs } ) test_results.append(map(lambda p: np.mean(p, axis=1) if len(p.shape) > 1 else np.mean(p), batch_results)) test_results = map(list, zip(*test_results)) # Print results log.info('Test Mean Rec. Loss: {:0>15.4f}'.format( float(np.mean(test_results[1])) )) log.info('Test DKL: {:0>15.4f}'.format( float(np.mean(test_results[0])) )) log.info('Test Dec. Mean Log Std Sq: {:0>15.4f}'.format( float(np.mean(test_results[2])) )) log.info('Test Enc. Mean Log Std Sq: {:0>15.4f}'.format( float(np.mean(test_results[3])) )) log.info('Test Dec. Mean Mean: {:0>15.4f}'.format( float(np.mean(test_results[4])) )) log.info('Test Enc. Mean Mean: {:0>15.4f}'.format( float(np.mean(test_results[5])) ))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'Log Sigma^2 clipped to: [{}, {}]\n\n'.format( -options['sigma_clip'], options['sigma_clip'] ) ) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write( """filename,type,name options,plain,Options train_loss.csv,csv,Discriminator Cross-Entropy train_acc.csv,csv,Discriminator Accuracy val_loss.csv,csv,Validation Cross-Entropy val_acc.csv,csv,Validation Accuracy """ ) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') train_acc = open(os.path.join(options['dashboard_dir'], 'train_acc.csv'), 'w') val_acc = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'), 'w') train_log.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n') val_log.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n') train_acc.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n') val_acc.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation Acc. (Training Disc.)\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): # Define inputs ------------------------------------------------------------------------- real_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], np.prod(np.array(options['img_shape']))], name = 'real_inputs' ) sampler_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], options['latent_dims']], name = 'noise_channel' ) labels = tf.constant( np.expand_dims( np.concatenate( ( np.ones(options['batch_size']), np.zeros(options['batch_size']) ), axis=0 ).astype(np.float32), axis=1 ) ) labels = tf.cast(labels, tf.float32) log.info('Inputs defined') # Define model -------------------------------------------------------------------------- with tf.variable_scope('gen_scope'): generator = Sequential('generator') generator += FullyConnected(options['latent_dims'], 60, tf.nn.tanh, name='fc_1') generator += FullyConnected(60, 60, tf.nn.tanh, name='fc_2') generator += FullyConnected(60, np.prod(options['img_shape']), tf.nn.tanh, name='fc_3') sampler = generator(sampler_input_batch) with tf.variable_scope('disc_scope'): disc_model = cupboard('fixed_conv_disc')( pickle.load(open(options['disc_params_path'], 'rb')), options['num_feat_layers'], name='disc_model' ) disc_inputs = tf.concat(0, [real_batch, sampler]) disc_inputs = tf.reshape( disc_inputs, [disc_inputs.get_shape()[0].value] + options['img_shape'] + [options['input_channels']] ) preds = disc_model(disc_inputs) preds = tf.clip_by_value(preds, 0.00001, 0.99999) # Disc Accuracy disc_accuracy = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum( tf.cast( tf.equal( tf.round(preds), labels ), tf.float32 ) ) # Define Losses ------------------------------------------------------------------------- # Discrimnator Cross-Entropy disc_CE = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum( -tf.add( tf.mul( labels, tf.log(preds) ), tf.mul( 1.0 - labels, tf.log(1.0 - preds) ) ) ) gen_loss = -tf.mul( 1.0 - labels, tf.log(preds) ) # Define Optimizers --------------------------------------------------------------------- optimizer = tf.train.AdamOptimizer( learning_rate=options['lr'] ) # Get Generator and Disriminator Trainable Variables gen_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'gen_scope') disc_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'disc_scope') # Get generator gradients grads = optimizer.compute_gradients(gen_loss, gen_train_vars) grads = [gv for gv in grads if gv[0] != None] clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='gen_grad_clipping'), gv[1]) for gv in grads] gen_backpass = optimizer.apply_gradients(clip_grads) # Get Dsicriminator gradients grads = optimizer.compute_gradients(disc_CE, disc_train_vars) grads = [gv for gv in grads if gv[0] != None] clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='disc_grad_clipping'), gv[1]) for gv in grads] disc_backpass = optimizer.apply_gradients(clip_grads) log.info('Optimizer graph built') # -------------------------------------------------------------------------------------- # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload_all']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') else: sess.run(init_op) log.info('Variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) last_accs = np.zeros((10)) disc_tracker = np.ones((5000)) batch_abs_idx = 0 D_to_G = options['D_to_G'] total_D2G = sum(D_to_G) base = options['initial_G_iters'] + options['initial_D_iters'] # must_init = True feat_params = pickle.load(open(options['disc_params_path'], 'rb')) for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs in train_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_abs_idx += 1 batch_rel_idx += 1 if batch_abs_idx < options['initial_G_iters']: backpass = gen_backpass log_format_string = '{},{},{},,\n' elif options['initial_G_iters'] <= batch_abs_idx < base: backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: # if np.mean(disc_tracker) < 0.95: # disc_model._disc.layers[-2].re_init_weights(sess) # disc_tracker = np.ones((5000)) if (batch_abs_idx - base) % total_D2G < D_to_G[0]: # if must_init: # # i = 0 # # for j in xrange(options['num_feat_layers']): # # if feat_params[j]['layer_type'] == 'conv': # # disc_model._disc.layers[i].re_init_weights(sess) # # # print('@' * 1000) # # # print(disc_model._disc.layers[i]) # # i += 1 # for dealing with activation function # # elif feat_params[j]['layer_type'] == 'fc': # # disc_model._disc.layers[i].re_init_weights(sess) # # # print('@' * 1000) # # # print(disc_model._disc.layers[i]) # # i += 1 # disc_model._disc.layers[-2].re_init_weights(sess) # # print('@' * 1000) # # print(disc_model._disc.layers[-2]) # must_init = False backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: # must_init = True backpass = gen_backpass log_format_string = '{},{},,{},\n' log_format_string = '{},{},{},,\n' result = sess.run( [ disc_CE, backpass, disc_accuracy ], feed_dict = { real_batch: inputs, sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) train_acc.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_accs))) train_log.flush() train_acc.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost last_accs = np.roll(last_accs, 1) last_accs[0] = result[-1] disc_tracker = np.roll(disc_tracker, 1) disc_tracker[0] = result[-1] # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format( epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses) )) log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last accuracies: {:0>15.4f}'.format( epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_accs) )) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: valid_costs = [] valid_accs = [] seen_batches = 0 for val_batch in val_provider: if isinstance(val_batch, tuple): val_batch = val_batch[0] result = sess.run( [ disc_CE, disc_accuracy ], feed_dict = { real_batch: val_batch, sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) valid_costs.append(result[0]) valid_accs.append(result[-1]) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)) )) log.info('Validation accuracies: {:0>15.4f}'.format( float(np.mean(valid_accs)) )) val_samples = sess.run( sampler, feed_dict = { sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) val_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_acc.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_accs))) val_log.flush() val_acc.flush() save_ae_samples( catalog, np.ones([options['batch_size']]+options['img_shape']), np.reshape(inputs, [options['batch_size']]+options['img_shape']), np.reshape(val_samples, [options['batch_size']]+options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True ) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format( options['DKL_weight'], -options['sigma_clip'], options['sigma_clip'])) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write("""filename,type,name options,plain,Options train_loss.csv,csv,Discriminator Cross-Entropy ll.csv,csv,Neg. Log-Likelihood dec_log_sig_sq.csv,csv,Decoder Log Simga^2 dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2 dec_mean.csv,csv,Decoder Mean dkl.csv,csv,DKL enc_log_sig_sq.csv,csv,Encoder Log Sigma^2 enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2 enc_mean.csv,csv,Encoder Mean val_loss.csv,csv,Validation Loss """) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w') ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w') dec_sig_log = open( os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w') enc_sig_log = open( os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w') dec_std_sig_log = open( os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w') enc_std_sig_log = open( os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w') dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w') enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w') # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w') train_log.write( 'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n' ) val_log.write( 'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n' ) dkl_log.write( 'step,time,DKL (Training Vanilla),DKL (Training Gen.),DKL (Training Disc.)\n' ) ll_log.write( 'step,time,-LL (Training Vanilla),-LL (Training Gen.),-LL (Training Disc.)\n' ) dec_sig_log.write( 'step,time,Decoder Log Sigma^2 (Training Vanilla),Decoder Log Sigma^2 (Training Gen.),Decoder Log Sigma^2 (Training Disc.)\n' ) enc_sig_log.write( 'step,time,Encoder Log Sigma^2 (Training Vanilla),Encoder Log Sigma^2 (Training Gen.),Encoder Log Sigma^2 (Training Disc.)\n' ) dec_std_sig_log.write( 'step,time,STD of Decoder Log Sigma^2 (Training Vanilla),STD of Decoder Log Sigma^2 (Training Gen.),STD of Decoder Log Sigma^2 (Training Disc.)\n' ) enc_std_sig_log.write( 'step,time,STD of Encoder Log Sigma^2 (Training Vanilla),STD of Encoder Log Sigma^2 (Training Gen.),STD of Encoder Log Sigma^2 (Training Disc.)\n' ) dec_mean_log.write( 'step,time,Decoder Mean (Training Vanilla),Decoder Mean (Training Gen.),Decoder Mean (Training Disc.)\n' ) enc_mean_log.write( 'step,time,Encoder Mean (Training Vanilla),Encoder Mean (Training Gen.),Encoder Mean (Training Disc.)\n' ) # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): # Define inputs model_input_batch = tf.placeholder( tf.float32, shape=[ options['batch_size'], np.prod(np.array(options['img_shape'])) ], name='enc_inputs') sampler_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size'], options['latent_dims']], name='dec_inputs') log.info('Inputs defined') # Define model with tf.variable_scope('vae_scope'): vae_model = cupboard('vanilla_vae')( options['p_layers'], options['q_layers'], np.prod(options['img_shape']), options['latent_dims'], options['DKL_weight'], options['sigma_clip'], 'vae_model') with tf.variable_scope('disc_scope'): disc_model = cupboard('fixed_conv_disc')( pickle.load(open(options['disc_params_path'], 'rb')), options['num_feat_layers'], name='disc_model') vae_gan = cupboard('vae_gan')(vae_model, disc_model, options['disc_weight'], options['img_shape'], options['input_channels'], 'vae_scope', 'disc_scope', name='vae_gan_model') # Define Optimizers --------------------------------------------------------------------- optimizer = tf.train.AdamOptimizer(learning_rate=options['lr']) vae_backpass, disc_backpass, vanilla_backpass = vae_gan( model_input_batch, sampler_input_batch, optimizer) log.info('Optimizer graph built') # -------------------------------------------------------------------------------------- # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload_all']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') else: sess.run(init_op) log.info('Variables initialized') if options['reload_vae']: vae_model.reload_vae(options['vae_params_path']) # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 D_to_G = options['D_to_G'] total_D2G = sum(D_to_G) base = options['initial_G_iters'] + options['initial_D_iters'] for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs in train_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_abs_idx += 1 batch_rel_idx += 1 if batch_abs_idx < options['initial_G_iters']: backpass = vanilla_backpass log_format_string = '{},{},{},,\n' elif options['initial_G_iters'] <= batch_abs_idx < base: backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: if (batch_abs_idx - base) % total_D2G < D_to_G[0]: backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: backpass = vae_backpass log_format_string = '{},{},,{},\n' result = sess.run( [ vae_gan.disc_CE, backpass, vae_gan._vae.DKL, vae_gan._vae.rec_loss, vae_gan._vae.dec_log_std_sq, vae_gan._vae.enc_log_std_sq, vae_gan._vae.enc_mean, vae_gan._vae.dec_mean ], feed_dict={ model_input_batch: inputs, sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) dkl_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', -np.mean(result[2]))) ll_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', -np.mean(result[3]))) train_log.flush() dkl_log.flush() ll_log.flush() dec_sig_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[4]))) enc_sig_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[5]))) # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_sig_log.flush() enc_sig_log.flush() dec_std_sig_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.std(result[4]))) enc_std_sig_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.std(result[5]))) dec_mean_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[7]))) enc_mean_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_std_sig_log.flush() enc_std_sig_log.flush() dec_mean_log.flush() enc_mean_log.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(vae_gan._vae._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses))) log.info('Batch Mean LL: {:0>15.4f}'.format( np.mean(result[3], axis=0))) log.info('Batch Mean -DKL: {:0>15.4f}'.format( np.mean(result[2], axis=0))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save( sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') save_dict = {} # Save encoder params ------------------------------------------------------------------ for i in range(len(vae_gan._vae._encoder.layers)): layer_dict = { 'input_dim': vae_gan._vae._encoder.layers[i].input_dim, 'output_dim': vae_gan._vae._encoder.layers[i].output_dim, 'act_fn': vae_gan._vae._encoder.layers[i].activation, 'W': vae_gan._vae._encoder.layers[i].weights['w'].eval( ), 'b': vae_gan._vae._encoder.layers[i].weights['b'].eval( ) } save_dict['encoder'] = layer_dict layer_dict = { 'input_dim': vae_gan._vae._enc_mean.input_dim, 'output_dim': vae_gan._vae._enc_mean.output_dim, 'act_fn': vae_gan._vae._enc_mean.activation, 'W': vae_gan._vae._enc_mean.weights['w'].eval(), 'b': vae_gan._vae._enc_mean.weights['b'].eval() } save_dict['enc_mean'] = layer_dict layer_dict = { 'input_dim': vae_gan._vae._enc_log_std_sq.input_dim, 'output_dim': vae_gan._vae._enc_log_std_sq.output_dim, 'act_fn': vae_gan._vae._enc_log_std_sq.activation, 'W': vae_gan._vae._enc_log_std_sq.weights['w'].eval(), 'b': vae_gan._vae._enc_log_std_sq.weights['b'].eval() } save_dict['enc_log_std_sq'] = layer_dict # Save decoder params ------------------------------------------------------------------ for i in range(len(vae_gan._vae._decoder.layers)): layer_dict = { 'input_dim': vae_gan._vae._decoder.layers[i].input_dim, 'output_dim': vae_gan._vae._decoder.layers[i].output_dim, 'act_fn': vae_gan._vae._decoder.layers[i].activation, 'W': vae_gan._vae._decoder.layers[i].weights['w'].eval( ), 'b': vae_gan._vae._decoder.layers[i].weights['b'].eval( ) } save_dict['decoder'] = layer_dict layer_dict = { 'input_dim': vae_gan._vae._dec_mean.input_dim, 'output_dim': vae_gan._vae._dec_mean.output_dim, 'act_fn': vae_gan._vae._dec_mean.activation, 'W': vae_gan._vae._dec_mean.weights['w'].eval(), 'b': vae_gan._vae._dec_mean.weights['b'].eval() } save_dict['dec_mean'] = layer_dict layer_dict = { 'input_dim': vae_gan._vae._dec_log_std_sq.input_dim, 'output_dim': vae_gan._vae._dec_log_std_sq.output_dim, 'act_fn': vae_gan._vae._dec_log_std_sq.activation, 'W': vae_gan._vae._dec_log_std_sq.weights['w'].eval(), 'b': vae_gan._vae._dec_log_std_sq.weights['b'].eval() } save_dict['dec_log_std_sq'] = layer_dict pickle.dump( save_dict, open( os.path.join(options['model_dir'], 'vae_dict_%d' % batch_abs_idx), 'wb')) # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: vae_gan._vae._decoder.layers[0].weights['w'].eval()[:5, :5] valid_costs = [] seen_batches = 0 for val_batch in val_provider: if isinstance(val_batch, tuple): val_batch = val_batch[0] val_cost = sess.run( vae_gan.disc_CE, feed_dict={ model_input_batch: val_batch, sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) valid_costs.append(val_cost) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)))) val_samples = sess.run( vae_gan.sampler, feed_dict={ sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) val_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() save_ae_samples(catalog, np.reshape(result[7], [options['batch_size']] + options['img_shape']), np.reshape(inputs, [options['batch_size']] + options['img_shape']), np.reshape(val_samples, [options['batch_size']] + options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write( """filename,type,name options,plain,Options train_loss.csv,csv,Train Loss val_loss.csv,csv,Validation Loss """ ) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') train_log.write('step,time,Train Loss\n') val_log.write('step,time,Validation Loss\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log) # Initialize model ------------------------------------------------------------------ # input_shape, input_channels, enc_params, dec_params, name='' with tf.device('/gpu:0'): if options['model'] == 'cnn_ae': model = cupboard(options['model'])( options['img_shape'], options['input_channels'], options['enc_params'], options['dec_params'], 'cnn_ae' ) # Define inputs model_clean_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']], name = 'clean' ) model_noisy_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size']] + options['img_shape'] + [options['input_channels']], name = 'noisy' ) log.info('Inputs defined') else: model = cupboard(options['model'])( np.prod(options['img_shape']) * options['input_channels'], options['enc_params'], options['dec_params'], 'ae' ) # Define inputs model_clean_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']], name = 'clean' ) model_noisy_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']], name = 'noisy' ) log.info('Inputs defined') log.info('Model initialized') # Define forward pass print(model_clean_input_batch.get_shape()) print(model_noisy_input_batch.get_shape()) cost_function = model(model_clean_input_batch, model_noisy_input_batch) log.info('Forward pass graph built') log.info('Sampler graph built') # Define optimizer optimizer = tf.train.AdamOptimizer( learning_rate=options['lr'] ) # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr']) # train_step = optimizer.minimize(cost_function) log.info('Optimizer graph built') # Get gradients grads = optimizer.compute_gradients(cost_function) grads = [gv for gv in grads if gv[0] != None] grad_tensors = [gv[0] for gv in grads] # Clip gradients clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads] # Update op backpass = optimizer.apply_gradients(clip_grads) # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload']: saver.restore(sess, os.path.join(options['model_dir'], 'model.ckpt')) log.info('Shared variables restored') else: sess.run(init_op) log.info('Shared variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs,_ in train_provider: batch_abs_idx += 1 batch_rel_idx += 1 result = sess.run( [cost_function, backpass] + [gv[0] for gv in grads], feed_dict = { model_clean_input_batch: inputs, model_noisy_input_batch: np.float32(inputs) + \ normal( loc=0.0, scale=np.float32(options['noise_std']), size=inputs.shape ) } ) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) train_log.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(model._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format( epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses) )) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Save Encoder Params save_dict = { 'enc_W': [], 'enc_b': [], 'enc_act_fn': [], } if options['model'] == 'cnn_ae': pass else: for i in range(len(model._encoder.layers)): save_dict['enc_W'].append(model._encoder.layers[i].weights['w'].eval()) save_dict['enc_b'].append(model._encoder.layers[i].weights['b'].eval()) save_dict['enc_act_fn'].append(options['enc_params']['act_fn'][i]) pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'enc_dict_%d' % batch_abs_idx), 'wb')) # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: model._decoder.layers[0].weights['w'].eval()[:5,:5] valid_costs = [] seen_batches = 0 for val_batch,_ in val_provider: noisy_val_batch = val_batch + \ normal( loc=0.0, scale=np.float32(options['noise_std']), size=val_batch.shape ) val_results = sess.run( (cost_function, model.decoder), feed_dict = { model_clean_input_batch: val_batch, model_noisy_input_batch: noisy_val_batch } ) valid_costs.append(val_results[0]) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)) )) val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() if options['model'] == 'conv_ae': val_recon = np.reshape( val_results[-1], val_batch.shape ) else: val_batch = np.reshape( val_batch, [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']] ) noisy_val_batch = np.reshape( noisy_val_batch, [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']] ) val_recon = np.reshape( val_results[-1], [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']] ) save_ae_samples( catalog, val_batch, noisy_val_batch, val_recon, batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True ) # save_samples( # val_recon, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'valid_samples'), # False, # options['img_shape'], # 5 # ) # save_samples( # inputs, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'input_sanity'), # False, # options['img_shape'], # num_to_save=5 # ) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format( options['DKL_weight'], -options['sigma_clip'], options['sigma_clip'])) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() with open(os.path.join(options['dashboard_dir'], 'description'), 'w') as desc_file: desc_file.write(options['description']) # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write("""filename,type,name description,plain,Description options,plain,Options train_loss.csv,csv,Train Loss ll.csv,csv,Neg. Log-Likelihood dec_log_sig_sq.csv,csv,Decoder Log Simga^2 dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2 dec_mean.csv,csv,Decoder Mean dkl.csv,csv,DKL enc_log_sig_sq.csv,csv,Encoder Log Sigma^2 enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2 enc_mean.csv,csv,Encoder Mean val_loss.csv,csv,Validation Loss """) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w') ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w') dec_sig_log = open( os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w') enc_sig_log = open( os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w') dec_std_sig_log = open( os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w') enc_std_sig_log = open( os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w') dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w') enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w') # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w') train_log.write('step,time,Train Loss\n') val_log.write('step,time,Validation Loss\n') dkl_log.write('step,time,DKL\n') ll_log.write('step,time,-LL\n') dec_sig_log.write('step,time,Decoder Log Sigma^2\n') enc_sig_log.write('step,time,Encoder Log Sigma^2\n') dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n') enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n') dec_mean_log.write('step,time,Decoder Mean\n') enc_mean_log.write('step,time,Encoder Mean\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): # Define inputs ---------------------------------------------------------- model_input_batch = tf.placeholder( tf.float32, shape=[ options['batch_size'], np.prod(np.array(options['img_shape'])) ], name='enc_inputs') sampler_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size'], options['latent_dims']], name='dec_inputs') log.info('Inputs defined') # Discriminator --------------------------------------------------------- # with tf.variable_scope('disc_scope'): # disc_model = cupboard('fixed_conv_disc')( # pickle.load(open(options['feat_params_path'], 'rb')), # options['num_feat_layers'], # 'discriminator' # ) # VAE ------------------------------------------------------------------- # VAE model # with tf.variable_scope('vae_scope'): vae_model = cupboard('vanilla_vae')( options['p_layers'], options['q_layers'], np.prod(options['img_shape']), options['latent_dims'], options['DKL_weight'], options['sigma_clip'], 'vanilla_vae') # VAE/GAN --------------------------------------------------------------- # vae_gan = cupboard('vae_gan')( # vae_model, # disc_model, # options['img_shape'], # options['input_channels'], # 'vae_scope', # 'disc_scope', # name = 'vae_gan_model' # ) log.info('Model initialized') # Define optimizer optimizer = tf.train.AdamOptimizer(learning_rate=options['lr']) # Define forward pass cost_function = vae_model(model_input_batch) # backpass, grads = vae_gan(model_input_batch, sampler_input_batch, optimizer) log.info('Forward pass graph built') # Define sampler # sampler = vae_gan.sampler sampler = vae_model.build_sampler(sampler_input_batch) log.info('Sampler graph built') # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr']) # train_step = optimizer.minimize(cost_function) log.info('Optimizer graph built') # Get gradients grads = optimizer.compute_gradients(cost_function) grads = [gv for gv in grads if gv[0] != None] grad_tensors = [gv[0] for gv in grads] # Clip gradients clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads] # Update op backpass = optimizer.apply_gradients(clip_grads) # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') test_LL_and_DKL(sess, test_provider, feat_vae.vae.DKL, feat_vae.vae.rec_loss, options, model_input_batch) return mean_img = np.load( os.path.join(options['data_dir'], 'mean' + options['extension'])) std_img = np.load( os.path.join(options['data_dir'], 'std' + options['extension'])) visualize(sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq, sampler, sampler_input_batch, model_input_batch, feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq, train_provider, val_provider, options, catalog, mean_img, std_img) return else: sess.run(init_op) log.info('Shared variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 D_to_G = options['D_to_G'] total_D2G = sum(D_to_G) for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs, _ in train_provider: batch_abs_idx += 1 batch_rel_idx += 1 # if batch_abs_idx < options['initial_G_iters']: # optimizer = vae_optimizer # else: # optimizer = disc_optimizer # if batch_abs_idx % total_D2G < D_to_G[0]: # optimizer = disc_optimizer # else: # optimizer = vae_optimizer result = sess.run([ cost_function, backpass, vae_model.DKL, vae_model.rec_loss, vae_model.dec_log_std_sq, vae_model.enc_log_std_sq, vae_model.enc_mean, vae_model.dec_mean, ] + [gv[0] for gv in grads], feed_dict={model_input_batch: inputs}) # print('#'*80) # print(result[-1]) # print('#'*80) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(last_losses))) dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2]))) ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3]))) train_log.flush() dkl_log.flush() ll_log.flush() dec_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[4]))) enc_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[5]))) # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_sig_log.flush() enc_sig_log.flush() dec_std_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.std(result[4]))) enc_std_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.std(result[5]))) dec_mean_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[7]))) enc_mean_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_std_sig_log.flush() enc_std_sig_log.flush() dec_mean_log.flush() enc_mean_log.flush() # val_sig_log.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(model._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses))) log.info('Batch Mean LL: {:0>15.4f}'.format( np.mean(result[3], axis=0))) log.info('Batch Mean -DKL: {:0>15.4f}'.format( np.mean(result[2], axis=0))) # log.info('Batch Mean Acc.: {:0>15.4f}'.format(result[-2], axis=0)) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save( sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: valid_costs = [] seen_batches = 0 for val_batch, _ in val_provider: val_cost = sess.run( vae_model.cost, feed_dict={ model_input_batch: val_batch, sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) valid_costs.append(val_cost) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)))) val_samples = sess.run( sampler, feed_dict={ sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() save_ae_samples(catalog, np.reshape(result[7], [options['batch_size']] + options['img_shape']), np.reshape(inputs, [options['batch_size']] + options['img_shape']), np.reshape(val_samples, [options['batch_size']] + options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True) # save_samples( # val_samples, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'valid_samples'), # True, # options['img_shape'], # 5 # ) # save_samples( # inputs, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'input_sanity'), # True, # options['img_shape'], # num_to_save=5 # ) # save_samples( # result[8], # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'rec_sanity'), # True, # options['img_shape'], # num_to_save=5 # ) log.info('End of epoch {}'.format(epoch_idx + 1)) # Test Model -------------------------------------------------------------------------- test_results = [] for inputs in test_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_results = sess.run([ feat_vae.vae.DKL, feat_vae.vae.rec_loss, feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq, feat_vae.vae.dec_mean, feat_vae.vae.enc_mean ], feed_dict={model_input_batch: inputs}) test_results.append( map( lambda p: np.mean(p, axis=1) if len(p.shape) > 1 else np.mean(p), batch_results)) test_results = map(list, zip(*test_results)) # Print results log.info('Test Mean Rec. Loss: {:0>15.4f}'.format( float(np.mean(test_results[1])))) log.info('Test DKL: {:0>15.4f}'.format(float(np.mean( test_results[0])))) log.info('Test Dec. Mean Log Std Sq: {:0>15.4f}'.format( float(np.mean(test_results[2])))) log.info('Test Enc. Mean Log Std Sq: {:0>15.4f}'.format( float(np.mean(test_results[3])))) log.info('Test Dec. Mean Mean: {:0>15.4f}'.format( float(np.mean(test_results[4])))) log.info('Test Enc. Mean Mean: {:0>15.4f}'.format( float(np.mean(test_results[5]))))
def visualize(sampler_mean, sess, dec_mean, dec_log_std_sq, sampler, sampler_input_batch, model_input_batch, enc_mean, enc_log_std_sq, train_provider, val_provider, options, catalog, mean_img, std_img): from numpy.random import multivariate_normal as MVN, uniform mean_img = mean_img.flatten() std_img = std_img.flatten() # Validation Samples -------------------------------------------------------------------------- print('Generate Samples from N(0,I)') val_samples = sess.run( sampler_mean, feed_dict = { sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) val_samples = (val_samples * std_img) + mean_img for inputs in val_provider: break if isinstance(inputs, tuple): inputs = inputs[0] rec_samples = sess.run( dec_mean, feed_dict = { model_input_batch: inputs } ) # Reconstruction Samples -------------------------------------------------------------------------- print('Generate Reconstruction Samples') print("NOT STUCK HERE!") # recons = [] # for i, temp in enumerate(zip(rec_samples[0], rec_samples[1])): # mean, log_std_sq = temp # std = np.exp(0.5 * log_std_sq) # recons.append( # std * MVN( # np.zeros(mean.shape[0]), # np.diag(np.ones(std.shape[0])) # ) + mean # ) # print(i) # print("NOT STUCK HERE!") # recons = np.array(recons) recons = rec_samples recons = (recons * std_img) + mean_img inputs = (inputs * std_img) + mean_img print("NOT STUCK HERE!") try: os.mkdir(options['visu_save_dir']) except: pass save_ae_samples( catalog, np.reshape(recons, [options['batch_size']]+options['img_shape']), np.reshape(inputs, [options['batch_size']]+options['img_shape']), np.reshape(val_samples, [options['batch_size']]+options['img_shape']), 100, options['visu_save_dir'], num_to_save=10, save_gray=True ) save_samples( val_samples, int(0), options['visu_save_dir'], True, options['img_shape'], 10 ) save_samples( recons, int(1), options['visu_save_dir'], True, options['img_shape'], 10 ) save_samples( inputs, int(2), options['visu_save_dir'], True, options['img_shape'], 10 ) # Gaussian Sampling -------------------------------------------------------------------------- print('Fit Gaussian to Samples') enc_samples = None for i, inputs in enumerate(train_provider): if isinstance(inputs, tuple): inputs = inputs[0] if i == 11: break encs = sess.run( enc_mean + tf.random_normal(enc_mean.get_shape()) * tf.exp(0.5 * enc_log_std_sq), feed_dict = { model_input_batch: inputs } ) # codes = [] # for i, temp in enumerate(zip(encs[0], encs[1])): # mean, log_std_sq = temp # var = np.exp(log_std_sq) # codes.append(MVN( # mean, # np.diag(var) # )) # codes = np.array(codes) codes = encs if enc_samples == None: enc_samples = codes else: enc_samples = np.concatenate((enc_samples, codes)) mean = np.mean(enc_samples, axis=0) std = np.std(enc_samples, axis=0) print("Generate new samples from Gaussian") val_samples = sess.run( sampler_mean, feed_dict = { sampler_input_batch: MVN( mean, np.diag(std), size = options['batch_size'] ) } ) val_samples = (val_samples * std_img) + mean_img save_samples( val_samples, int(3), options['visu_save_dir'], True, options['img_shape'], 10 )
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write('Log Sigma^2 clipped to: [{}, {}]\n\n'.format( -options['sigma_clip'], options['sigma_clip'])) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write("""filename,type,name options,plain,Options train_loss.csv,csv,Discriminator Cross-Entropy train_acc.csv,csv,Discriminator Accuracy val_loss.csv,csv,Validation Cross-Entropy val_acc.csv,csv,Validation Accuracy """) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') train_acc = open(os.path.join(options['dashboard_dir'], 'train_acc.csv'), 'w') val_acc = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'), 'w') train_log.write( 'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n' ) val_log.write( 'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n' ) train_acc.write( 'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n' ) val_acc.write( 'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation Acc. (Training Disc.)\n' ) # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): # Define inputs ------------------------------------------------------------------------- real_batch = tf.placeholder(tf.float32, shape=[ options['batch_size'], np.prod(np.array(options['img_shape'])) ], name='real_inputs') sampler_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size'], options['latent_dims']], name='noise_channel') labels = tf.constant( np.expand_dims(np.concatenate((np.ones( options['batch_size']), np.zeros(options['batch_size'])), axis=0).astype(np.float32), axis=1)) labels = tf.cast(labels, tf.float32) log.info('Inputs defined') # Define model -------------------------------------------------------------------------- with tf.variable_scope('gen_scope'): generator = Sequential('generator') generator += FullyConnected(options['latent_dims'], 60, tf.nn.tanh, name='fc_1') generator += FullyConnected(60, 60, tf.nn.tanh, name='fc_2') generator += FullyConnected(60, np.prod(options['img_shape']), tf.nn.tanh, name='fc_3') sampler = generator(sampler_input_batch) with tf.variable_scope('disc_scope'): disc_model = cupboard('fixed_conv_disc')( pickle.load(open(options['disc_params_path'], 'rb')), options['num_feat_layers'], name='disc_model') disc_inputs = tf.concat(0, [real_batch, sampler]) disc_inputs = tf.reshape( disc_inputs, [disc_inputs.get_shape()[0].value] + options['img_shape'] + [options['input_channels']]) preds = disc_model(disc_inputs) preds = tf.clip_by_value(preds, 0.00001, 0.99999) # Disc Accuracy disc_accuracy = ( 1 / float(labels.get_shape()[0].value)) * tf.reduce_sum( tf.cast(tf.equal(tf.round(preds), labels), tf.float32)) # Define Losses ------------------------------------------------------------------------- # Discrimnator Cross-Entropy disc_CE = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum( -tf.add(tf.mul(labels, tf.log(preds)), tf.mul(1.0 - labels, tf.log(1.0 - preds)))) gen_loss = -tf.mul(1.0 - labels, tf.log(preds)) # Define Optimizers --------------------------------------------------------------------- optimizer = tf.train.AdamOptimizer(learning_rate=options['lr']) # Get Generator and Disriminator Trainable Variables gen_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'gen_scope') disc_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'disc_scope') # Get generator gradients grads = optimizer.compute_gradients(gen_loss, gen_train_vars) grads = [gv for gv in grads if gv[0] != None] clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='gen_grad_clipping'), gv[1]) for gv in grads] gen_backpass = optimizer.apply_gradients(clip_grads) # Get Dsicriminator gradients grads = optimizer.compute_gradients(disc_CE, disc_train_vars) grads = [gv for gv in grads if gv[0] != None] clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='disc_grad_clipping'), gv[1]) for gv in grads] disc_backpass = optimizer.apply_gradients(clip_grads) log.info('Optimizer graph built') # -------------------------------------------------------------------------------------- # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload_all']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') else: sess.run(init_op) log.info('Variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) last_accs = np.zeros((10)) disc_tracker = np.ones((5000)) batch_abs_idx = 0 D_to_G = options['D_to_G'] total_D2G = sum(D_to_G) base = options['initial_G_iters'] + options['initial_D_iters'] # must_init = True feat_params = pickle.load(open(options['disc_params_path'], 'rb')) for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs in train_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_abs_idx += 1 batch_rel_idx += 1 if batch_abs_idx < options['initial_G_iters']: backpass = gen_backpass log_format_string = '{},{},{},,\n' elif options['initial_G_iters'] <= batch_abs_idx < base: backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: # if np.mean(disc_tracker) < 0.95: # disc_model._disc.layers[-2].re_init_weights(sess) # disc_tracker = np.ones((5000)) if (batch_abs_idx - base) % total_D2G < D_to_G[0]: # if must_init: # # i = 0 # # for j in xrange(options['num_feat_layers']): # # if feat_params[j]['layer_type'] == 'conv': # # disc_model._disc.layers[i].re_init_weights(sess) # # # print('@' * 1000) # # # print(disc_model._disc.layers[i]) # # i += 1 # for dealing with activation function # # elif feat_params[j]['layer_type'] == 'fc': # # disc_model._disc.layers[i].re_init_weights(sess) # # # print('@' * 1000) # # # print(disc_model._disc.layers[i]) # # i += 1 # disc_model._disc.layers[-2].re_init_weights(sess) # # print('@' * 1000) # # print(disc_model._disc.layers[-2]) # must_init = False backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: # must_init = True backpass = gen_backpass log_format_string = '{},{},,{},\n' log_format_string = '{},{},{},,\n' result = sess.run( [disc_CE, backpass, disc_accuracy], feed_dict={ real_batch: inputs, sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) train_acc.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_accs))) train_log.flush() train_acc.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost last_accs = np.roll(last_accs, 1) last_accs[0] = result[-1] disc_tracker = np.roll(disc_tracker, 1) disc_tracker[0] = result[-1] # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses))) log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last accuracies: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_accs))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save( sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: valid_costs = [] valid_accs = [] seen_batches = 0 for val_batch in val_provider: if isinstance(val_batch, tuple): val_batch = val_batch[0] result = sess.run( [disc_CE, disc_accuracy], feed_dict={ real_batch: val_batch, sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) valid_costs.append(result[0]) valid_accs.append(result[-1]) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)))) log.info('Validation accuracies: {:0>15.4f}'.format( float(np.mean(valid_accs)))) val_samples = sess.run( sampler, feed_dict={ sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) val_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_acc.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_accs))) val_log.flush() val_acc.flush() save_ae_samples(catalog, np.ones([options['batch_size']] + options['img_shape']), np.reshape(inputs, [options['batch_size']] + options['img_shape']), np.reshape(val_samples, [options['batch_size']] + options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format( options['DKL_weight'], -options['sigma_clip'], options['sigma_clip'] ) ) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write( """filename,type,name options,plain,Options train_loss.csv,csv,Train Loss ll.csv,csv,Neg. Log-Likelihood dec_log_sig_sq.csv,csv,Decoder Log Simga^2 dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2 dec_mean.csv,csv,Decoder Mean dkl.csv,csv,DKL enc_log_sig_sq.csv,csv,Encoder Log Sigma^2 enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2 enc_mean.csv,csv,Encoder Mean val_loss.csv,csv,Validation Loss """ ) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w') ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w') dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w') enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w') dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w') enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w') dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w') enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w') # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w') train_log.write('step,time,Train Loss\n') val_log.write('step,time,Validation Loss\n') dkl_log.write('step,time,DKL\n') ll_log.write('step,time,-LL\n') dec_sig_log.write('step,time,Decoder Log Sigma^2\n') enc_sig_log.write('step,time,Encoder Log Sigma^2\n') dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n') enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n') dec_mean_log.write('step,time,Decoder Mean\n') enc_mean_log.write('step,time,Encoder Mean\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): model = cupboard(options['model'])( options['p_layers'], options['q_layers'], np.prod(options['img_shape']), options['latent_dims'], options['DKL_weight'], options['sigma_clip'], 'vanilla_vae' ) log.info('Model initialized') # Define inputs model_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], np.prod(np.array(options['img_shape']))], name = 'enc_inputs' ) sampler_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], options['latent_dims']], name = 'dec_inputs' ) log.info('Inputs defined') # Define forward pass cost_function = model(model_input_batch) log.info('Forward pass graph built') # Define sampler sampler = model.build_sampler(sampler_input_batch) log.info('Sampler graph built') # Define optimizer optimizer = tf.train.AdamOptimizer( learning_rate=options['lr'] ) # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr']) train_step = optimizer.minimize(cost_function) # Get gradients grads = optimizer.compute_gradients(cost_function) grad_tensors = [gv[0] for gv in grads] # Clip gradients clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads] # Update op backpass = optimizer.apply_gradients(clip_grads) log.info('Optimizer graph built') # # Get gradients # grad = optimizer.compute_gradients(cost_function) # # Clip gradients # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping') # # Update op # backpass = optimizer.apply_gradients(clipped_grad) # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') # test_LL_and_DKL(sess, test_provider, model.DKL, model.rec_loss, options, model_input_batch) # return # if options['data_dir'] == 'MNIST': # mean_img = np.zeros(np.prod(options['img_shape'])) # std_img = np.ones(np.prod(options['img_shape'])) # else: # mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension'])) # std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension'])) # visualize(model.sampler_mean, sess, model.dec_mean, model.dec_log_std_sq, sampler, sampler_input_batch, # model_input_batch, model.enc_mean, model.enc_log_std_sq, # train_provider, val_provider, options, catalog, mean_img, std_img) # return else: sess.run(init_op) log.info('Shared variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs in train_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_abs_idx += 1 batch_rel_idx += 1 result = sess.run( # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL), # 0 1 2 3 4 5 6 7 8 9 10 [cost_function, backpass, model.DKL, model.rec_loss, model.dec_log_std_sq, model.enc_log_std_sq, model.enc_mean, model.dec_mean], feed_dict = { model_input_batch: inputs } ) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2]))) ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3]))) train_log.flush() dkl_log.flush() ll_log.flush() dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4]))) enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5]))) # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_sig_log.flush() enc_sig_log.flush() dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4]))) enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5]))) dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7]))) enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_std_sig_log.flush() enc_std_sig_log.flush() dec_mean_log.flush() enc_mean_log.flush() # val_sig_log.flush() # print('\n\nENC_MEAN:') # print(result[3]) # print('\n\nENC_STD:') # print(result[2]) # print('\nDEC_MEAN:') # print(result[6]) # print('\nDEC_STD:') # print(result[5]) # print('\n\nENCODER WEIGHTS:') # print(model._encoder.layers[0].weights['w'].eval()) # print('\n\DECODER WEIGHTS:') # print(model._decoder.layers[0].weights['w'].eval()) # print(model._encoder.layers[0].weights['w'].eval()) # print(result[2]) # print(result[3]) # print(result[3]) # print(result[2]) # print(result[-2]) # print(result[-1]) # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(model._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format( epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses) )) log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0))) log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') save_dict = {} # Save encoder params ------------------------------------------------------------------ for i in range(len(model._encoder.layers)): layer_dict = { 'input_dim':model._encoder.layers[i].input_dim, 'output_dim':model._encoder.layers[i].output_dim, 'act_fn':model._encoder.layers[i].activation, 'W':model._encoder.layers[i].weights['w'].eval(), 'b':model._encoder.layers[i].weights['b'].eval() } save_dict['encoder'] = layer_dict layer_dict = { 'input_dim':model._enc_mean.input_dim, 'output_dim':model._enc_mean.output_dim, 'act_fn':model._enc_mean.activation, 'W':model._enc_mean.weights['w'].eval(), 'b':model._enc_mean.weights['b'].eval() } save_dict['enc_mean'] = layer_dict layer_dict = { 'input_dim':model._enc_log_std_sq.input_dim, 'output_dim':model._enc_log_std_sq.output_dim, 'act_fn':model._enc_log_std_sq.activation, 'W':model._enc_log_std_sq.weights['w'].eval(), 'b':model._enc_log_std_sq.weights['b'].eval() } save_dict['enc_log_std_sq'] = layer_dict # Save decoder params ------------------------------------------------------------------ for i in range(len(model._decoder.layers)): layer_dict = { 'input_dim':model._decoder.layers[i].input_dim, 'output_dim':model._decoder.layers[i].output_dim, 'act_fn':model._decoder.layers[i].activation, 'W':model._decoder.layers[i].weights['w'].eval(), 'b':model._decoder.layers[i].weights['b'].eval() } save_dict['decoder'] = layer_dict layer_dict = { 'input_dim':model._dec_mean.input_dim, 'output_dim':model._dec_mean.output_dim, 'act_fn':model._dec_mean.activation, 'W':model._dec_mean.weights['w'].eval(), 'b':model._dec_mean.weights['b'].eval() } save_dict['dec_mean'] = layer_dict layer_dict = { 'input_dim':model._dec_log_std_sq.input_dim, 'output_dim':model._dec_log_std_sq.output_dim, 'act_fn':model._dec_log_std_sq.activation, 'W':model._dec_log_std_sq.weights['w'].eval(), 'b':model._dec_log_std_sq.weights['b'].eval() } save_dict['dec_log_std_sq'] = layer_dict pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'vae_dict_%d' % batch_abs_idx), 'wb')) # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: model._decoder.layers[0].weights['w'].eval()[:5,:5] valid_costs = [] seen_batches = 0 for val_batch in val_provider: if isinstance(val_batch, tuple): val_batch = val_batch[0] val_cost = sess.run( cost_function, feed_dict = { model_input_batch: val_batch } ) valid_costs.append(val_cost) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)) )) val_samples = sess.run( sampler, feed_dict = { sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() save_ae_samples( catalog, np.reshape(result[7], [options['batch_size']]+options['img_shape']), np.reshape(inputs, [options['batch_size']]+options['img_shape']), np.reshape(val_samples, [options['batch_size']]+options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True ) # save_dash_samples( # catalog, # val_samples, # batch_abs_idx, # options['dashboard_dir'], # flat_samples=True, # img_shape=options['img_shape'], # num_to_save=5 # ) # save_samples( # val_samples, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'valid_samples'), # True, # options['img_shape'], # 5 # ) # save_samples( # inputs, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'input_sanity'), # True, # options['img_shape'], # num_to_save=5 # ) # save_samples( # result[7], # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'rec_sanity'), # True, # options['img_shape'], # num_to_save=5 # ) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write("""filename,type,name options,plain,Options train_loss.csv,csv,Train Loss val_loss.csv,csv,Validation Loss """) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') train_log.write('step,time,Train Loss\n') val_log.write('step,time,Validation Loss\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log) # Initialize model ------------------------------------------------------------------ # input_shape, input_channels, enc_params, dec_params, name='' with tf.device('/gpu:0'): if options['model'] == 'cnn_ae': model = cupboard(options['model'])(options['img_shape'], options['input_channels'], options['enc_params'], options['dec_params'], 'cnn_ae') # Define inputs model_clean_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size']] + options['img_shape'] + [options['input_channels']], name='clean') model_noisy_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size']] + options['img_shape'] + [options['input_channels']], name='noisy') log.info('Inputs defined') else: model = cupboard(options['model'])( np.prod(options['img_shape']) * options['input_channels'], options['enc_params'], options['dec_params'], 'ae') # Define inputs model_clean_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']], name='clean') model_noisy_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size']] + [np.prod(options['img_shape']) * options['input_channels']], name='noisy') log.info('Inputs defined') log.info('Model initialized') # Define forward pass print(model_clean_input_batch.get_shape()) print(model_noisy_input_batch.get_shape()) cost_function = model(model_clean_input_batch, model_noisy_input_batch) log.info('Forward pass graph built') log.info('Sampler graph built') # Define optimizer optimizer = tf.train.AdamOptimizer(learning_rate=options['lr']) # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr']) # train_step = optimizer.minimize(cost_function) log.info('Optimizer graph built') # Get gradients grads = optimizer.compute_gradients(cost_function) grads = [gv for gv in grads if gv[0] != None] grad_tensors = [gv[0] for gv in grads] # Clip gradients clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads] # Update op backpass = optimizer.apply_gradients(clip_grads) # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload']: saver.restore(sess, os.path.join(options['model_dir'], 'model.ckpt')) log.info('Shared variables restored') else: sess.run(init_op) log.info('Shared variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs, _ in train_provider: batch_abs_idx += 1 batch_rel_idx += 1 result = sess.run( [cost_function, backpass] + [gv[0] for gv in grads], feed_dict = { model_clean_input_batch: inputs, model_noisy_input_batch: np.float32(inputs) + \ normal( loc=0.0, scale=np.float32(options['noise_std']), size=inputs.shape ) } ) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(last_losses))) train_log.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(model._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save( sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Save Encoder Params save_dict = { 'enc_W': [], 'enc_b': [], 'enc_act_fn': [], } if options['model'] == 'cnn_ae': pass else: for i in range(len(model._encoder.layers)): save_dict['enc_W'].append( model._encoder.layers[i].weights['w'].eval()) save_dict['enc_b'].append( model._encoder.layers[i].weights['b'].eval()) save_dict['enc_act_fn'].append( options['enc_params']['act_fn'][i]) pickle.dump( save_dict, open( os.path.join(options['model_dir'], 'enc_dict_%d' % batch_abs_idx), 'wb')) # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: model._decoder.layers[0].weights['w'].eval()[:5, :5] valid_costs = [] seen_batches = 0 for val_batch, _ in val_provider: noisy_val_batch = val_batch + \ normal( loc=0.0, scale=np.float32(options['noise_std']), size=val_batch.shape ) val_results = sess.run( (cost_function, model.decoder), feed_dict={ model_clean_input_batch: val_batch, model_noisy_input_batch: noisy_val_batch }) valid_costs.append(val_results[0]) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)))) val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() if options['model'] == 'conv_ae': val_recon = np.reshape(val_results[-1], val_batch.shape) else: val_batch = np.reshape( val_batch, [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]) noisy_val_batch = np.reshape( noisy_val_batch, [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]) val_recon = np.reshape( val_results[-1], [val_batch.shape[0]] + options['img_shape'] + [options['input_channels']]) save_ae_samples(catalog, val_batch, noisy_val_batch, val_recon, batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True) # save_samples( # val_recon, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'valid_samples'), # False, # options['img_shape'], # 5 # ) # save_samples( # inputs, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'input_sanity'), # False, # options['img_shape'], # num_to_save=5 # ) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format( options['DKL_weight'], -options['sigma_clip'], options['sigma_clip'] ) ) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write( """filename,type,name options,plain,Options train_loss.csv,csv,Train Loss ll.csv,csv,Neg. Log-Likelihood dec_log_sig_sq.csv,csv,Decoder Log Simga^2 dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2 dec_mean.csv,csv,Decoder Mean dkl.csv,csv,DKL enc_log_sig_sq.csv,csv,Encoder Log Sigma^2 enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2 enc_mean.csv,csv,Encoder Mean val_loss.csv,csv,Validation Loss """ ) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w') ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w') dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w') enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w') dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w') enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w') dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w') enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w') # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w') train_log.write('step,time,Train Loss\n') val_log.write('step,time,Validation Loss\n') dkl_log.write('step,time,DKL\n') ll_log.write('step,time,-LL\n') dec_sig_log.write('step,time,Decoder Log Sigma^2\n') enc_sig_log.write('step,time,Encoder Log Sigma^2\n') dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n') enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n') dec_mean_log.write('step,time,Decoder Mean\n') enc_mean_log.write('step,time,Encoder Mean\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider if options['data_dir'] != 'MNIST': num_data_points = len( os.listdir( os.path.join(options['data_dir'], 'train', 'patches') ) ) num_data_points -= 2 train_provider = DataProvider( num_data_points, options['batch_size'], toolbox.ImageLoader( data_dir = os.path.join(options['data_dir'], 'train', 'patches'), flat=True, extension=options['file_extension'] ) ) # Valid provider num_data_points = len( os.listdir( os.path.join(options['data_dir'], 'valid', 'patches') ) ) num_data_points -= 2 val_provider = DataProvider( num_data_points, options['batch_size'], toolbox.ImageLoader( data_dir = os.path.join(options['data_dir'], 'valid', 'patches'), flat = True, extension=options['file_extension'] ) ) else: train_provider = DataProvider( 55000, options['batch_size'], toolbox.MNISTLoader( mode='train', flat=True ) ) val_provider = DataProvider( 5000, options['batch_size'], toolbox.MNISTLoader( mode='validation', flat = True ) ) log.info('Data providers initialized.') # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): # Feature Extractor ----------------------------------------------------- # feat_params = pickle.load(open(options['feat_params_path'], 'rb')) # _classifier = Sequential('CNN_Classifier') # _classifier += ConvLayer( # options['input_channels'], # feat_params[0]['n_filters_out'], # feat_params[0]['input_dim'], # feat_params[0]['filter_dim'], # feat_params[0]['strides'], # name='classifier_conv_0' # ) # _classifier += feat_params[0]['act_fn'] # _classifier.layers[-2].weights['W'] = tf.constant(feat_params[0]['W']) # _classifier.layers[-2].weights['b'] = tf.constant(feat_params[0]['b']) # print("1 conv layer") # i = 1 # while i < options['num_feat_layers']: # if 'filter_dim' in feat_params[i]: # _classifier += ConvLayer( # feat_params[i]['n_filters_in'], # feat_params[i]['n_filters_out'], # feat_params[i]['input_dim'], # feat_params[i]['filter_dim'], # feat_params[i]['strides'], # name='classifier_conv_0' # ) # _classifier += feat_params[i]['act_fn'] # _classifier.layers[-2].weights['W'] = tf.constant(feat_params[i]['W']) # _classifier.layers[-2].weights['b'] = tf.constant(feat_params[i]['b']) # print("1 conv layer") # else: # _classifier += ConstFC( # feat_params[i]['W'], # feat_params[i]['b'], # activation=feat_params[i]['act_fn'], # name='classifier_fc_0' # ) # print("1 fc layer") # i += 1 if options['feat_type'] == 'fc': feat_model = Sequential('feat_extractor') feat_params = pickle.load(open(options['feat_params_path'], 'rb')) for i in range(options['num_feat_layers']): feat_model += ConstFC( feat_params['enc_W'][i], feat_params['enc_b'][i], activation=feat_params['enc_act_fn'][i], name='feat_layer_%d'%i ) else: pass # VAE ------------------------------------------------------------------- # VAE model vae_model = cupboard('vanilla_vae')( options['p_layers'], options['q_layers'], np.prod(options['img_shape']), options['latent_dims'], options['DKL_weight'], options['sigma_clip'], 'vanilla_vae' ) # ----------------------------------------------------------------------- feat_vae = cupboard('feat_vae')( vae_model, feat_model, options['DKL_weight'], 0.0, img_shape=options['img_shape'], input_channels=options['input_channels'], flat=True, name='feat_vae_model' ) log.info('Model initialized') # Define inputs model_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], np.prod(np.array(options['img_shape']))], name = 'enc_inputs' ) sampler_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], options['latent_dims']], name = 'dec_inputs' ) log.info('Inputs defined') # Define forward pass cost_function = feat_vae(model_input_batch) log.info('Forward pass graph built') # Define sampler sampler = feat_vae.build_sampler(sampler_input_batch) log.info('Sampler graph built') # Define optimizer optimizer = tf.train.AdamOptimizer( learning_rate=options['lr'] ) # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr']) # train_step = optimizer.minimize(cost_function) log.info('Optimizer graph built') # Get gradients grads = optimizer.compute_gradients(cost_function) grads = [gv for gv in grads if gv[0] != None] grad_tensors = [gv[0] for gv in grads] # Clip gradients clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads] # Update op backpass = optimizer.apply_gradients(clip_grads) # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') if options['data_dir'] == 'MNIST': mean_img = np.zeros(np.prod(options['img_shape'])) std_img = np.ones(np.prod(options['img_shape'])) else: mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension'])) std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension'])) visualize(feat_vae.vae.sampler_mean, sess, feat_vae.vae.dec_mean, feat_vae.vae.dec_log_std_sq, sampler, sampler_input_batch, model_input_batch, feat_vae.vae.enc_mean, feat_vae.vae.enc_log_std_sq, train_provider, val_provider, options, catalog, mean_img, std_img) return else: sess.run(init_op) log.info('Shared variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs in train_provider: batch_abs_idx += 1 batch_rel_idx += 1 result = sess.run( # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL), # 0 1 2 3 4 5 6 7 8 9 10 [cost_function, backpass, feat_vae.vae.DKL, feat_vae.vae.rec_loss, feat_vae.vae.dec_log_std_sq, feat_vae.vae.enc_log_std_sq, feat_vae.vae.enc_mean, feat_vae.vae.dec_mean] + [gv[0] for gv in grads], feed_dict = { model_input_batch: inputs } ) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2]))) ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3]))) train_log.flush() dkl_log.flush() ll_log.flush() dec_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[4]))) enc_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[5]))) # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_sig_log.flush() enc_sig_log.flush() dec_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[4]))) enc_std_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.std(result[5]))) dec_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[7]))) enc_mean_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_std_sig_log.flush() enc_std_sig_log.flush() dec_mean_log.flush() enc_mean_log.flush() # val_sig_log.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(model._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format( epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses) )) log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0))) log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: valid_costs = [] seen_batches = 0 for val_batch in val_provider: val_cost = sess.run( cost_function, feed_dict = { model_input_batch: val_batch } ) valid_costs.append(val_cost) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)) )) val_samples = sess.run( sampler, feed_dict = { sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() save_ae_samples( catalog, np.reshape(result[7], [options['batch_size']]+options['img_shape']), np.reshape(inputs, [options['batch_size']]+options['img_shape']), np.reshape(val_samples, [options['batch_size']]+options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True ) # save_dash_samples( # catalog, # val_samples, # batch_abs_idx, # options['dashboard_dir'], # flat_samples=True, # img_shape=options['img_shape'], # num_to_save=5 # ) save_samples( val_samples, int(batch_abs_idx/options['freq_validation']), os.path.join(options['model_dir'], 'valid_samples'), True, options['img_shape'], 5 ) save_samples( inputs, int(batch_abs_idx/options['freq_validation']), os.path.join(options['model_dir'], 'input_sanity'), True, options['img_shape'], num_to_save=5 ) save_samples( result[7], int(batch_abs_idx/options['freq_validation']), os.path.join(options['model_dir'], 'rec_sanity'), True, options['img_shape'], num_to_save=5 ) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format( options['DKL_weight'], -options['sigma_clip'], options['sigma_clip'])) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write("""filename,type,name options,plain,Options train_loss.csv,csv,Train Loss ll.csv,csv,Neg. Log-Likelihood dec_log_sig_sq.csv,csv,Decoder Log Simga^2 dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2 dec_mean.csv,csv,Decoder Mean dkl.csv,csv,DKL enc_log_sig_sq.csv,csv,Encoder Log Sigma^2 enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2 enc_mean.csv,csv,Encoder Mean val_loss.csv,csv,Validation Loss """) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w') ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w') dec_sig_log = open( os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w') enc_sig_log = open( os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w') dec_std_sig_log = open( os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w') enc_std_sig_log = open( os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w') dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w') enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w') # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w') train_log.write('step,time,Train Loss\n') val_log.write('step,time,Validation Loss\n') dkl_log.write('step,time,DKL\n') ll_log.write('step,time,-LL\n') dec_sig_log.write('step,time,Decoder Log Sigma^2\n') enc_sig_log.write('step,time,Encoder Log Sigma^2\n') dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2\n') enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2\n') dec_mean_log.write('step,time,Decoder Mean\n') enc_mean_log.write('step,time,Encoder Mean\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): model = cupboard(options['model'])( options['p_layers'], options['q_layers'], np.prod(options['img_shape']), options['latent_dims'], options['DKL_weight'], options['sigma_clip'], 'vanilla_vae') log.info('Model initialized') # Define inputs model_input_batch = tf.placeholder( tf.float32, shape=[ options['batch_size'], np.prod(np.array(options['img_shape'])) ], name='enc_inputs') sampler_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size'], options['latent_dims']], name='dec_inputs') log.info('Inputs defined') # Define forward pass cost_function = model(model_input_batch) log.info('Forward pass graph built') # Define sampler sampler = model.build_sampler(sampler_input_batch) log.info('Sampler graph built') # Define optimizer optimizer = tf.train.AdamOptimizer(learning_rate=options['lr']) # optimizer = tf.train.GradientDescentOptimizer(learning_rate=options['lr']) train_step = optimizer.minimize(cost_function) # Get gradients grads = optimizer.compute_gradients(cost_function) grad_tensors = [gv[0] for gv in grads] # Clip gradients clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='grad_clipping'), gv[1]) for gv in grads] # Update op backpass = optimizer.apply_gradients(clip_grads) log.info('Optimizer graph built') # # Get gradients # grad = optimizer.compute_gradients(cost_function) # # Clip gradients # clipped_grad = tf.clip_by_norm(grad, 5.0, name='grad_clipping') # # Update op # backpass = optimizer.apply_gradients(clipped_grad) # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') # test_LL_and_DKL(sess, test_provider, model.DKL, model.rec_loss, options, model_input_batch) # return # if options['data_dir'] == 'MNIST': # mean_img = np.zeros(np.prod(options['img_shape'])) # std_img = np.ones(np.prod(options['img_shape'])) # else: # mean_img = np.load(os.path.join(options['data_dir'], 'mean' + options['extension'])) # std_img = np.load(os.path.join(options['data_dir'], 'std' + options['extension'])) # visualize(model.sampler_mean, sess, model.dec_mean, model.dec_log_std_sq, sampler, sampler_input_batch, # model_input_batch, model.enc_mean, model.enc_log_std_sq, # train_provider, val_provider, options, catalog, mean_img, std_img) # return else: sess.run(init_op) log.info('Shared variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs in train_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_abs_idx += 1 batch_rel_idx += 1 result = sess.run( # (cost_function, train_step, model.enc_std, model.enc_mean, model.encoder, model.dec_std, model.dec_mean, model.decoder, model.rec_loss, model.DKL), # 0 1 2 3 4 5 6 7 8 9 10 [ cost_function, backpass, model.DKL, model.rec_loss, model.dec_log_std_sq, model.enc_log_std_sq, model.enc_mean, model.dec_mean ], feed_dict={model_input_batch: inputs}) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(last_losses))) dkl_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[2]))) ll_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', -np.mean(result[3]))) train_log.flush() dkl_log.flush() ll_log.flush() dec_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[4]))) enc_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[5]))) # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_sig_log.flush() enc_sig_log.flush() dec_std_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.std(result[4]))) enc_std_sig_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.std(result[5]))) dec_mean_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[7]))) enc_mean_log.write('{},{},{}\n'.format( batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_std_sig_log.flush() enc_std_sig_log.flush() dec_mean_log.flush() enc_mean_log.flush() # val_sig_log.flush() # print('\n\nENC_MEAN:') # print(result[3]) # print('\n\nENC_STD:') # print(result[2]) # print('\nDEC_MEAN:') # print(result[6]) # print('\nDEC_STD:') # print(result[5]) # print('\n\nENCODER WEIGHTS:') # print(model._encoder.layers[0].weights['w'].eval()) # print('\n\DECODER WEIGHTS:') # print(model._decoder.layers[0].weights['w'].eval()) # print(model._encoder.layers[0].weights['w'].eval()) # print(result[2]) # print(result[3]) # print(result[3]) # print(result[2]) # print(result[-2]) # print(result[-1]) # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(model._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses))) log.info('Batch Mean LL: {:0>15.4f}'.format( np.mean(result[3], axis=0))) log.info('Batch Mean -DKL: {:0>15.4f}'.format( np.mean(result[2], axis=0))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save( sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') save_dict = {} # Save encoder params ------------------------------------------------------------------ for i in range(len(model._encoder.layers)): layer_dict = { 'input_dim': model._encoder.layers[i].input_dim, 'output_dim': model._encoder.layers[i].output_dim, 'act_fn': model._encoder.layers[i].activation, 'W': model._encoder.layers[i].weights['w'].eval(), 'b': model._encoder.layers[i].weights['b'].eval() } save_dict['encoder'] = layer_dict layer_dict = { 'input_dim': model._enc_mean.input_dim, 'output_dim': model._enc_mean.output_dim, 'act_fn': model._enc_mean.activation, 'W': model._enc_mean.weights['w'].eval(), 'b': model._enc_mean.weights['b'].eval() } save_dict['enc_mean'] = layer_dict layer_dict = { 'input_dim': model._enc_log_std_sq.input_dim, 'output_dim': model._enc_log_std_sq.output_dim, 'act_fn': model._enc_log_std_sq.activation, 'W': model._enc_log_std_sq.weights['w'].eval(), 'b': model._enc_log_std_sq.weights['b'].eval() } save_dict['enc_log_std_sq'] = layer_dict # Save decoder params ------------------------------------------------------------------ for i in range(len(model._decoder.layers)): layer_dict = { 'input_dim': model._decoder.layers[i].input_dim, 'output_dim': model._decoder.layers[i].output_dim, 'act_fn': model._decoder.layers[i].activation, 'W': model._decoder.layers[i].weights['w'].eval(), 'b': model._decoder.layers[i].weights['b'].eval() } save_dict['decoder'] = layer_dict layer_dict = { 'input_dim': model._dec_mean.input_dim, 'output_dim': model._dec_mean.output_dim, 'act_fn': model._dec_mean.activation, 'W': model._dec_mean.weights['w'].eval(), 'b': model._dec_mean.weights['b'].eval() } save_dict['dec_mean'] = layer_dict layer_dict = { 'input_dim': model._dec_log_std_sq.input_dim, 'output_dim': model._dec_log_std_sq.output_dim, 'act_fn': model._dec_log_std_sq.activation, 'W': model._dec_log_std_sq.weights['w'].eval(), 'b': model._dec_log_std_sq.weights['b'].eval() } save_dict['dec_log_std_sq'] = layer_dict pickle.dump( save_dict, open( os.path.join(options['model_dir'], 'vae_dict_%d' % batch_abs_idx), 'wb')) # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: model._decoder.layers[0].weights['w'].eval()[:5, :5] valid_costs = [] seen_batches = 0 for val_batch in val_provider: if isinstance(val_batch, tuple): val_batch = val_batch[0] val_cost = sess.run( cost_function, feed_dict={model_input_batch: val_batch}) valid_costs.append(val_cost) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)))) val_samples = sess.run( sampler, feed_dict={ sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) val_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() save_ae_samples(catalog, np.reshape(result[7], [options['batch_size']] + options['img_shape']), np.reshape(inputs, [options['batch_size']] + options['img_shape']), np.reshape(val_samples, [options['batch_size']] + options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True) # save_dash_samples( # catalog, # val_samples, # batch_abs_idx, # options['dashboard_dir'], # flat_samples=True, # img_shape=options['img_shape'], # num_to_save=5 # ) # save_samples( # val_samples, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'valid_samples'), # True, # options['img_shape'], # 5 # ) # save_samples( # inputs, # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'input_sanity'), # True, # options['img_shape'], # num_to_save=5 # ) # save_samples( # result[7], # int(batch_abs_idx/options['freq_validation']), # os.path.join(options['model_dir'], 'rec_sanity'), # True, # options['img_shape'], # num_to_save=5 # ) log.info('End of epoch {}'.format(epoch_idx + 1))
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write( 'DKL Weight: {}\nLog Sigma^2 clipped to: [{}, {}]\n\n'.format( options['DKL_weight'], -options['sigma_clip'], options['sigma_clip'] ) ) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write( """filename,type,name options,plain,Options train_loss.csv,csv,Discriminator Cross-Entropy ll.csv,csv,Neg. Log-Likelihood dec_log_sig_sq.csv,csv,Decoder Log Simga^2 dec_std_log_sig_sq.csv,csv,STD of Decoder Log Simga^2 dec_mean.csv,csv,Decoder Mean dkl.csv,csv,DKL enc_log_sig_sq.csv,csv,Encoder Log Sigma^2 enc_std_log_sig_sq.csv,csv,STD of Encoder Log Sigma^2 enc_mean.csv,csv,Encoder Mean val_loss.csv,csv,Validation Loss """ ) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') dkl_log = open(os.path.join(options['dashboard_dir'], 'dkl.csv'), 'w') ll_log = open(os.path.join(options['dashboard_dir'], 'll.csv'), 'w') dec_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_log_sig_sq.csv'), 'w') enc_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_log_sig_sq.csv'), 'w') dec_std_sig_log = open(os.path.join(options['dashboard_dir'], 'dec_std_log_sig_sq.csv'), 'w') enc_std_sig_log = open(os.path.join(options['dashboard_dir'], 'enc_std_log_sig_sq.csv'), 'w') dec_mean_log = open(os.path.join(options['dashboard_dir'], 'dec_mean.csv'), 'w') enc_mean_log = open(os.path.join(options['dashboard_dir'], 'enc_mean.csv'), 'w') # val_sig_log = open(os.path.join(options['dashboard_dir'], 'val_log_sig_sq.csv'), 'w') train_log.write('step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n') val_log.write('step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n') dkl_log.write('step,time,DKL (Training Vanilla),DKL (Training Gen.),DKL (Training Disc.)\n') ll_log.write('step,time,-LL (Training Vanilla),-LL (Training Gen.),-LL (Training Disc.)\n') dec_sig_log.write('step,time,Decoder Log Sigma^2 (Training Vanilla),Decoder Log Sigma^2 (Training Gen.),Decoder Log Sigma^2 (Training Disc.)\n') enc_sig_log.write('step,time,Encoder Log Sigma^2 (Training Vanilla),Encoder Log Sigma^2 (Training Gen.),Encoder Log Sigma^2 (Training Disc.)\n') dec_std_sig_log.write('step,time,STD of Decoder Log Sigma^2 (Training Vanilla),STD of Decoder Log Sigma^2 (Training Gen.),STD of Decoder Log Sigma^2 (Training Disc.)\n') enc_std_sig_log.write('step,time,STD of Encoder Log Sigma^2 (Training Vanilla),STD of Encoder Log Sigma^2 (Training Gen.),STD of Encoder Log Sigma^2 (Training Disc.)\n') dec_mean_log.write('step,time,Decoder Mean (Training Vanilla),Decoder Mean (Training Gen.),Decoder Mean (Training Disc.)\n') enc_mean_log.write('step,time,Encoder Mean (Training Vanilla),Encoder Mean (Training Gen.),Encoder Mean (Training Disc.)\n') # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): # Define inputs model_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], np.prod(np.array(options['img_shape']))], name = 'enc_inputs' ) sampler_input_batch = tf.placeholder( tf.float32, shape = [options['batch_size'], options['latent_dims']], name = 'dec_inputs' ) log.info('Inputs defined') # Define model with tf.variable_scope('vae_scope'): vae_model = cupboard('vanilla_vae')( options['p_layers'], options['q_layers'], np.prod(options['img_shape']), options['latent_dims'], options['DKL_weight'], options['sigma_clip'], 'vae_model' ) with tf.variable_scope('disc_scope'): disc_model = cupboard('fixed_conv_disc')( pickle.load(open(options['disc_params_path'], 'rb')), options['num_feat_layers'], name='disc_model' ) vae_gan = cupboard('vae_gan')( vae_model, disc_model, options['disc_weight'], options['img_shape'], options['input_channels'], 'vae_scope', 'disc_scope', name='vae_gan_model' ) # Define Optimizers --------------------------------------------------------------------- optimizer = tf.train.AdamOptimizer( learning_rate=options['lr'] ) vae_backpass, disc_backpass, vanilla_backpass = vae_gan(model_input_batch, sampler_input_batch, optimizer) log.info('Optimizer graph built') # -------------------------------------------------------------------------------------- # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload_all']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') else: sess.run(init_op) log.info('Variables initialized') if options['reload_vae']: vae_model.reload_vae(options['vae_params_path']) # Define last losses to compute a running average last_losses = np.zeros((10)) batch_abs_idx = 0 D_to_G = options['D_to_G'] total_D2G = sum(D_to_G) base = options['initial_G_iters'] + options['initial_D_iters'] for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs in train_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_abs_idx += 1 batch_rel_idx += 1 if batch_abs_idx < options['initial_G_iters']: backpass = vanilla_backpass log_format_string = '{},{},{},,\n' elif options['initial_G_iters'] <= batch_abs_idx < base: backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: if (batch_abs_idx - base) % total_D2G < D_to_G[0]: backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: backpass = vae_backpass log_format_string = '{},{},,{},\n' result = sess.run( [ vae_gan.disc_CE, backpass, vae_gan._vae.DKL, vae_gan._vae.rec_loss, vae_gan._vae.dec_log_std_sq, vae_gan._vae.enc_log_std_sq, vae_gan._vae.enc_mean, vae_gan._vae.dec_mean ], feed_dict = { model_input_batch: inputs, sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) dkl_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', -np.mean(result[2]))) ll_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', -np.mean(result[3]))) train_log.flush() dkl_log.flush() ll_log.flush() dec_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[4]))) enc_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[5]))) # val_sig_log.write('{},{},{}\n'.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_sig_log.flush() enc_sig_log.flush() dec_std_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.std(result[4]))) enc_std_sig_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.std(result[5]))) dec_mean_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[7]))) enc_mean_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(result[6]))) dec_std_sig_log.flush() enc_std_sig_log.flush() dec_mean_log.flush() enc_mean_log.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') for i in range(len(result)): print("\n\nresult[%d]:" % i) try: print(np.any(np.isnan(result[i]))) except: pass print(result[i]) print(result[3].shape) print(vae_gan._vae._encoder.layers[0].weights['w'].eval()) print('\n\nAny:') print(np.any(np.isnan(result[8]))) print(np.any(np.isnan(result[9]))) print(np.any(np.isnan(result[10]))) print(inputs) return 1., 1., 1. # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info('Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}'.format( epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses) )) log.info('Batch Mean LL: {:0>15.4f}'.format(np.mean(result[3], axis=0))) log.info('Batch Mean -DKL: {:0>15.4f}'.format(np.mean(result[2], axis=0))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save(sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') save_dict = {} # Save encoder params ------------------------------------------------------------------ for i in range(len(vae_gan._vae._encoder.layers)): layer_dict = { 'input_dim':vae_gan._vae._encoder.layers[i].input_dim, 'output_dim':vae_gan._vae._encoder.layers[i].output_dim, 'act_fn':vae_gan._vae._encoder.layers[i].activation, 'W':vae_gan._vae._encoder.layers[i].weights['w'].eval(), 'b':vae_gan._vae._encoder.layers[i].weights['b'].eval() } save_dict['encoder'] = layer_dict layer_dict = { 'input_dim':vae_gan._vae._enc_mean.input_dim, 'output_dim':vae_gan._vae._enc_mean.output_dim, 'act_fn':vae_gan._vae._enc_mean.activation, 'W':vae_gan._vae._enc_mean.weights['w'].eval(), 'b':vae_gan._vae._enc_mean.weights['b'].eval() } save_dict['enc_mean'] = layer_dict layer_dict = { 'input_dim':vae_gan._vae._enc_log_std_sq.input_dim, 'output_dim':vae_gan._vae._enc_log_std_sq.output_dim, 'act_fn':vae_gan._vae._enc_log_std_sq.activation, 'W':vae_gan._vae._enc_log_std_sq.weights['w'].eval(), 'b':vae_gan._vae._enc_log_std_sq.weights['b'].eval() } save_dict['enc_log_std_sq'] = layer_dict # Save decoder params ------------------------------------------------------------------ for i in range(len(vae_gan._vae._decoder.layers)): layer_dict = { 'input_dim':vae_gan._vae._decoder.layers[i].input_dim, 'output_dim':vae_gan._vae._decoder.layers[i].output_dim, 'act_fn':vae_gan._vae._decoder.layers[i].activation, 'W':vae_gan._vae._decoder.layers[i].weights['w'].eval(), 'b':vae_gan._vae._decoder.layers[i].weights['b'].eval() } save_dict['decoder'] = layer_dict layer_dict = { 'input_dim':vae_gan._vae._dec_mean.input_dim, 'output_dim':vae_gan._vae._dec_mean.output_dim, 'act_fn':vae_gan._vae._dec_mean.activation, 'W':vae_gan._vae._dec_mean.weights['w'].eval(), 'b':vae_gan._vae._dec_mean.weights['b'].eval() } save_dict['dec_mean'] = layer_dict layer_dict = { 'input_dim':vae_gan._vae._dec_log_std_sq.input_dim, 'output_dim':vae_gan._vae._dec_log_std_sq.output_dim, 'act_fn':vae_gan._vae._dec_log_std_sq.activation, 'W':vae_gan._vae._dec_log_std_sq.weights['w'].eval(), 'b':vae_gan._vae._dec_log_std_sq.weights['b'].eval() } save_dict['dec_log_std_sq'] = layer_dict pickle.dump(save_dict, open(os.path.join(options['model_dir'], 'vae_dict_%d' % batch_abs_idx), 'wb')) # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: vae_gan._vae._decoder.layers[0].weights['w'].eval()[:5,:5] valid_costs = [] seen_batches = 0 for val_batch in val_provider: if isinstance(val_batch, tuple): val_batch = val_batch[0] val_cost = sess.run( vae_gan.disc_CE, feed_dict = { model_input_batch: val_batch, sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) valid_costs.append(val_cost) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)) )) val_samples = sess.run( vae_gan.sampler, feed_dict = { sampler_input_batch: MVN( np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size = options['batch_size'] ) } ) val_log.write(log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_log.flush() save_ae_samples( catalog, np.reshape(result[7], [options['batch_size']]+options['img_shape']), np.reshape(inputs, [options['batch_size']]+options['img_shape']), np.reshape(val_samples, [options['batch_size']]+options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True ) log.info('End of epoch {}'.format(epoch_idx + 1))