def test(args): model = GAN(batch_size=args.bs, noise_dim=args.noise_dim, learning_rate=args.lr, trainable=True) model.build() saver = tf.train.Saver(max_to_keep=20) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) latest_ckpt = tf.train.latest_checkpoint(args.loadpath) saver.restore(sess, latest_ckpt) print('restore from', latest_ckpt) batch_y_tot = gen_testdata() # batch_y_tot = gen_fromfile(args.testfile) new_im = Image.new('RGB', (64 * 5, 64 * len(batch_y_tot))) for j in range(len(batch_y_tot)): batch_y = np.tile(batch_y_tot[j], (5, 1)) noise = np.random.uniform(-1, 1, [batch_y.shape[0], args.noise_dim]) generated_test = sess.run(model.sampler, feed_dict={ model.noises: noise, model.labels: batch_y }) for i in range(5): generated = (generated_test[i] + 1) * 127.5 # scale from [-1., 1.] to [0., 255.] generated = np.clip(generated, 0., 255.).astype(np.uint8) generated = misc.imresize(generated, [64, 64, 3]) gen_path = 'samples/sample_' + str(j + 1) + '_' + str(i + 1) + '.jpg' misc.imsave(gen_path, generated) new_im.paste(Image.fromarray(generated, "RGB"), (64 * i, 64 * j)) path = 'samples_' + str(j) pickle.dump(noise, open(path, 'wb')) gen_path = 'samples/' + '1' + '.jpg' new_im.save(gen_path) print('gen results:', len(batch_y_tot), '* 5 in samples')
def train(dataset, args, retrain): model = GAN(batch_size=args.bs, noise_dim=args.noise_dim, learning_rate=args.lr, trainable=True) model.build() saver = tf.train.Saver(max_to_keep=20) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) if retrain == True: latest_ckpt = tf.train.latest_checkpoint(args.loadpath) saver.restore(sess, latest_ckpt) print('restore from', latest_ckpt) print( 'Start training, method=%s, lr=%f, batch_size=%d, epoch=%d, comment=%s' % (model.name, args.lr, args.bs, args.ep, args.comment)) path_args = args.savepath + 'output.log' fwirte = open(path_args, 'w') for ep in range(1, args.ep + 1): d_tot_loss = [] g_tot_loss = [] total_loss = 0 start_time = time.time() for step in range(0, dataset.iters): for d_train in range(5): batch_x, wrong_x, batch_y, wrong_y = dataset.next_batch() noise = np.random.uniform(-1, 1, [batch_x.shape[0], args.noise_dim]) # update D d_loss, _ = sess.run( [model.d_loss_op, model.d_train_op], feed_dict={ model.imgs: batch_x, model.labels: batch_y, model.noises: noise, model.imgs_wrong: wrong_x, model.labels_wrong: wrong_y }) d_tot_loss.append(d_loss) # update G noise = np.random.uniform(-1, 1, [batch_x.shape[0], args.noise_dim]) generated, g_loss, _ = sess.run( [model.generated_op, model.g_loss_op, model.g_train_op], feed_dict={ model.imgs: batch_x, model.labels: batch_y, model.noises: noise, model.imgs_wrong: wrong_x, model.labels_wrong: wrong_y }) g_tot_loss.append(g_loss) if step % 2 == 0 or step == 1: print( "Epoch: %5d, Step: %4d/%4d, D_loss: %.4f, G_loss: %.4f " % (ep, step, dataset.iters, np.mean(d_tot_loss), np.mean(g_tot_loss)), end='\r') end_time = time.time() fwirte.write( "Epoch: %5d, Take time: %4.1fs, D_loss: %.4f, G_loss: %.4f\n" % (ep, (end_time - start_time), np.mean(d_tot_loss), np.mean(g_tot_loss))) print( "Epoch: %5d, Take time: %4.1fs, D_loss: %.4f, G_loss: %.4f " % (ep, (end_time - start_time), np.mean(d_tot_loss), np.mean(g_tot_loss))) if ep % 10 == 0 or ep == 1: # test new_im = Image.new('RGB', (64 * 5, 64 * 5)) for j in range(5): batch_y = gen_testdata() noise = np.random.uniform(-1, 1, [batch_y.shape[0], args.noise_dim]) generated_test = sess.run(model.sampler, feed_dict={ model.noises: noise, model.labels: batch_y }) for i in range(5): generated = ( generated_test[i] + 1) * 127.5 # scale from [-1., 1.] to [0., 255.] generated = np.clip(generated, 0., 255.).astype(np.uint8) generated = misc.imresize(generated, [64, 64, 3]) new_im.paste(Image.fromarray(generated, "RGB"), (64 * i, 64 * j)) gen_path = args.savepath + 'generated/' + str(ep) + '.jpg' new_im.save(gen_path) saver.save(sess, args.savepath + 'model.ckpt', global_step=ep) print('Done') print("Model saved in file: %s\n" % args.savepath) fwirte.write('Done') fwirte.close()