def evaluate(config, restore_path): if restore_path is None: restore_file = executor.search_restore_filename( environment.CHECKPOINTS_DIR) restore_path = os.path.join(environment.CHECKPOINTS_DIR, restore_file) if not os.path.exists("{}.index".format(restore_path)): raise Exception("restore file {} dont exists.".format(restore_path)) print("restore_path:", restore_path) DatasetClass = config.DATASET_CLASS ModelClass = config.NETWORK_CLASS network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()} if "test" in DatasetClass.available_subsets: subset = "test" else: subset = "validation" validation_dataset = setup_dataset(config, subset, seed=0) graph = tf.Graph() with graph.as_default(): if ModelClass.__module__.startswith("lmnet.networks.object_detection"): model = ModelClass( classes=validation_dataset.classes, num_max_boxes=validation_dataset.num_max_boxes, is_debug=config.IS_DEBUG, **network_kwargs, ) else: model = ModelClass( classes=validation_dataset.classes, is_debug=config.IS_DEBUG, **network_kwargs, ) global_step = tf.Variable(0, name="global_step", trainable=False) is_training = tf.constant(False, name="is_training") images_placeholder, labels_placeholder = model.placeholders() output = model.inference(images_placeholder, is_training) metrics_ops_dict, metrics_update_op = model.metrics( output, labels_placeholder) model.summary(output, labels_placeholder) summary_op = tf.summary.merge_all() metrics_summary_op, metrics_placeholders = executor.prepare_metrics( metrics_ops_dict) init_op = tf.global_variables_initializer() reset_metrics_op = tf.local_variables_initializer() saver = tf.train.Saver(max_to_keep=None) session_config = None # tf.ConfigProto(log_device_placement=True) sess = tf.Session(graph=graph, config=session_config) sess.run([init_op, reset_metrics_op]) validation_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR + "/evaluate") saver.restore(sess, restore_path) last_step = sess.run(global_step) # init metrics values test_step_size = int( math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE)) print("test_step_size", test_step_size) for test_step in range(test_step_size): print("test_step", test_step) images, labels = validation_dataset.feed() feed_dict = { images_placeholder: images, labels_placeholder: labels, } # Summarize at only last step. if test_step == test_step_size - 1: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) validation_writer.add_summary(summary, last_step) else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) validation_writer.add_summary(metrics_summary, last_step)
def start_training(config): if config.IS_DISTRIBUTION: import horovod.tensorflow as hvd # initialize Horovod. hvd.init() num_worker = hvd.size() rank = hvd.rank() # verify that MPI multi-threading is supported. assert hvd.mpi_threads_supported() # make sure MPI is not re-initialized. import mpi4py.rc mpi4py.rc.initialize = False # import mpi4py from mpi4py import MPI comm = MPI.COMM_WORLD # check size and rank are syncronized assert num_worker == comm.Get_size() assert rank == comm.Get_rank() else: num_worker = 1 rank = 0 ModelClass = config.NETWORK_CLASS network_kwargs = dict( (key.lower(), val) for key, val in config.NETWORK.items()) if "train_validation_saving_size".upper() in config.DATASET.keys(): use_train_validation_saving = config.DATASET.TRAIN_VALIDATION_SAVING_SIZE > 0 else: use_train_validation_saving = False if use_train_validation_saving: top_train_validation_saving_set_accuracy = 0 train_dataset = setup_dataset(config, "train", rank) print("train dataset num:", train_dataset.num_per_epoch) if use_train_validation_saving: train_validation_saving_dataset = setup_dataset( config, "train_validation_saving", rank) print("train_validation_saving dataset num:", train_validation_saving_dataset.num_per_epoch) validation_dataset = setup_dataset(config, "validation", rank) print("validation dataset num:", validation_dataset.num_per_epoch) graph = tf.Graph() with graph.as_default(): if ModelClass.__module__.startswith("lmnet.networks.object_detection"): model = ModelClass( classes=train_dataset.classes, num_max_boxes=train_dataset.num_max_boxes, is_debug=config.IS_DEBUG, **network_kwargs, ) elif ModelClass.__module__.startswith("lmnet.networks.segmentation"): model = ModelClass( classes=train_dataset.classes, label_colors=train_dataset.label_colors, is_debug=config.IS_DEBUG, **network_kwargs, ) else: model = ModelClass( classes=train_dataset.classes, is_debug=config.IS_DEBUG, **network_kwargs, ) global_step = tf.Variable(0, name="global_step", trainable=False) is_training_placeholder = tf.placeholder( tf.bool, name="is_training_placeholder") images_placeholder, labels_placeholder = model.placeholderes() output = model.inference(images_placeholder, is_training_placeholder) if ModelClass.__module__.startswith("lmnet.networks.object_detection"): loss = model.loss(output, labels_placeholder, is_training_placeholder) else: loss = model.loss(output, labels_placeholder) opt = model.optimizer(global_step) if config.IS_DISTRIBUTION: # add Horovod Distributed Optimizer opt = hvd.DistributedOptimizer(opt) train_op = model.train(loss, opt, global_step) metrics_ops_dict, metrics_update_op = model.metrics( output, labels_placeholder) # TODO(wakisaka): Deal with many networks. model.summary(output, labels_placeholder) summary_op = tf.summary.merge_all() metrics_summary_op, metrics_placeholders = executor.prepare_metrics( metrics_ops_dict) init_op = tf.global_variables_initializer() reset_metrics_op = tf.local_variables_initializer() if config.IS_DISTRIBUTION: # add Horovod broadcasting variables from rank 0 to all bcast_global_variables_op = hvd.broadcast_global_variables(0) if use_train_validation_saving: saver = tf.train.Saver(max_to_keep=1) else: saver = tf.train.Saver(max_to_keep=None) if config.IS_PRETRAIN: all_vars = tf.global_variables() pretrain_var_list = [ var for var in all_vars if var.name.startswith(tuple(config.PRETRAIN_VARS)) ] print("pretrain_vars", [var.name for var in pretrain_var_list]) pretrain_saver = tf.train.Saver(pretrain_var_list, name="pretrain_saver") if config.IS_DISTRIBUTION: # For distributed training session_config = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True, visible_device_list=str(hvd.local_rank()))) else: # TODO(wakisaka): For debug. # session_config = tf.ConfigProto( # gpu_options=tf.GPUOptions( # allow_growth=True, # per_process_gpu_memory_fraction=0.1 # ) # ) session_config = tf.ConfigProto( ) # tf.ConfigProto(log_device_placement=True) # TODO(wakisaka): XLA JIT # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 sess = tf.Session(graph=graph, config=session_config) sess.run([init_op, reset_metrics_op]) if rank == 0: train_writer = tf.summary.FileWriter( environment.TENSORBOARD_DIR + "/train", sess.graph) if use_train_validation_saving: train_val_saving_writer = tf.summary.FileWriter( environment.TENSORBOARD_DIR + "/train_validation_saving") val_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR + "/validation") if config.IS_PRETRAIN: print("------- Load pretrain data ----------") pretrain_saver.restore( sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE)) sess.run(tf.assign(global_step, 0)) last_step = 0 # for recovery ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR) if ckpt and ckpt.model_checkpoint_path: print("--------- Restore last checkpoint -------------") saver.restore(sess, ckpt.model_checkpoint_path) # saver.recover_last_checkpoints(ckpt.model_checkpoint_path) last_step = sess.run(global_step) # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard. # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072 train_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1) val_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1) print("recovered. last step", last_step) if config.IS_DISTRIBUTION: # broadcast variables from rank 0 to all other processes sess.run(bcast_global_variables_op) # calculate step per epoch for each nodes train_num_per_epoch = train_dataset.num_per_epoch num_per_nodes = (train_num_per_epoch + num_worker - 1) // num_worker step_per_epoch = num_per_nodes // config.BATCH_SIZE begin_index = (train_num_per_epoch * rank) // num_worker end_index = begin_index + num_per_nodes last_step = sess.run(global_step) # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS. if "MAX_EPOCHS" in config: max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE * config.MAX_EPOCHS) else: max_steps = config.MAX_STEPS print("max_steps: {}".format(max_steps)) for step in range(last_step, max_steps): print("step", step) if config.IS_DISTRIBUTION: # scatter dataset if step % step_per_epoch == 0: indices = train_dataset.get_shuffle_index( ) if rank == 0 else None # broadcast shuffled indices indices = comm.bcast(indices, 0) feed_indices = indices[begin_index:end_index] # update each dataset by splited indices train_dataset.update_dataset(feed_indices) images, labels = train_dataset.feed() feed_dict = { is_training_placeholder: True, images_placeholder: images, labels_placeholder: labels, } if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0: # Runtime statistics for develop. # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # run_metadata = tf.RunMetadata() sess.run(reset_metrics_op) _, summary, _ = sess.run( [train_op, summary_op, metrics_update_op], feed_dict=feed_dict, # options=run_options, # run_metadata=run_metadata, ) # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1)) train_writer.add_summary(summary, step + 1) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) train_writer.add_summary(metrics_summary, step + 1) else: sess.run([train_op], feed_dict=feed_dict) to_be_saved = step == 0 or ( step + 1) == max_steps or (step + 1) % config.SAVE_STEPS == 0 if to_be_saved and rank == 0: if use_train_validation_saving: sess.run(reset_metrics_op) train_validation_saving_step_size = int( math.ceil(train_validation_saving_dataset.num_per_epoch / config.BATCH_SIZE)) print("train_validation_saving_step_size", train_validation_saving_step_size) current_train_validation_saving_set_accuracy = 0 for train_validation_saving_step in range( train_validation_saving_step_size): print("train_validation_saving_step", train_validation_saving_step) images, labels = train_validation_saving_dataset.feed() feed_dict = { is_training_placeholder: False, images_placeholder: images, labels_placeholder: labels, } if train_validation_saving_step % config.SUMMARISE_STEPS == 0: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) train_val_saving_writer.add_summary(summary, step + 1) else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) train_val_saving_writer.add_summary(metrics_summary, step + 1) current_train_validation_saving_set_accuracy = sess.run( metrics_ops_dict["accuracy"]) if current_train_validation_saving_set_accuracy > top_train_validation_saving_set_accuracy: top_train_validation_saving_set_accuracy = current_train_validation_saving_set_accuracy print("New top train_validation_saving accuracy is: ", top_train_validation_saving_set_accuracy) _save_checkpoint(saver, sess, global_step, step) else: _save_checkpoint(saver, sess, global_step, step) if step == 0: # check create pb on only first step. minimal_graph = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), ["output"], ) pb_name = "minimal_graph_with_shape_{}.pb".format(step + 1) pbtxt_name = "minimal_graph_with_shape_{}.pbtxt".format(step + 1) tf.train.write_graph(minimal_graph, environment.CHECKPOINTS_DIR, pb_name, as_text=False) tf.train.write_graph(minimal_graph, environment.CHECKPOINTS_DIR, pbtxt_name, as_text=True) if step == 0 or (step + 1) % config.TEST_STEPS == 0: # init metrics values sess.run(reset_metrics_op) test_step_size = int( math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE)) print("test_step_size", test_step_size) for test_step in range(test_step_size): print("test_step", test_step) images, labels = validation_dataset.feed() feed_dict = { is_training_placeholder: False, images_placeholder: images, labels_placeholder: labels, } if test_step % config.SUMMARISE_STEPS == 0: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) if rank == 0: val_writer.add_summary(summary, step + 1) else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) if rank == 0: val_writer.add_summary(metrics_summary, step + 1) # training loop end. print("reach max step")