def main(): workspace_path = os.environ.get('AE_WORKSPACE_PATH') if workspace_path == None: print('Please define a workspace path:\n') print('export AE_WORKSPACE_PATH=/path/to/workspace\n') exit(-1) gentle_stop = np.array((1,), dtype=np.bool) gentle_stop[0] = False def on_ctrl_c(signal, frame): gentle_stop[0] = True signal.signal(signal.SIGINT, on_ctrl_c) parser = argparse.ArgumentParser() parser.add_argument("experiment_name") parser.add_argument("-d", action='store_true', default=False) parser.add_argument("-gen", action='store_true', default=False) parser.add_argument("-vis_emb", action='store_true', default=False) parser.add_argument('--at_step', default=None, type=int, required=False) arguments = parser.parse_args() full_name = arguments.experiment_name.split('/') experiment_name = full_name.pop() experiment_group = full_name.pop() if len(full_name) > 0 else '' debug_mode = arguments.d generate_data = arguments.gen at_step = arguments.at_step cfg_file_path = u.get_config_file_path(workspace_path, experiment_name, experiment_group) log_dir = u.get_log_dir(workspace_path, experiment_name, experiment_group) checkpoint_file = u.get_checkpoint_basefilename(log_dir) ckpt_dir = u.get_checkpoint_dir(log_dir) train_fig_dir = u.get_train_fig_dir(log_dir) dataset_path = u.get_dataset_path(workspace_path) if not os.path.exists(cfg_file_path): print('Could not find config file:\n') print('{}\n'.format(cfg_file_path)) exit(-1) args = configparser.ConfigParser() args.read(cfg_file_path) num_iter = args.getint('Training', 'NUM_ITER') if not debug_mode else np.iinfo(np.int32).max save_interval = args.getint('Training', 'SAVE_INTERVAL') num_gpus = 1 model_type = args.get('Dataset', 'MODEL') with tf.variable_scope(experiment_name, reuse=tf.AUTO_REUSE): dataset = factory.build_dataset(dataset_path, args) multi_queue = factory.build_multi_queue(dataset, args) dev_splits = np.array_split(np.arange(24), num_gpus) iterator = multi_queue.create_iterator(dataset_path, args) all_object_views = tf.concat([inp[0] for inp in multi_queue.next_element],0) bs = multi_queue._batch_size encoding_splits = [] for dev in range(num_gpus): with tf.device('/device:GPU:%s' % dev): encoder = factory.build_encoder(all_object_views[dev_splits[dev][0]*bs:(dev_splits[dev][-1]+1)*bs], args, is_training=False) encoding_splits.append(tf.split(encoder.z, len(dev_splits[dev]),0)) with tf.variable_scope(experiment_name): decoders = [] for dev in range(num_gpus): with tf.device('/device:GPU:%s' % dev): for j,i in enumerate(dev_splits[dev]): decoders.append(factory.build_decoder(multi_queue.next_element[i], encoding_splits[dev][j], args, is_training=False, idx=i)) ae = factory.build_ae(encoder, decoders, args) codebook = factory.build_codebook(encoder, dataset, args) train_op = factory.build_train_op(ae, args) saver = tf.train.Saver(save_relative_paths=True) dataset.load_bg_images(dataset_path) multi_queue.create_tfrecord_training_images(dataset_path, args) widgets = ['Training: ', progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.Counter(), ' / %s' % num_iter, ' ', progressbar.ETA(), ' '] bar = progressbar.ProgressBar(maxval=num_iter,widgets=widgets) gpu_options = tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction = 0.9) config = tf.ConfigProto(gpu_options=gpu_options,log_device_placement=True,allow_soft_placement=True) with tf.Session(config=config) as sess: sess.run(multi_queue.bg_img_init.initializer) sess.run(iterator.initializer) chkpt = tf.train.get_checkpoint_state(ckpt_dir) if chkpt and chkpt.model_checkpoint_path: if at_step is None: checkpoint_file_basename = u.get_checkpoint_basefilename(log_dir,latest=args.getint('Training', 'NUM_ITER')) else: checkpoint_file_basename = u.get_checkpoint_basefilename(log_dir,latest=at_step) print('loading ', checkpoint_file_basename) saver.restore(sess, checkpoint_file_basename) else: if encoder._pre_trained_model != 'False': encoder.saver.restore(sess, encoder._pre_trained_model) all_vars = set([var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]) var_list = all_vars.symmetric_difference([v[1] for v in list(encoder.fil_var_list.items())]) sess.run(tf.variables_initializer(var_list)) print(sess.run(tf.report_uninitialized_variables())) else: sess.run(tf.global_variables_initializer()) if not debug_mode: print('Training with %s model' % args.get('Dataset','MODEL'), os.path.basename(args.get('Paths','MODEL_PATH'))) bar.start() while True: this,_,reconstr_train,enc_z = sess.run([multi_queue.next_element,multi_queue.next_bg_element,[decoder.x for decoder in decoders], encoder.z]) this_x = np.concatenate([el[0] for el in this]) this_y = np.concatenate([el[2] for el in this]) print(this_x.shape) reconstr_train = np.concatenate(reconstr_train) print(this_x.shape) cv2.imshow('sample batch', np.hstack(( u.tiles(this_x, 4, 6), u.tiles(reconstr_train, 4,6),u.tiles(this_y, 4, 6))) ) k = cv2.waitKey(0) idx = np.random.randint(0,24) this_y = np.repeat(this_y[idx:idx+1, :, :], 24, axis=0) reconstr_train = sess.run([decoder.x for decoder in decoders],feed_dict={encoder._input:this_y}) reconstr_train = np.array(reconstr_train) print(reconstr_train.shape) reconstr_train = reconstr_train.squeeze() cv2.imshow('sample batch 2', np.hstack((u.tiles(this_y, 4, 6), u.tiles(reconstr_train, 4, 6)))) k = cv2.waitKey(0) if k == 27: break if gentle_stop[0]: break if not debug_mode: bar.finish() if not gentle_stop[0] and not debug_mode: print('To create the embedding run:\n') print('ae_embed {}\n'.format(full_name))
def main(): tf.disable_eager_execution() workspace_path = os.environ.get('AE_WORKSPACE_PATH') if workspace_path == None: print('Please define a workspace path:\n') print('export AE_WORKSPACE_PATH=/path/to/workspace\n') exit(-1) parser = argparse.ArgumentParser() parser.add_argument("experiment_name") parser.add_argument('--at_step', default=None, required=False) arguments = parser.parse_args() full_name = arguments.experiment_name.split('/') experiment_name = full_name.pop() experiment_group = full_name.pop() if len(full_name) > 0 else '' at_step = arguments.at_step cfg_file_path = u.get_config_file_path(workspace_path, experiment_name, experiment_group) log_dir = u.get_log_dir(workspace_path, experiment_name, experiment_group) checkpoint_file = u.get_checkpoint_basefilename(log_dir) ckpt_dir = u.get_checkpoint_dir(log_dir) dataset_path = u.get_dataset_path(workspace_path) print(checkpoint_file) print(ckpt_dir) print('#' * 20) if not os.path.exists(cfg_file_path): print('Could not find config file:\n') print('{}\n'.format(cfg_file_path)) exit(-1) args = configparser.ConfigParser() args.read(cfg_file_path) with tf.variable_scope(experiment_name): dataset = factory.build_dataset(dataset_path, args) queue = factory.build_queue(dataset, args) encoder = factory.build_encoder(queue.x, args) decoder = factory.build_decoder(queue.y, encoder, args) ae = factory.build_ae(encoder, decoder, args) codebook = factory.build_codebook(encoder, dataset, args) saver = tf.train.Saver(save_relative_paths=True) batch_size = args.getint('Training', 'BATCH_SIZE') model = args.get('Dataset', 'MODEL') gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7) config = tf.ConfigProto(gpu_options=gpu_options) with tf.Session(config=config) as sess: print(ckpt_dir) print('#' * 20) factory.restore_checkpoint(sess, saver, ckpt_dir, at_step=at_step) # chkpt = tf.train.get_checkpoint_state(ckpt_dir) # if chkpt and chkpt.model_checkpoint_path: # print chkpt.model_checkpoint_path # saver.restore(sess, chkpt.model_checkpoint_path) # else: # print 'No checkpoint found. Expected one in:\n' # print '{}\n'.format(ckpt_dir) # exit(-1) if model == 'dsprites': codebook.update_embedding_dsprites(sess, args) else: codebook.update_embedding(sess, batch_size) print('Saving new checkoint ..') saver.save(sess, checkpoint_file, global_step=ae.global_step) print('done')
def main(): workspace_path = os.environ.get('AE_WORKSPACE_PATH') if workspace_path == None: print('Please define a workspace path:\n') print('export AE_WORKSPACE_PATH=/path/to/workspace\n') exit(-1) parser = argparse.ArgumentParser() parser.add_argument("experiment_name") parser.add_argument('--at_step', default=None, type=int, required=False) parser.add_argument('--model_path', type=str, required=True) arguments = parser.parse_args() full_name = arguments.experiment_name.split('/') experiment_name = full_name.pop() experiment_group = full_name.pop() if len(full_name) > 0 else '' at_step = arguments.at_step model_path = arguments.model_path cfg_file_path = u.get_config_file_path(workspace_path, experiment_name, experiment_group) log_dir = u.get_log_dir(workspace_path, experiment_name, experiment_group) ckpt_dir = u.get_checkpoint_dir(log_dir) dataset_path = u.get_dataset_path(workspace_path) if not os.path.exists(cfg_file_path): print('Could not find config file:\n') print('{}\n'.format(cfg_file_path)) exit(-1) args = configparser.ConfigParser() args.read(cfg_file_path) iteration = args.getint('Training', 'NUM_ITER') if at_step is None else at_step checkpoint_file_basename = u.get_checkpoint_basefilename(log_dir, latest=iteration, joint=True) if not tf.train.checkpoint_exists(checkpoint_file_basename): checkpoint_file_basename = u.get_checkpoint_basefilename( log_dir, latest=iteration, joint=False) checkpoint_single_encoding = u.get_checkpoint_basefilename( log_dir, latest=iteration, model_path=model_path) target_checkpoint_file = u.get_checkpoint_basefilename(log_dir, joint=True) print(checkpoint_file_basename) print(target_checkpoint_file) print(ckpt_dir) print('#' * 20) with tf.variable_scope(experiment_name): dataset = factory.build_dataset(dataset_path, args) queue = factory.build_queue(dataset, args) encoder = factory.build_encoder(queue.x, args) # decoder = factory.build_decoder(queue.y, encoder, args) # ae = factory.build_ae(encoder, decoder, args) # before_cb = set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) codebook_multi = factory.build_codebook_multi( encoder, dataset, args, checkpoint_file_basename) restore_saver = tf.train.Saver(save_relative_paths=True, max_to_keep=100) codebook_multi.add_new_codebook_to_graph(model_path) # inters_vars = before_cb.intersection(set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))) saver = tf.train.Saver(save_relative_paths=True, max_to_keep=100) batch_size = args.getint('Training', 'BATCH_SIZE') * len( eval(args.get('Paths', 'MODEL_PATH'))) model = args.get('Dataset', 'MODEL') gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8) config = tf.ConfigProto(gpu_options=gpu_options) with tf.Session(config=config) as sess: print(ckpt_dir) # print sess.run(encoder.global_step) print('#' * 20) # factory.restore_checkpoint(sess, saver, ckpt_dir, at_step=at_step) sess.run(tf.global_variables_initializer()) restore_saver.restore(sess, checkpoint_file_basename) print('#' * 20) # chkpt = tf.train.get_checkpoint_state(ckpt_dir) # if chkpt and chkpt.model_checkpoint_path: # print chkpt.model_checkpoint_path # saver.restore(sess, chkpt.model_checkpoint_path) # else: # print 'No checkpoint found. Expected one in:\n' # print '{}\n'.format(ckpt_dir) # exit(-1) try: loaded_emb = tf.train.load_variable( checkpoint_single_encoding, experiment_name + '/embedding_normalized') loaded_obj_bbs = tf.train.load_variable( checkpoint_single_encoding, experiment_name + '/embed_obj_bbs_var') except: loaded_emb = None loaded_obj_bbs = None if model == 'dsprites': codebook_multi.update_embedding_dsprites(sess, args) else: codebook_multi.update_embedding(sess, batch_size, model_path, loaded_emb=loaded_emb, loaded_obj_bbs=loaded_obj_bbs) print('Saving new checkoint ..') saver.save(sess, target_checkpoint_file, global_step=iteration) print('done')
def main(): workspace_path = os.environ.get('AE_WORKSPACE_PATH') if workspace_path is None: print('Please define a workspace path:\n') print('export AE_WORKSPACE_PATH=/path/to/workspace\n') exit(-1) gentle_stop = np.array((1, ), dtype=np.bool) gentle_stop[0] = False def on_ctrl_c(signal, frame): gentle_stop[0] = True signal.signal(signal.SIGINT, on_ctrl_c) parser = argparse.ArgumentParser() parser.add_argument("experiment_name") parser.add_argument("-d", action='store_true', default=False) parser.add_argument("-gen", action='store_true', default=False) parser.add_argument('--at_step', default=None, type=int, required=False) arguments = parser.parse_args() full_name = arguments.experiment_name.split('/') experiment_name = full_name.pop() experiment_group = full_name.pop() if len(full_name) > 0 else '' debug_mode = arguments.d generate_data = arguments.gen at_step = arguments.at_step cfg_file_path = u.get_config_file_path(workspace_path, experiment_name, experiment_group) log_dir = u.get_log_dir(workspace_path, experiment_name, experiment_group) checkpoint_file = u.get_checkpoint_basefilename(log_dir) ckpt_dir = u.get_checkpoint_dir(log_dir) train_fig_dir = u.get_train_fig_dir(log_dir) dataset_path = u.get_dataset_path(workspace_path) if not os.path.exists(cfg_file_path): print('Could not find config file:\n') print(('{}\n'.format(cfg_file_path))) exit(-1) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) if not os.path.exists(train_fig_dir): os.makedirs(train_fig_dir) if not os.path.exists(dataset_path): os.makedirs(dataset_path) args = configparser.ConfigParser(inline_comment_prefixes="#") args.read(cfg_file_path) shutil.copy2(cfg_file_path, log_dir) num_iter = args.getint( 'Training', 'NUM_ITER') if not debug_mode else np.iinfo(np.int32).max save_interval = args.getint('Training', 'SAVE_INTERVAL') num_gpus = args.getint('Training', 'NUM_GPUS') with tf.device('/device:CPU:0'): with tf.variable_scope(experiment_name, reuse=tf.AUTO_REUSE): dataset = factory.build_dataset(dataset_path, args) multi_queue = factory.build_multi_queue(dataset, args) if generate_data: # dataset.load_bg_images(dataset_path) multi_queue.create_tfrecord_training_images(dataset_path, args) print('finished generating training images') exit() dev_splits = np.array_split(np.arange(multi_queue._num_objects), num_gpus) iterator = multi_queue.create_iterator(dataset_path, args) all_x, all_y = list( zip(*[(inp[0], inp[2]) for inp in multi_queue.next_element])) all_x, all_y = tf.concat(all_x, axis=0), tf.concat(all_y, axis=0) print(all_x.shape) encoding_splits = [] for dev in range(num_gpus): with tf.device('/device:GPU:%s' % dev): sta = dev_splits[dev][0] * multi_queue._batch_size end = (dev_splits[dev][-1] + 1) * multi_queue._batch_size print(sta, end) encoder = factory.build_encoder(all_x[sta:end], args, target=all_y[sta:end], is_training=True) encoding_splits.append( tf.split(encoder.z, len(dev_splits[dev]), 0)) with tf.variable_scope(experiment_name): decoders = [] for dev in range(num_gpus): with tf.device('/device:GPU:%s' % dev): for j, i in enumerate(dev_splits[dev]): print(len(encoding_splits)) decoders.append( factory.build_decoder(multi_queue.next_element[i], encoding_splits[dev][j], args, is_training=True, idx=i)) ae = factory.build_ae(encoder, decoders, args) codebook = factory.build_codebook(encoder, dataset, args) train_op = factory.build_train_op(ae, args) saver = tf.train.Saver(save_relative_paths=True, max_to_keep=1) # dataset.get_training_images(dataset_path, args) # dataset.load_bg_images(dataset_path) multi_queue.create_tfrecord_training_images(dataset_path, args) if generate_data: print(('finished generating synthetic training data for ' + experiment_name)) print('exiting...') exit() widgets = [ 'Training: ', progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.Counter(), ' / %s' % num_iter, ' ', progressbar.ETA(), ' ' ] bar = progressbar.ProgressBar(maxval=num_iter, widgets=widgets) gpu_options = tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction=0.9) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) with tf.Session(config=config) as sess: sess.run(multi_queue.bg_img_init.initializer) sess.run(iterator.initializer) u.create_summaries(multi_queue, decoders, ae) merged_loss_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(ckpt_dir, sess.graph) chkpt = tf.train.get_checkpoint_state(ckpt_dir) if chkpt and chkpt.model_checkpoint_path: if at_step is None: # checkpoint_file_basename = u.get_checkpoint_basefilename(log_dir,latest=args.getint('Training', 'NUM_ITER')) checkpoint_file_basename = chkpt.model_checkpoint_path else: checkpoint_file_basename = u.get_checkpoint_basefilename( log_dir, latest=at_step) print(('loading ', checkpoint_file_basename)) saver.restore(sess, checkpoint_file_basename) # except: # print 'loading ', chkpt.model_checkpoint_path # saver.restore(sess, chkpt.model_checkpoint_path) else: if encoder._pre_trained_model != 'False': encoder.saver.restore(sess, encoder._pre_trained_model) all_vars = set([ var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ]) var_list = all_vars.symmetric_difference( [v[1] for v in list(encoder.fil_var_list.items())]) sess.run(tf.variables_initializer(var_list)) print(sess.run(tf.report_uninitialized_variables())) else: sess.run(tf.global_variables_initializer()) if not debug_mode: print(('Training with %s model' % args.get('Dataset', 'MODEL'), os.path.basename(args.get('Paths', 'MODEL_PATH')))) bar.start() for i in range(encoder.global_step.eval(), num_iter): if not debug_mode: # print 'before optimize' sess.run([train_op, multi_queue.next_bg_element]) # print 'after optimize' if (i + 1) % 100 == 0: merged_summaries = sess.run(merged_loss_summary) summary_writer.add_summary(merged_summaries, i) bar.update(i) if (i + 1) % save_interval == 0: saver.save(sess, checkpoint_file, global_step=encoder.global_step) # this_x, this_y = sess.run([queue.x, queue.y]) # reconstr_train = sess.run(decoder.x,feed_dict={queue.x:this_x}) this, reconstr_train = sess.run([ multi_queue.next_element, [decoder.x for decoder in decoders] ]) this_x = np.concatenate([el[0] for el in this]) this_y = np.concatenate([el[2] for el in this]) # reconstr_train = sess.run(,feed_dict={queue.x:this_x}) reconstr_train = np.concatenate(reconstr_train) for imgs in [this_x, this_y, reconstr_train]: np.random.seed(0) np.random.shuffle(imgs) train_imgs = np.hstack( (u.tiles(this_x, 4, 4), u.tiles(reconstr_train, 4, 4), u.tiles(this_y, 4, 4))) cv2.imwrite( os.path.join(train_fig_dir, 'training_images_%s.png' % i), train_imgs * 255) else: this, _, reconstr_train = sess.run([ multi_queue.next_element, multi_queue.next_bg_element, [decoder.x for decoder in decoders] ]) this_x = np.concatenate([el[0] for el in this]) this_y = np.concatenate([el[2] for el in this]) print(this_x.shape, reconstr_train[0].shape, len(reconstr_train)) reconstr_train = np.concatenate(reconstr_train, axis=0) for imgs in [this_x, this_y, reconstr_train]: np.random.seed(0) np.random.shuffle(imgs) print(this_x.shape) cv2.imshow( 'sample batch', np.hstack((u.tiles(this_x, 4, 6), u.tiles(reconstr_train, 4, 6), u.tiles(this_y, 4, 6)))) k = cv2.waitKey(0) if k == 27: break if gentle_stop[0]: break if not debug_mode: bar.finish() if not gentle_stop[0] and not debug_mode: print('To create the embedding run:\n') print(('ae_embed {}\n'.format(full_name)))
def main(): tf.disable_eager_execution() workspace_path = os.environ.get('AE_WORKSPACE_PATH') if workspace_path is None: print('Please define a workspace path:\n') print('export AE_WORKSPACE_PATH=/path/to/workspace\n') exit(-1) gentle_stop = np.array((1,), dtype=np.bool) gentle_stop[0] = False def on_ctrl_c(signal, frame): gentle_stop[0] = True signal.signal(signal.SIGINT, on_ctrl_c) parser = argparse.ArgumentParser() parser.add_argument("experiment_name") parser.add_argument("-d", action='store_true', default=False) parser.add_argument("-gen", action='store_true', default=False) arguments = parser.parse_args() full_name = arguments.experiment_name.split('/') experiment_name = full_name.pop() experiment_group = full_name.pop() if len(full_name) > 0 else '' debug_mode = arguments.d generate_data = arguments.gen cfg_file_path = u.get_config_file_path(workspace_path, experiment_name, experiment_group) log_dir = u.get_log_dir(workspace_path, experiment_name, experiment_group) checkpoint_file = u.get_checkpoint_basefilename(log_dir) ckpt_dir = u.get_checkpoint_dir(log_dir) train_fig_dir = u.get_train_fig_dir(log_dir) dataset_path = u.get_dataset_path(workspace_path) if not os.path.exists(cfg_file_path): print('Could not find config file:\n') print('{}\n'.format(cfg_file_path)) exit(-1) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) if not os.path.exists(train_fig_dir): os.makedirs(train_fig_dir) if not os.path.exists(dataset_path): os.makedirs(dataset_path) args = configparser.ConfigParser() args.read(cfg_file_path) shutil.copy2(cfg_file_path, log_dir) with tf.variable_scope(experiment_name): dataset = factory.build_dataset(dataset_path, args) queue = factory.build_queue(dataset, args) encoder = factory.build_encoder(queue.x, args, is_training=True) decoder = factory.build_decoder(queue.y, encoder, args, is_training=True) ae = factory.build_ae(encoder, decoder, args) codebook = factory.build_codebook(encoder, dataset, args) train_op = factory.build_train_op(ae, args) saver = tf.train.Saver(save_relative_paths=True) num_iter = args.getint('Training', 'NUM_ITER') if not debug_mode else 100000 save_interval = args.getint('Training', 'SAVE_INTERVAL') model_type = args.get('Dataset', 'MODEL') if model_type=='dsprites': dataset.get_sprite_training_images(args) else: dataset.get_training_images(dataset_path, args) dataset.load_bg_images(dataset_path) if generate_data: print('finished generating synthetic training data for ' + experiment_name) print('exiting...') exit() widgets = ['Training: ', progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.Counter(), ' / %s' % num_iter, ' ', progressbar.ETA(), ' '] bar = progressbar.ProgressBar(maxval=num_iter,widgets=widgets) gpu_options = tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction=0.8) config = tf.ConfigProto(gpu_options=gpu_options) with tf.Session(config=config) as sess: chkpt = tf.train.get_checkpoint_state(ckpt_dir) if chkpt and chkpt.model_checkpoint_path: saver.restore(sess, chkpt.model_checkpoint_path) else: sess.run(tf.global_variables_initializer()) merged_loss_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(ckpt_dir, sess.graph) if not debug_mode: print('Training with %s model' % args.get('Dataset','MODEL'), os.path.basename(args.get('Paths','MODEL_PATH'))) bar.start() queue.start(sess) for i in range(ae.global_step.eval(), num_iter): if not debug_mode: sess.run(train_op) if i % 10 == 0: loss = sess.run(merged_loss_summary) summary_writer.add_summary(loss, i) bar.update(i) if (i+1) % save_interval == 0: saver.save(sess, checkpoint_file, global_step=ae.global_step) this_x, this_y = sess.run([queue.x, queue.y]) reconstr_train = sess.run(decoder.x,feed_dict={queue.x:this_x}) train_imgs = np.hstack(( u.tiles(this_x, 4, 4), u.tiles(reconstr_train, 4,4),u.tiles(this_y, 4, 4))) cv2.imwrite(os.path.join(train_fig_dir,'training_images_%s.png' % i), train_imgs*255) else: this_x, this_y = sess.run([queue.x, queue.y]) reconstr_train = sess.run(decoder.x,feed_dict={queue.x:this_x}) cv2.imshow('sample batch', np.hstack(( u.tiles(this_x, 3, 3), u.tiles(reconstr_train, 3,3),u.tiles(this_y, 3, 3))) ) k = cv2.waitKey(0) if k == 27: break if gentle_stop[0]: break queue.stop(sess) if not debug_mode: bar.finish() if not gentle_stop[0] and not debug_mode: print('To create the embedding run:\n') print('ae_embed {}\n'.format(full_name))