def build_decoder(self, input_var): # Build the decoder if len(self.p_layers) > 0: self._decoder = Sequential('vae_decoder') self._decoder += FullyConnected(self.latent_dims, self.p_layers[0], coder_act_fn, name='fc_1') for i in xrange(1, len(self.p_layers)): self._decoder += FullyConnected(self.p_layers[i - 1], self.p_layers[i], coder_act_fn, name='fc_%d' % (i + 1)) self.decoder = self._decoder(input_var) self._dec_mean = FullyConnected(self.p_layers[-1], self.input_dims, dec_mean_act_fn, name='dec_mean') self.dec_mean = self._dec_mean(self.decoder) else: self.decoder = input_var self._dec_mean = FullyConnected(self.latent_dims, self.input_dims, dec_mean_act_fn, name='dec_mean') self.dec_mean = self._dec_mean(self.decoder)
def __call__(self, disc_input): feat_params = self.feat_params self._disc = Sequential('Fixed_Conv_Disc') conv_count, pool_count, fc_count = 0, 0, 0 for i in xrange(self.num_feat_layers): if feat_params[i]['layer_type'] == 'conv': self._disc += ConvLayer(feat_params[i]['n_filters_in'], feat_params[i]['n_filters_out'], feat_params[i]['input_dim'], feat_params[i]['filter_dim'], feat_params[i]['strides'], name='classifier_conv_%d' % conv_count) self._disc.layers[-1].weights['W'] = tf.constant( feat_params[i]['W']) self._disc.layers[-1].weights['b'] = tf.constant( feat_params[i]['b']) self._disc += feat_params[i]['act_fn'] conv_count += 1 elif feat_params[i]['layer_type'] == 'pool': self._disc += PoolLayer(feat_params[i]['input_dim'], feat_params[i]['filter_dim'], feat_params[i]['strides'], name='classifier_pool_%d' % i) pool_count += 1 elif feat_params[i]['layer_type'] == 'fc': # self._disc += FullyConnected( # feat_params[i]['W'].shape[0], # feat_params[i]['W'].shape[1], # activation=tf.nn.tanh, # scale=0.01, # name='classifier_fc_%d' % fc_count # ) self._disc += ConstFC(feat_params[i]['W'], feat_params[i]['b'], activation=feat_params[i]['act_fn'], name='classifier_fc_%d' % fc_count) fc_count += 1 if isinstance(self._disc.layers[-1], ConstFC): disc_input_dim = self._disc.layers[-1].weights['w'].get_shape( )[1].value elif isinstance(self._disc.layers[-1], PoolLayer): disc_input_dim = np.prod(self._disc.layers[-1].output_dim) * ( self._disc.layers[-3].n_filters_out) else: # function after conv layer disc_input_dim = np.prod(self._disc.layers[-1].output_dim) * ( self._disc.layers[-2].n_filters_out) # self._disc += FullyConnected(disc_input_dim, 1024, activation=tf.nn.tanh, scale=0.01, name='disc_fc_0') self._disc += FullyConnected(disc_input_dim, 1, activation=None, scale=0.01, name='disc_logit') self._disc += lambda p: 1.0 / (1.0 + tf.exp(-p)) self.disc = self._disc(disc_input) return self.disc
def __init__(self, input_dims, enc_params, dec_params, name=''): super(AutoEncoder, self).__init__() self.input_dims = input_dims self.enc_params = enc_params self.dec_params = dec_params self.enc_params['act_fn'] = map(lambda p: act_lib[p], self.enc_params['act_fn']) self.dec_params['act_fn'] = map(lambda p: act_lib[p], self.dec_params['act_fn']) self.name = name self._encoder = Sequential(self.name + '_ae_encoder') for i in range(len(enc_params['layer_dims'])): if i == 0: self._encoder += FullyConnected( self.input_dims, self.enc_params['layer_dims'][i], self.enc_params['act_fn'][i], name=self.name + '_e_fc_%d'%(i+1) ) else: self._encoder += FullyConnected( self.enc_params['layer_dims'][i-1], self.enc_params['layer_dims'][i], self.enc_params['act_fn'][i], name=self.name + '_e_fc_%d'%(i+1) ) self._decoder = Sequential(self.name + '_ae_decoder') for i in range(len(self.dec_params['layer_dims'])): if i == 0: self._decoder += FullyConnected( self.enc_params['layer_dims'][-1], self.dec_params['layer_dims'][i], self.dec_params['act_fn'][i], name=self.name + '_d_fc_%d'%(i+1) ) else: self._decoder += FullyConnected( self.dec_params['layer_dims'][i-1], self.dec_params['layer_dims'][i], self.dec_params['act_fn'][i], name=self.name + '_d_fc_%d'%(i+1) )
def build_encoder(self, input_var, params=None): # Build the encoder if len(self.q_layers) > 0: self._encoder = Sequential('vae_encoder') self._encoder += FullyConnected(self.input_dims, self.q_layers[0], coder_act_fn, name='fc_1') for i in xrange(1, len(self.q_layers)): self._encoder += FullyConnected(self.q_layers[i-1], self.q_layers[i], coder_act_fn, name='fc_%d'%(i+1)) self.encoder = self._encoder(input_var) self._enc_mean = FullyConnected(self.q_layers[-1], self.latent_dims, mean_std_act_fn, name='enc_mean') self.enc_mean = self._enc_mean(self.encoder) self._enc_log_std_sq = FullyConnected(self.q_layers[-1], self.latent_dims, mean_std_act_fn, name='enc_std') self.enc_log_std_sq = tf.clip_by_value( self._enc_log_std_sq(self.encoder), -self.sigma_clip, self.sigma_clip ) else: self.encoder = input_var self.enc_mean = FullyConnected(self.input_dims, self.latent_dims, mean_std_act_fn, name='enc_mean') self.enc_mean = self.enc_mean(self.encoder) self.enc_log_std_sq = FullyConnected(self.input_dims, self.latent_dims, mean_std_act_fn, name='enc_std') self.enc_log_std_sq = tf.clip_by_value( self.enc_log_std_sq(self.encoder), -self.sigma_clip, self.sigma_clip )
def __init__(self, input_shape, input_channels, enc_params, dec_params, name=''): """ enc_params: - kernels - strides - num_filters - act_fn dec_params: - layer_dims - act_fn """ super(ConvAutoEncoder, self).__init__() self.input_shape = input_shape self.input_channels = input_channels self.enc_params = enc_params self.dec_params = dec_params self.name = name self.enc_params['act_fn'] = map(lambda p: act_lib[p], self.enc_params['act_fn']) self.dec_params['act_fn'] = map(lambda p: act_lib[p], self.dec_params['act_fn']) # Build the encoder which is fully convolutional and no pooling self._encoder = Sequential(self.name + 'ae_encoder') for i in range(len(self.enc_params['kernels'])): self._encoder += ConvLayer( self.input_channels if i == 0 else self.enc_params['num_filters'][i-1], enc_params['num_filters'][i], self.input_shape if i == 0 else self._encoder.layers[-2].output_dim, self.enc_params['kernels'][i], self.enc_params['strides'][i], name=self.name+'_enc_conv_%d' % (i+1) ) self._encoder += self.enc_params['act_fn'][i] # Build the decoder which is fully connected self._decoder = Sequential(self.name + 'ae_decoder') for i in range(len(self.dec_params['layer_dims'])): self._decoder += FullyConnected( self.enc_params['num_filters'][-1] * np.prod(self._encoder.layers[-2].output_dim) if i == 0 \ else self.dec_params['layer_dims'][i-1], self.dec_params['layer_dims'][i], self.dec_params['act_fn'][i], name=self.name+'_dec_fc_%d' % (i+1) )
def train(options): # Get logger log = utils.get_logger(os.path.join(options['model_dir'], 'log.txt')) options_file = open(os.path.join(options['dashboard_dir'], 'options'), 'w') options_file.write(options['description'] + '\n') options_file.write('Log Sigma^2 clipped to: [{}, {}]\n\n'.format( -options['sigma_clip'], options['sigma_clip'])) for optn in options: options_file.write(optn) options_file.write(':\t') options_file.write(str(options[optn])) options_file.write('\n') options_file.close() # Dashboard Catalog catalog = open(os.path.join(options['dashboard_dir'], 'catalog'), 'w') catalog.write("""filename,type,name options,plain,Options train_loss.csv,csv,Discriminator Cross-Entropy train_acc.csv,csv,Discriminator Accuracy val_loss.csv,csv,Validation Cross-Entropy val_acc.csv,csv,Validation Accuracy """) catalog.flush() train_log = open(os.path.join(options['dashboard_dir'], 'train_loss.csv'), 'w') val_log = open(os.path.join(options['dashboard_dir'], 'val_loss.csv'), 'w') train_acc = open(os.path.join(options['dashboard_dir'], 'train_acc.csv'), 'w') val_acc = open(os.path.join(options['dashboard_dir'], 'val_acc.csv'), 'w') train_log.write( 'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n' ) val_log.write( 'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation CE (Training Disc.)\n' ) train_acc.write( 'step,time,Train CE (Training Vanilla),Train CE (Training Gen.),Train CE (Training Disc.)\n' ) val_acc.write( 'step,time,Validation CE (Training Vanilla),Validation CE (Training Gen.),Validation Acc. (Training Disc.)\n' ) # Print options utils.print_options(options, log) # Load dataset ---------------------------------------------------------------------- # Train provider train_provider, val_provider, test_provider = get_providers(options, log, flat=True) # Initialize model ------------------------------------------------------------------ with tf.device('/gpu:0'): # Define inputs ------------------------------------------------------------------------- real_batch = tf.placeholder(tf.float32, shape=[ options['batch_size'], np.prod(np.array(options['img_shape'])) ], name='real_inputs') sampler_input_batch = tf.placeholder( tf.float32, shape=[options['batch_size'], options['latent_dims']], name='noise_channel') labels = tf.constant( np.expand_dims(np.concatenate((np.ones( options['batch_size']), np.zeros(options['batch_size'])), axis=0).astype(np.float32), axis=1)) labels = tf.cast(labels, tf.float32) log.info('Inputs defined') # Define model -------------------------------------------------------------------------- with tf.variable_scope('gen_scope'): generator = Sequential('generator') generator += FullyConnected(options['latent_dims'], 60, tf.nn.tanh, name='fc_1') generator += FullyConnected(60, 60, tf.nn.tanh, name='fc_2') generator += FullyConnected(60, np.prod(options['img_shape']), tf.nn.tanh, name='fc_3') sampler = generator(sampler_input_batch) with tf.variable_scope('disc_scope'): disc_model = cupboard('fixed_conv_disc')( pickle.load(open(options['disc_params_path'], 'rb')), options['num_feat_layers'], name='disc_model') disc_inputs = tf.concat(0, [real_batch, sampler]) disc_inputs = tf.reshape( disc_inputs, [disc_inputs.get_shape()[0].value] + options['img_shape'] + [options['input_channels']]) preds = disc_model(disc_inputs) preds = tf.clip_by_value(preds, 0.00001, 0.99999) # Disc Accuracy disc_accuracy = ( 1 / float(labels.get_shape()[0].value)) * tf.reduce_sum( tf.cast(tf.equal(tf.round(preds), labels), tf.float32)) # Define Losses ------------------------------------------------------------------------- # Discrimnator Cross-Entropy disc_CE = (1 / float(labels.get_shape()[0].value)) * tf.reduce_sum( -tf.add(tf.mul(labels, tf.log(preds)), tf.mul(1.0 - labels, tf.log(1.0 - preds)))) gen_loss = -tf.mul(1.0 - labels, tf.log(preds)) # Define Optimizers --------------------------------------------------------------------- optimizer = tf.train.AdamOptimizer(learning_rate=options['lr']) # Get Generator and Disriminator Trainable Variables gen_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'gen_scope') disc_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'disc_scope') # Get generator gradients grads = optimizer.compute_gradients(gen_loss, gen_train_vars) grads = [gv for gv in grads if gv[0] != None] clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='gen_grad_clipping'), gv[1]) for gv in grads] gen_backpass = optimizer.apply_gradients(clip_grads) # Get Dsicriminator gradients grads = optimizer.compute_gradients(disc_CE, disc_train_vars) grads = [gv for gv in grads if gv[0] != None] clip_grads = [(tf.clip_by_norm(gv[0], 5.0, name='disc_grad_clipping'), gv[1]) for gv in grads] disc_backpass = optimizer.apply_gradients(clip_grads) log.info('Optimizer graph built') # -------------------------------------------------------------------------------------- # Define init operation init_op = tf.initialize_all_variables() log.info('Variable initialization graph built') # Define op to save and restore variables saver = tf.train.Saver() log.info('Save operation built') # -------------------------------------------------------------------------- # Train loop --------------------------------------------------------------- with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: log.info('Session started') # Initialize shared variables or restore if options['reload_all']: saver.restore(sess, options['reload_file']) log.info('Shared variables restored') else: sess.run(init_op) log.info('Variables initialized') # Define last losses to compute a running average last_losses = np.zeros((10)) last_accs = np.zeros((10)) disc_tracker = np.ones((5000)) batch_abs_idx = 0 D_to_G = options['D_to_G'] total_D2G = sum(D_to_G) base = options['initial_G_iters'] + options['initial_D_iters'] # must_init = True feat_params = pickle.load(open(options['disc_params_path'], 'rb')) for epoch_idx in xrange(options['n_epochs']): batch_rel_idx = 0 log.info('Epoch {}'.format(epoch_idx + 1)) for inputs in train_provider: if isinstance(inputs, tuple): inputs = inputs[0] batch_abs_idx += 1 batch_rel_idx += 1 if batch_abs_idx < options['initial_G_iters']: backpass = gen_backpass log_format_string = '{},{},{},,\n' elif options['initial_G_iters'] <= batch_abs_idx < base: backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: # if np.mean(disc_tracker) < 0.95: # disc_model._disc.layers[-2].re_init_weights(sess) # disc_tracker = np.ones((5000)) if (batch_abs_idx - base) % total_D2G < D_to_G[0]: # if must_init: # # i = 0 # # for j in xrange(options['num_feat_layers']): # # if feat_params[j]['layer_type'] == 'conv': # # disc_model._disc.layers[i].re_init_weights(sess) # # # print('@' * 1000) # # # print(disc_model._disc.layers[i]) # # i += 1 # for dealing with activation function # # elif feat_params[j]['layer_type'] == 'fc': # # disc_model._disc.layers[i].re_init_weights(sess) # # # print('@' * 1000) # # # print(disc_model._disc.layers[i]) # # i += 1 # disc_model._disc.layers[-2].re_init_weights(sess) # # print('@' * 1000) # # print(disc_model._disc.layers[-2]) # must_init = False backpass = disc_backpass log_format_string = '{},{},,,{}\n' else: # must_init = True backpass = gen_backpass log_format_string = '{},{},,{},\n' log_format_string = '{},{},{},,\n' result = sess.run( [disc_CE, backpass, disc_accuracy], feed_dict={ real_batch: inputs, sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) cost = result[0] if batch_abs_idx % 10 == 0: train_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_losses))) train_acc.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(last_accs))) train_log.flush() train_acc.flush() # Check cost if np.isnan(cost) or np.isinf(cost): log.info('NaN detected') # Update last losses last_losses = np.roll(last_losses, 1) last_losses[0] = cost last_accs = np.roll(last_accs, 1) last_accs[0] = result[-1] disc_tracker = np.roll(disc_tracker, 1) disc_tracker[0] = result[-1] # Display training information if np.mod(epoch_idx, options['freq_logging']) == 0: log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last losses: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_losses))) log.info( 'Epoch {:02}/{:02} Batch {:03} Current Loss: {:0>15.4f} Mean last accuracies: {:0>15.4f}' .format(epoch_idx + 1, options['n_epochs'], batch_abs_idx, float(cost), np.mean(last_accs))) # Save model if np.mod(batch_abs_idx, options['freq_saving']) == 0: saver.save( sess, os.path.join(options['model_dir'], 'model_at_%d.ckpt' % batch_abs_idx)) log.info('Model saved') # Validate model if np.mod(batch_abs_idx, options['freq_validation']) == 0: valid_costs = [] valid_accs = [] seen_batches = 0 for val_batch in val_provider: if isinstance(val_batch, tuple): val_batch = val_batch[0] result = sess.run( [disc_CE, disc_accuracy], feed_dict={ real_batch: val_batch, sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) valid_costs.append(result[0]) valid_accs.append(result[-1]) seen_batches += 1 if seen_batches == options['valid_batches']: break # Print results log.info('Validation loss: {:0>15.4f}'.format( float(np.mean(valid_costs)))) log.info('Validation accuracies: {:0>15.4f}'.format( float(np.mean(valid_accs)))) val_samples = sess.run( sampler, feed_dict={ sampler_input_batch: MVN(np.zeros(options['latent_dims']), np.diag(np.ones(options['latent_dims'])), size=options['batch_size']) }) val_log.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_costs))) val_acc.write( log_format_string.format(batch_abs_idx, '2016-04-22', np.mean(valid_accs))) val_log.flush() val_acc.flush() save_ae_samples(catalog, np.ones([options['batch_size']] + options['img_shape']), np.reshape(inputs, [options['batch_size']] + options['img_shape']), np.reshape(val_samples, [options['batch_size']] + options['img_shape']), batch_abs_idx, options['dashboard_dir'], num_to_save=5, save_gray=True) log.info('End of epoch {}'.format(epoch_idx + 1))