def began_train(images, start_epoch=0, add_epochs=None, batch_size=16, hidden_size=2048, dim=(64, 64, 3), gpu_id='/gpu:0', demo=False, get=False, start_learn_rate=1e-5, decay_every=50, save_every=1, batch_norm=True, gamma=0.75): num_epochs = start_epoch + add_epochs loss_tracker = BEGAN.loss_tracker graph = tf.Graph() with graph.as_default(): global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) with tf.device(gpu_id): learning_rate = tf.placeholder(tf.float32, shape=[]) opt = tf.train.AdamOptimizer(learning_rate, epsilon=1.0) next_batch = tf.placeholder(tf.float32, [batch_size, np.product(dim)]) x_tilde, x_tilde_d, x_d = BEGAN.run(next_batch, batch_size, hidden_size) k_t = tf.placeholder(tf.float32, shape=[]) D_loss, G_loss, k_tp, convergence_measure = \ BEGAN.loss(next_batch, x_d, x_tilde, x_tilde_d, k_t=k_t) params = tf.trainable_variables() tr_vars = {} for s in BEGAN.scopes: tr_vars[s] = [i for i in params if s in i.name] G_grad = opt.compute_gradients(G_loss, var_list=tr_vars['generator']) D_grad = opt.compute_gradients(D_loss, var_list=tr_vars['discriminator']) G_train = opt.apply_gradients(G_grad, global_step=global_step) D_train = opt.apply_gradients(D_grad, global_step=global_step) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) sess.run(init) if start_epoch > 0: path = '{}/{}_{}.tfmod'.format(checkpoint_path, checkpoint_prefix, str(start_epoch-1).zfill(4)) tf.train.Saver.restore(saver, sess, path) k_t_ = 0 # We initialise with k_t = 0 as in the paper. num_batches_per_epoch = int(len(images) / batch_size) for epoch in range(start_epoch, num_epochs): print('Epoch {} / {}'.format(epoch + 1, num_epochs + 1)) for i in tqdm.tqdm(range(num_batches_per_epoch)): iter_ = dataIterator([images], batch_size) learning_rate_ = start_learn_rate * pow(0.5, epoch // decay_every) next_batch_ = next(iter_) _, _, D_loss_, G_loss_, k_t_, M_ = \ sess.run([G_train, D_train, D_loss, G_loss, k_tp, convergence_measure], {learning_rate: learning_rate_, next_batch: next_batch_, k_t: min(max(k_t_, 0), 1)}) loss_tracker['epoch'].append(epoch) loss_tracker['iteration'].append(i) loss_tracker['k'].append(k_t_) loss_tracker['generator'].append(G_loss_) loss_tracker['discriminator'].append(D_loss_) loss_tracker['convergence_measure'].append(M_) # every epoch, append convergence info from each iter in that epoch to master csv lt_df = pd.DataFrame.from_dict(loss_tracker) lt_df = lt_df[['epoch','iteration', 'k', 'generator','discriminator','convergence_measure']] fname = 'convergence_measure.csv' if (epoch == 0) or (not os.path.isfile(fname)): with open(fname, 'w') as f: lt_df.to_csv(f, header=True) else: with open(fname, 'a') as f: lt_df.loc[lt_df['epoch'] == epoch].to_csv(f, header=False) if epoch % save_every == 0: path = '{}/{}_{}.tfmod'.format(checkpoint_path, checkpoint_prefix, str(epoch).zfill(4)) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) saver.save(sess, path) if demo: batch = dataIterator([images], batch_size).__next__() ims = sess.run(x_tilde) plot_gens((ims, batch), ('Generated 64x64 samples.', 'Random training images.'), loss_tracker) if get: return ims
def began_train(num_images=50000, start_epoch=0, add_epochs=None, batch_size=16, hidden_size=64, image_size=64, gpu_id='/gpu:0', demo=False, get=False, start_learn_rate=1e-4, decay_every=100, save_every=1, batch_norm=True, gamma=0.75): num_epochs = start_epoch + add_epochs loss_tracker = BEGAN.loss_tracker graph = tf.Graph() with graph.as_default(): global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) with tf.device(gpu_id): learning_rate = tf.placeholder(tf.float32, shape=[]) opt = tf.train.AdamOptimizer(learning_rate, epsilon=1.0) next_batch = tf.placeholder( tf.float32, [batch_size, image_size * image_size * 3]) x_tilde, x_tilde_d, x_d = BEGAN.run(next_batch, batch_size=batch_size, num_filters=128, hidden_size=hidden_size, image_size=image_size) k_t = tf.get_variable('kt', [], initializer=tf.constant_initializer(0), trainable=False) D_loss, G_loss, k_tp, convergence_measure = \ BEGAN.loss(next_batch, x_d, x_tilde, x_tilde_d, k_t=k_t) params = tf.trainable_variables() tr_vars = {} for s in BEGAN.scopes: tr_vars[s] = [i for i in params if s in i.name] G_grad = opt.compute_gradients(G_loss, var_list=tr_vars['generator']) D_grad = opt.compute_gradients(D_loss, var_list=tr_vars['discriminator']) G_train = opt.apply_gradients(G_grad, global_step=global_step) D_train = opt.apply_gradients(D_grad, global_step=global_step) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) sess.run(init) if start_epoch > 0: path = '{}/{}_{}.tfmod'.format(checkpoint_path, checkpoint_prefix, str(start_epoch - 1).zfill(4)) tf.train.Saver.restore(saver, sess, path) k_t_ = sess.run(k_t) # We initialise with k_t = 0 as in the paper. num_batches_per_epoch = num_images // batch_size for epoch in range(start_epoch, num_epochs): images = loadData(size=num_images) print('Epoch {} / {}'.format(epoch + 1, num_epochs + 1)) for i in tqdm.tqdm(range(num_batches_per_epoch)): iter_ = dataIterator([images], batch_size) learning_rate_ = start_learn_rate * pow(0.5, epoch // decay_every) next_batch_ = next(iter_) _, _, D_loss_, G_loss_, k_t_, M_ = \ sess.run([G_train, D_train, D_loss, G_loss, k_tp, convergence_measure], {learning_rate: learning_rate_, next_batch: next_batch_, k_t: min(max(k_t_, 0), 1)}) loss_tracker['generator'].append(G_loss_) loss_tracker['discriminator'].append(D_loss_) loss_tracker['convergence_measure'].append(M_) if epoch % save_every == 0: path = '{}/{}_{}.tfmod'.format(checkpoint_path, checkpoint_prefix, str(epoch).zfill(4)) saver.save(sess, path) if demo: batch = dataIterator([images], batch_size).__next__() ims = sess.run(x_tilde) plot_gens((ims, batch), ('Generated 64x64 samples.', 'Random training images.'), loss_tracker) if get: return ims