def main(argv=()): del argv # Unused. eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train_bkp') log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'eval_%s' % FLAGS.eval_set) if not os.path.exists(eval_dir): os.makedirs(eval_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) g = tf.Graph() with g.as_default(): eval_params = FLAGS eval_params.batch_size = 1 eval_params.step_size = FLAGS.num_views ########### ## model ## ########### model = model_ptn.model_PTN(eval_params) ########## ## data ## ########## eval_data = model.get_inputs( FLAGS.inp_dir, FLAGS.dataset_name, eval_params.eval_set, eval_params.batch_size, eval_params.image_size, eval_params.vox_size, is_training=False) inputs = model.preprocess_with_all_views(eval_data) ############## ## model_fn ## ############## model_fn = model.get_model_fn(is_training=False, run_projection=False) outputs = model_fn(inputs) ############# ## metrics ## ############# names_to_values, names_to_updates = model.get_metrics(inputs, outputs) del names_to_values ################ ## evaluation ## ################ num_batches = eval_data['num_samples'] slim.evaluation.evaluation_loop( master=FLAGS.master, checkpoint_dir=eval_dir, logdir=log_dir, num_evals=num_batches, eval_op=names_to_updates.values(), eval_interval_secs=FLAGS.eval_interval_secs)
def main(argv=()): del argv # Unused. eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train') log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'eval_%s' % FLAGS.eval_set) if not os.path.exists(eval_dir): os.makedirs(eval_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) g = tf.Graph() with g.as_default(): eval_params = FLAGS eval_params.batch_size = 1 eval_params.step_size = FLAGS.num_views ########### ## model ## ########### model = model_ptn.model_PTN(eval_params) ########## ## data ## ########## eval_data = model.get_inputs( FLAGS.inp_dir, FLAGS.dataset_name, eval_params.eval_set, eval_params.batch_size, eval_params.image_size, eval_params.vox_size, is_training=False) inputs = model.preprocess_with_all_views(eval_data) ############## ## model_fn ## ############## model_fn = model.get_model_fn(is_training=False, run_projection=False) outputs = model_fn(inputs) ############# ## metrics ## ############# names_to_values, names_to_updates = model.get_metrics(inputs, outputs) del names_to_values ################ ## evaluation ## ################ num_batches = eval_data['num_samples'] slim.evaluation.evaluation_loop( master=FLAGS.master, checkpoint_dir=eval_dir, logdir=log_dir, num_evals=num_batches, eval_op=names_to_updates.values(), eval_interval_secs=FLAGS.eval_interval_secs)
def main(_): train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train_bkp') save_image_dir = os.path.join(train_dir, 'images') if not os.path.exists(train_dir): os.makedirs(train_dir) if not os.path.exists(save_image_dir): os.makedirs(save_image_dir) g = tf.Graph() with g.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): global_step = slim.get_or_create_global_step() ########### ## model ## ########### model = model_ptn.model_PTN(FLAGS) ########## ## data ## ########## train_data = model.get_inputs(FLAGS.inp_dir, FLAGS.dataset_name, 'train_bkp', FLAGS.batch_size, FLAGS.image_size, FLAGS.vox_size, is_training=True) inputs = model.preprocess(train_data, FLAGS.step_size) ############## ## model_fn ## ############## model_fn = model.get_model_fn(is_training=True, reuse=False, run_projection=True) outputs = model_fn(inputs) ################## ## train_scopes ## ################## if FLAGS.init_model: train_scopes = ['decoder'] init_scopes = ['encoder'] else: train_scopes = ['encoder', 'decoder'] ########## ## loss ## ########## task_loss = model.get_loss(inputs, outputs) regularization_loss = model.get_regularization_loss(train_scopes) loss = task_loss + regularization_loss ############### ## optimizer ## ############### optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) if FLAGS.sync_replicas: optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers, total_num_replicas=FLAGS.worker_replicas) ############## ## train_op ## ############## train_op = model.get_train_op_for_scope(loss, optimizer, train_scopes) ########### ## saver ## ########### saver = tf.train.Saver( max_to_keep=np.minimum(5, FLAGS.worker_replicas + 1)) if FLAGS.task == 0: params = FLAGS params.batch_size = params.num_views params.step_size = 1 model.set_params(params) val_data = model.get_inputs(params.inp_dir, params.dataset_name, 'val', params.batch_size, params.image_size, params.vox_size, is_training=False) val_inputs = model.preprocess(val_data, params.step_size) # Note: don't compute loss here reused_model_fn = model.get_model_fn(is_training=False, reuse=True) val_outputs = reused_model_fn(val_inputs) with tf.device(tf.DeviceSpec(device_type='CPU')): vis_input_images = val_inputs['images_1'] * 255.0 vis_gt_projs = (val_outputs['masks_1'] * (-1) + 1) * 255.0 vis_pred_projs = (val_outputs['projs_1'] * (-1) + 1) * 255.0 vis_gt_projs = tf.concat([vis_gt_projs] * 3, axis=3) vis_pred_projs = tf.concat([vis_pred_projs] * 3, axis=3) # rescale new_size = [FLAGS.image_size] * 2 vis_gt_projs = tf.image.resize_nearest_neighbor( vis_gt_projs, new_size) vis_pred_projs = tf.image.resize_nearest_neighbor( vis_pred_projs, new_size) # flip # vis_gt_projs = utils.image_flipud(vis_gt_projs) # vis_pred_projs = utils.image_flipud(vis_pred_projs) # vis_gt_projs is of shape [batch, height, width, channels] write_disk_op = model.write_disk_grid( global_step=global_step, log_dir=save_image_dir, input_images=vis_input_images, gt_projs=vis_gt_projs, pred_projs=vis_pred_projs, input_voxels=val_inputs['voxels'], output_voxels=val_outputs['voxels_1']) with tf.control_dependencies([write_disk_op]): train_op = tf.identity(train_op) ############# ## init_fn ## ############# if FLAGS.init_model: init_fn = model.get_init_fn(init_scopes) else: init_fn = None ############## ## training ## ############## slim.learning.train(train_op=train_op, logdir=train_dir, init_fn=init_fn, master=FLAGS.master, is_chief=(FLAGS.task == 0), number_of_steps=FLAGS.max_number_of_steps, saver=saver, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(_): train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train') save_image_dir = os.path.join(train_dir, 'images') if not os.path.exists(train_dir): os.makedirs(train_dir) if not os.path.exists(save_image_dir): os.makedirs(save_image_dir) g = tf.Graph() with g.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): global_step = slim.get_or_create_global_step() ########### ## model ## ########### model = model_ptn.model_PTN(FLAGS) ########## ## data ## ########## train_data = model.get_inputs( FLAGS.inp_dir, FLAGS.dataset_name, 'train', FLAGS.batch_size, FLAGS.image_size, FLAGS.vox_size, is_training=True) inputs = model.preprocess(train_data, FLAGS.step_size) ############## ## model_fn ## ############## model_fn = model.get_model_fn( is_training=True, reuse=False, run_projection=True) outputs = model_fn(inputs) ################## ## train_scopes ## ################## if FLAGS.init_model: train_scopes = ['decoder'] init_scopes = ['encoder'] else: train_scopes = ['encoder', 'decoder'] ########## ## loss ## ########## task_loss = model.get_loss(inputs, outputs) regularization_loss = model.get_regularization_loss(train_scopes) loss = task_loss + regularization_loss ############### ## optimizer ## ############### optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) if FLAGS.sync_replicas: optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers, total_num_replicas=FLAGS.worker_replicas) ############## ## train_op ## ############## train_op = model.get_train_op_for_scope(loss, optimizer, train_scopes) ########### ## saver ## ########### saver = tf.train.Saver(max_to_keep=np.minimum(5, FLAGS.worker_replicas + 1)) if FLAGS.task == 0: params = FLAGS params.batch_size = params.num_views params.step_size = 1 model.set_params(params) val_data = model.get_inputs( params.inp_dir, params.dataset_name, 'val', params.batch_size, params.image_size, params.vox_size, is_training=False) val_inputs = model.preprocess(val_data, params.step_size) # Note: don't compute loss here reused_model_fn = model.get_model_fn(is_training=False, reuse=True) val_outputs = reused_model_fn(val_inputs) with tf.device(tf.DeviceSpec(device_type='CPU')): vis_input_images = val_inputs['images_1'] * 255.0 vis_gt_projs = (val_outputs['masks_1'] * (-1) + 1) * 255.0 vis_pred_projs = (val_outputs['projs_1'] * (-1) + 1) * 255.0 vis_gt_projs = tf.concat([vis_gt_projs] * 3, axis=3) vis_pred_projs = tf.concat([vis_pred_projs] * 3, axis=3) # rescale new_size = [FLAGS.image_size] * 2 vis_gt_projs = tf.image.resize_nearest_neighbor( vis_gt_projs, new_size) vis_pred_projs = tf.image.resize_nearest_neighbor( vis_pred_projs, new_size) # flip # vis_gt_projs = utils.image_flipud(vis_gt_projs) # vis_pred_projs = utils.image_flipud(vis_pred_projs) # vis_gt_projs is of shape [batch, height, width, channels] write_disk_op = model.write_disk_grid( global_step=global_step, log_dir=save_image_dir, input_images=vis_input_images, gt_projs=vis_gt_projs, pred_projs=vis_pred_projs, input_voxels=val_inputs['voxels'], output_voxels=val_outputs['voxels_1']) with tf.control_dependencies([write_disk_op]): train_op = tf.identity(train_op) ############# ## init_fn ## ############# if FLAGS.init_model: init_fn = model.get_init_fn(init_scopes) else: init_fn = None ############## ## training ## ############## slim.learning.train( train_op=train_op, logdir=train_dir, init_fn=init_fn, master=FLAGS.master, is_chief=(FLAGS.task == 0), number_of_steps=FLAGS.max_number_of_steps, saver=saver, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(argv=()): del argv # Unused. #eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train') #log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, # 'eval_%s' % FLAGS.eval_set) #if not os.path.exists(eval_dir): # os.makedirs(eval_dir) #if not os.path.exists(log_dir): # os.makedirs(log_dir) g = tf.Graph() with g.as_default(): eval_params = FLAGS eval_params.batch_size = 1 eval_params.step_size = FLAGS.num_views ########### ## model ## ########### model = model_ptn.model_PTN(eval_params) ########## ## data ## ########## eval_data = model.get_inputs( FLAGS.inp_dir, FLAGS.dataset_name, eval_params.eval_set, eval_params.batch_size, eval_params.image_size, eval_params.vox_size, is_training=False) inputs = model.preprocess_with_all_views(eval_data) ############## ## model_fn ## ############## model_fn = model.get_model_fn(is_training=False, run_projection=False) outputs = model_fn(inputs) ############# ## metrics ## ############# names_to_values, names_to_updates = model.get_metrics(inputs, outputs) del names_to_values ################ ## evaluation ## ################ num_batches = eval_data['num_samples'] sess = tf.Session() tf.train.start_queue_runners(sess=sess) saver = tf.train.Saver() def restore_from_checkpoint(sess, saver): ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if not ckpt or not ckpt.model_checkpoint_path: return False saver.restore(sess, ckpt.model_checkpoint_path) return True if not restore_from_checkpoint(sess, saver): raise NotImplementedError init = tf.global_variables_initializer() sess.run(init) init = tf.local_variables_initializer() sess.run(init) for i in range(num_batches): print('Running {} batch out of {} batches.'.format(i, num_batches)) options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() sess.run(list(names_to_updates.values()), options=options, run_metadata=run_metadata) cg = CompGraph('ptn', run_metadata, tf.get_default_graph()) cg_tensor_dict = cg.get_tensors() cg_sorted_keys = sorted(cg_tensor_dict.keys()) cg_sorted_items = [] for cg_key in cg_sorted_keys: cg_sorted_items.append(tf.shape(cg_tensor_dict[cg_key])) cg_sorted_shape = sess.run(cg_sorted_items) cg.op_analysis(dict(zip(cg_sorted_keys, cg_sorted_shape)), 'ptn.pickle') exit(0)