model.feat_holder: batch_feat, model.isTrain: (args.train_bn == 1) }) if (n_epoch % args.info_epoch == 0): print('[n_epoch: %d, D_loss: %f, G_loss: %f]' % (n_epoch, D_loss_curr, G_loss_curr)) save_img_fill = sess.run( [model.fake_img], feed_dict={ model.noise_holder: save_noise_fill, model.feat_holder: save_feat_fill, model.isTrain: (args.test_bn == 1) }) save_img = save_img_fill[0][0:10, :, :, :] save_image_train_by_digit(n_epoch, save_img, args, generated=True) # label = np.argmax(batch_feat[0]) # filename = str(n_epoch)+'_'+str(label)+'.jpg' # misc.imsave(os.path.join(args.save_img_dir, filename), fake_img[0, :, :, :]) # save_image_train(n_epoch, fake_img, args, generated = True) # save_image_train(n_epoch, batch_img, args, generated = False) # save_path = saver.save(sess, args.log_dir+'/model_'+str(n_epoch)+'.ckpt') # print("Model saved in file: %s" % save_path) saver.save(sess, save_path=args.log_dir, global_step=n_epoch)
if __name__ == '__main__': if not os.path.exists(args.save_img_dir): os.mkdir(args.save_img_dir) with tf.Graph().as_default() as graph: initializer = tf.random_uniform_initializer(-args.init_scale, args.init_scale) with tf.variable_scope('model_capsule', reuse=None, initializer=initializer) as scope: model = CapsGAN(args) scope.reuse_variables() config = tf.ConfigProto() config.gpu_options.allow_growth = True config.graph_options.optimizer_options.global_jit_level =\ tf.OptimizerOptions.ON_1 sv = tf.train.Supervisor(logdir=args.log_dir, save_model_secs=args.save_model_secs) saver = sv.saver with sv.managed_session(config=config) as sess: save_noise = np.random.uniform(-1., 1., [10, args.noise_dim]) save_feat = to_categorical(np.arange(10), num_classes=10) save_noise_fill = np.concatenate((save_noise, np.zeros((args.batch_size-10, args.noise_dim))), axis=0) save_feat_fill = np.concatenate((save_feat, np.zeros((args.batch_size-10, 10))), axis=0) save_img_fill = sess.run([model.fake_img], feed_dict={model.noise_holder: save_noise_fill, model.feat_holder: save_feat_fill, model.isTrain: True}) save_img = save_img_fill[0][0:10, :, :, :] save_image_train_by_digit('_test', save_img, args, generated = True)