def main(_): # Always keep the cpu as default with tf.Graph().as_default(), tf.device('/cpu:0'): if FLAGS.validation_interval == 0: FLAGS.validation_db = None # Set Tensorboard log directory if FLAGS.summaries_dir: # The following gives a nice but unrobust timestamp FLAGS.summaries_dir = os.path.join( FLAGS.summaries_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) if not FLAGS.train_db and not FLAGS.validation_db and not FLAGS.inference_db and not FLAGS.visualizeModelPath: logging.error( "At least one of the following file sources should be specified: " "train_db, validation_db or inference_db") exit(-1) if FLAGS.seed: tf.set_random_seed(FLAGS.seed) batch_size_train = FLAGS.batch_size batch_size_val = FLAGS.batch_size logging.info("Train batch size is %s and validation batch size is %s", batch_size_train, batch_size_val) # This variable keeps track of next epoch, when to perform validation. next_validation = FLAGS.validation_interval logging.info( "Training epochs to be completed for each validation : %s", next_validation) # This variable keeps track of next epoch, when to save model weights. next_snapshot_save = FLAGS.snapshotInterval logging.info( "Training epochs to be completed before taking a snapshot : %s", next_snapshot_save) last_snapshot_save_epoch = 0 snapshot_prefix = FLAGS.snapshotPrefix if FLAGS.snapshotPrefix else FLAGS.network.split( '.')[0] logging.info("Model weights will be saved as %s_<EPOCH>_Model.ckpt", snapshot_prefix) if not os.path.exists(FLAGS.save): os.makedirs(FLAGS.save) logging.info("Created a directory %s to save all the snapshots", FLAGS.save) # Load mean variable if FLAGS.subtractMean == 'none': mean_loader = None else: if not FLAGS.mean: logging.error( "subtractMean parameter not set to 'none' yet mean image path is unset" ) exit(-1) logging.info("Loading mean tensor from %s file", FLAGS.mean) mean_loader = tf_data.MeanLoader(FLAGS.mean, FLAGS.subtractMean, FLAGS.bitdepth) classes = 0 nclasses = 0 if FLAGS.labels_list: logging.info("Loading label definitions from %s file", FLAGS.labels_list) classes = loadLabels(FLAGS.labels_list) nclasses = len(classes) if not classes: logging.error("Reading labels file %s failed.", FLAGS.labels_list) exit(-1) logging.info("Found %s classes", nclasses) # Create a data-augmentation dict aug_dict = { 'aug_flip': FLAGS.augFlip, 'aug_noise': FLAGS.augNoise, 'aug_contrast': FLAGS.augContrast, 'aug_whitening': FLAGS.augWhitening, 'aug_HSV': { 'h': FLAGS.augHSVh, 's': FLAGS.augHSVs, 'v': FLAGS.augHSVv, }, } # Import the network file path_network = os.path.join( os.path.dirname(os.path.realpath(__file__)), FLAGS.networkDirectory, FLAGS.network) exec(open(path_network).read(), globals()) try: UserModel except NameError: logging.error("The user model class 'UserModel' is not defined.") exit(-1) if not inspect.isclass(UserModel): # noqa logging.error("The user model class 'UserModel' is not a class.") exit(-1) # @TODO(tzaman) - add mode checks to UserModel if FLAGS.train_db: with tf.name_scope(digits.STAGE_TRAIN) as stage_scope: train_model = Model(digits.STAGE_TRAIN, FLAGS.croplen, nclasses, FLAGS.optimization, FLAGS.momentum) train_model.create_dataloader(FLAGS.train_db) train_model.dataloader.setup(FLAGS.train_labels, FLAGS.shuffle, FLAGS.bitdepth, batch_size_train, FLAGS.epoch, FLAGS.seed) train_model.dataloader.set_augmentation(mean_loader, aug_dict) train_model.create_model(UserModel, stage_scope) # noqa if FLAGS.validation_db: with tf.name_scope(digits.STAGE_VAL) as stage_scope: val_model = Model(digits.STAGE_VAL, FLAGS.croplen, nclasses) val_model.create_dataloader(FLAGS.validation_db) val_model.dataloader.setup( FLAGS.validation_labels, False, FLAGS.bitdepth, batch_size_val, 1e9, FLAGS.seed) # @TODO(tzaman): set numepochs to 1 val_model.dataloader.set_augmentation(mean_loader) val_model.create_model(UserModel, stage_scope) # noqa if FLAGS.inference_db: with tf.name_scope(digits.STAGE_INF) as stage_scope: inf_model = Model(digits.STAGE_INF, FLAGS.croplen, nclasses) inf_model.create_dataloader(FLAGS.inference_db) inf_model.dataloader.setup(None, False, FLAGS.bitdepth, FLAGS.batch_size, 1, FLAGS.seed) inf_model.dataloader.set_augmentation(mean_loader) inf_model.create_model(UserModel, stage_scope) # noqa # Start running operations on the Graph. allow_soft_placement must be set to # True to build towers on GPU, as some of the ops do not have GPU # implementations. sess = tf.Session(config=tf.ConfigProto( allow_soft_placement= True, # will automatically do non-gpu supported ops on cpu inter_op_parallelism_threads=TF_INTER_OP_THREADS, intra_op_parallelism_threads=TF_INTRA_OP_THREADS, log_device_placement=FLAGS.log_device_placement)) if FLAGS.visualizeModelPath: visualize_graph(sess.graph_def, FLAGS.visualizeModelPath) exit(0) # Saver creation. if FLAGS.save_vars == 'all': vars_to_save = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) elif FLAGS.save_vars == 'trainable': vars_to_save = tf.all_variables() else: logging.error('Unknown save_var flag (%s)' % FLAGS.save_vars) exit(-1) saver = tf.train.Saver(vars_to_save, max_to_keep=0, sharded=FLAGS.serving_export) # Initialize variables init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # If weights option is set, preload weights from existing models appropriately if FLAGS.weights: load_snapshot(sess, FLAGS.weights, tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) # Tensorboard: Merge all the summaries and write them out writer = tf.train.SummaryWriter( os.path.join(FLAGS.summaries_dir, 'tb'), sess.graph) # If we are inferencing, only do that. if FLAGS.inference_db: inf_model.start_queue_runners(sess) Inference(sess, inf_model) queue_size_op = [] for n in tf.get_default_graph().as_graph_def().node: if '_Size' in n.name: queue_size_op.append(n.name + ':0') start = time.time() # @TODO(tzaman) - removeme # Initial Forward Validation Pass if FLAGS.validation_db: val_model.start_queue_runners(sess) Validation(sess, val_model, 0) if FLAGS.train_db: # During training, a log output should occur at least X times per epoch or every X images, whichever lower train_steps_per_epoch = train_model.dataloader.get_total( ) / batch_size_train if math.ceil(train_steps_per_epoch / MIN_LOGS_PER_TRAIN_EPOCH) < math.ceil( 5000 / batch_size_train): logging_interval_step = int( math.ceil(train_steps_per_epoch / MIN_LOGS_PER_TRAIN_EPOCH)) else: logging_interval_step = int(math.ceil(5000 / batch_size_train)) logging.info( "During training. details will be logged after every %s steps (batches)", logging_interval_step) # epoch value will be calculated for every batch size. To maintain unique epoch value between batches, # it needs to be rounded to the required number of significant digits. epoch_round = 0 # holds the required number of significant digits for round function. tmp_batchsize = batch_size_train * logging_interval_step while tmp_batchsize <= train_model.dataloader.get_total(): tmp_batchsize = tmp_batchsize * 10 epoch_round += 1 logging.info( "While logging, epoch value will be rounded to %s significant digits", epoch_round) # Create the learning rate policy total_training_steps = train_model.dataloader.num_epochs * train_model.dataloader.get_total() / \ train_model.dataloader.batch_size lrpolicy = lr_policy.LRPolicy(FLAGS.lr_policy, FLAGS.lr_base_rate, FLAGS.lr_gamma, FLAGS.lr_power, total_training_steps, FLAGS.lr_stepvalues) train_model.start_queue_runners(sess) # Training logging.info('Started training the model') current_epoch = 0 try: step = 0 step_last_log = 0 print_vals_sum = 0 while not train_model.queue_coord.should_stop(): log_runtime = FLAGS.log_runtime_stats_per_step and ( step % FLAGS.log_runtime_stats_per_step == 0) run_options = None run_metadata = None if log_runtime: # For a HARDWARE_TRACE you need NVIDIA CUPTI, a 'CUDA-EXTRA' # SOFTWARE_TRACE HARDWARE_TRACE FULL_TRACE run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() feed_dict = { train_model.learning_rate: lrpolicy.get_learning_rate(step) } if False: for op in train_model.train: _, summary_str, step = sess.run( [ op, train_model.summary, train_model.global_step ], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) else: _, summary_str, step = sess.run( [ train_model.train, train_model.summary, train_model.global_step ], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) # HACK step = step / len(train_model.train) # logging.info(sess.run(queue_size_op)) # DEVELOPMENT: for checking the queue size if log_runtime: writer.add_run_metadata(run_metadata, str(step)) save_timeline_trace(run_metadata, FLAGS.save, int(step)) writer.add_summary(summary_str, step) # Parse the summary tags, print_vals = summary_to_lists(summary_str) print_vals_sum = print_vals + print_vals_sum # @TODO(tzaman): account for variable batch_size value on very last epoch current_epoch = round((step * batch_size_train) / train_model.dataloader.get_total(), epoch_round) # Start with a forward pass if ((step % logging_interval_step) == 0): steps_since_log = step - step_last_log print_list = print_summarylist( tags, print_vals_sum / steps_since_log) logging.info("Training (epoch " + str(current_epoch) + "): " + print_list) print_vals_sum = 0 step_last_log = step # Potential Validation Pass if FLAGS.validation_db and current_epoch >= next_validation: Validation(sess, val_model, current_epoch) # Find next nearest epoch value that exactly divisible by FLAGS.validation_interval: next_validation = (round(float(current_epoch)/FLAGS.validation_interval) + 1) * \ FLAGS.validation_interval # Saving Snapshot if FLAGS.snapshotInterval > 0 and current_epoch >= next_snapshot_save: save_snapshot(sess, saver, FLAGS.save, snapshot_prefix, current_epoch, FLAGS.serving_export) # To find next nearest epoch value that exactly divisible by FLAGS.snapshotInterval next_snapshot_save = (round(float(current_epoch)/FLAGS.snapshotInterval) + 1) * \ FLAGS.snapshotInterval last_snapshot_save_epoch = current_epoch writer.flush() except tf.errors.OutOfRangeError: logging.info( 'Done training for epochs: tf.errors.OutOfRangeError') except ValueError as err: logging.error(err.args[0]) exit(-1) # DIGITS wants a dirty error. except (KeyboardInterrupt): logging.info('Interrupt signal received.') # If required, perform final snapshot save if FLAGS.snapshotInterval > 0 and FLAGS.epoch > last_snapshot_save_epoch: save_snapshot(sess, saver, FLAGS.save, snapshot_prefix, FLAGS.epoch, FLAGS.serving_export) print('Training wall-time:', time.time() - start) # @TODO(tzaman) - removeme # If required, perform final Validation pass if FLAGS.validation_db and current_epoch >= next_validation: Validation(sess, val_model, current_epoch) if FLAGS.train_db: del train_model if FLAGS.validation_db: del val_model if FLAGS.inference_db: del inf_model # We need to call sess.close() because we've used a with block sess.close() writer.close() logging.info('END') exit(0)
def main(_): # Always keep the cpu as default with tf.Graph().as_default(), tf.device('/cpu:0'): if FLAGS.validation_interval == 0: FLAGS.validation_db = None # Set Tensorboard log directory if FLAGS.summaries_dir: # The following gives a nice but unrobust timestamp FLAGS.summaries_dir = os.path.join( FLAGS.summaries_dir, datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) if not FLAGS.train_db and not FLAGS.validation_db and not FLAGS.inference_db: logging.error( "At least one of the following file sources should be specified: " "train_db, validation_db or inference_db") exit(-1) if FLAGS.seed: tf.set_random_seed(FLAGS.seed) batch_size_train = FLAGS.batch_size batch_size_val = FLAGS.batch_size logging.info("Train batch size is %s and validation batch size is %s", batch_size_train, batch_size_val) # This variable keeps track of next epoch, when to perform validation. next_validation = FLAGS.validation_interval logging.info( "Training epochs to be completed for each validation : %s", next_validation) # This variable keeps track of next epoch, when to save model weights. next_snapshot_save = FLAGS.snapshotInterval logging.info( "Training epochs to be completed before taking a snapshot : %s", next_snapshot_save) last_snapshot_save_epoch = 0 snapshot_prefix = FLAGS.snapshotPrefix if FLAGS.snapshotPrefix else FLAGS.network.split( '.')[0] logging.info( 'Model weights will be saves as {}_<EPOCH>_Model.ckpt'.format( snapshot_prefix)) if not os.path.exists(FLAGS.save): os.makedirs(FLAGS.save) logging.info("Created a directory %s to save all the snapshots", FLAGS.save) classes = 0 nclasses = 0 if FLAGS.labels_list: logging.info("Loading label definitions from %s file", FLAGS.labels_list) classes = loadLabels(FLAGS.labels_list) nclasses = len(classes) if not classes: logging.error("Reading labels file %s failed.", FLAGS.labels_list) exit(-1) logging.info("Found %s classes", nclasses) # Debugging NaNs and Inf check_op = tf.add_check_numerics_ops() # Import the network file path_network = os.path.join( os.path.dirname(os.path.realpath(__file__)), FLAGS.networkDirectory, FLAGS.network) exec(open(path_network).read(), globals()) try: UserModel except NameError: logging.error("The user model class 'UserModel' is not defined.") exit(-1) if not inspect.isclass(UserModel): logging.error("The user model class 'UserModel' is not a class.") eixt(-1) if FLAGS.train_db: with tf.name_scope(utils.STAGE_TRAIN) as stage_scope: train_model = Model(utils.STAGE_TRAIN, FLAGS.croplen, nclasses, FLAGS.optimization, FLAGS.momentum) train_model.create_dataloader(FLAGS.train_db) train_model.dataloader.setup(FLAGS.train_labels, FLAGS.shuffle, FLAGS.bitdepth, batch_size_train, FLAGS.epoch, FLAGS.seed) train_model.create_model(UserModel, stage_scope) if FLAGS.validation_db: with tf.name_scope(utils.STAGE_VAL) as stage_scope: val_model = Model(utils.STAGE_VAL, FLAGS.croplen, nclasses, reuse_variable=True) val_model.create_dataloader(FLAGS.validation_db) val_model.dataloader.setup(FLAGS.validation_labels, False, FLAGS.bitdepth, batch_size_val, 1e9, FLAGS.seed) val_model.create_model(UserModel, stage_scope) if FLAGS.inference_db: with tf.name_scope(utils.STAGE_INF) as stage_scope: inf_model = Model(utils.STAGE_INF, FLAGS.croplen, nclasses) inf_model.create_dataloader(FLAGS.inference_db) inf_model.dataloader.setup(None, False, FLAGS.bitdepth, FLAGS.batch_size, 1, FLAGS.seed) inf_model.create_model(UserModel, stage_scope) # Start running operations on the Graph. allow_soft_placement must be set to # True to build towers on GPU, as some of the ops do not have GPU # implementations. sess = tf.Session(config=tf.ConfigProto( allow_soft_placement= True, # will automatically do non-gpu supported ops on cpu inter_op_parallelism_threads=TF_INTER_OP_THREADS, intra_op_parallelism_threads=TF_INTRA_OP_THREADS, log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) # Saver creation. if FLAGS.save_vars == 'all': vars_to_save = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # vars_to_save = tf.global_variables() elif FLAGS.save_vars == 'trainable': vars_to_save = tf.trainable_variables() else: logging.error('Unknown save_var flag ({})'.format(FLAGS.save_vars)) exit(-1) saver = tf.train.Saver(vars_to_save, max_to_keep=0, sharded=FLAGS.serving_export) # Initialize variables init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # If weights option is set, preload weights from existing models appropriatedly if FLAGS.weights: load_snapshot(sess, FLAGS.weights, tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) # Tensorboard: Merge all the summaries and write them out writer = ops.SummaryWriter(os.path.join(FLAGS.summaries_dir, 'tb'), sess.graph) # If we are inferencing , only do that if FLAGS.inference_db: inf_model.start_queue_runners(sess) Inference(sess, inf_model) queue_size_op = [] for n in tf.get_default_graph().as_graph_def().node: if '_Size' in n.name: logging.debug('graph node name: {}'.format(n.name)) queue_size_op.append(n.name + ':0') start = time.time() # initail Forward VAlidation Pass if FLAGS.validation_db: val_model.start_queue_runners(sess) Validation(sess, val_model, 0, check_op) if FLAGS.train_db: # During training, a log output should occur at least X times per epoch or every X images, whichever lower train_steps_per_epoch = train_model.dataloader.get_total( ) / batch_size_train if math.ceil(train_steps_per_epoch / MIN_LOGS_PER_TRAIN_EPOCH) < math.ceil( 5000 / batch_size_train): logging_interval_step = int( math.ceil(train_steps_per_epoch / MIN_LOGS_PER_TRAIN_EPOCH)) else: logging_interval_step = int(math.ceil(5000 / batch_size_train)) logging.info( "During training. details will be logged after every %s steps (batches)", logging_interval_step) # epoch value will be calculated for every batch size. To maintain unique epoch value between batches, # it needs to be rounded to the required number of significant digits. epoch_round = 0 # holds the required number of significant digits for round function. tmp_batchsize = batch_size_train * logging_interval_step while tmp_batchsize <= train_model.dataloader.get_total(): tmp_batchsize = tmp_batchsize * 10 epoch_round += 1 logging.info( "While logging, epoch value will be rounded to %s significant digits", epoch_round) # Create the learning rate policy total_trainning_steps = train_model.dataloader.num_epochs * train_model.dataloader.get_total( ) / train_model.dataloader.batch_size lrpolicy = lr_policy.LRPolicy(FLAGS.lr_policy, FLAGS.lr_base_rate, FLAGS.lr_gamma, FLAGS.lr_power, total_trainning_steps, FLAGS.lr_stepvalues) train_model.start_queue_runners(sess) # Trainnig logging.info('Started training the model') current_epoch = 0 try: step = 0 step_last_log = 0 print_vals_sum = 0 while not train_model.queue_coord.should_stop(): log_runtime = FLAGS.log_runtime_stats_per_step and ( step % FLAGS.log_runtime_stats_per_step == 0) # log runtime for benchmark run_options = None run_metadata = None if log_runtime: # For a HARDWARE_TRACE you need NVIDIA CUPTI, a 'CUDA-EXTRA' # SOFTWARE_TRACE HARDWARE_TRACE FULL_TRACE run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() feed_dict = { train_model.learning_rate: lrpolicy.get_learning_rate(step) } # Session.run _, summary_str, step, _ = sess.run( [ train_model.train, train_model.summary, train_model.global_step, check_op ], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) step = step / len(train_model.train) #1. display the eaxct total memory, compute time, and tensor output sizes to TensorBoard if log_runtime: writer.add_run_metadata(run_metadata, str(step)) #2, another method to trace the timeline of operations # You can then open Google Chrome, go to the page chrome://tracing and load the timeline.json file if log_runtime: # Create the Timeline object, and write it to json tl = timeline.TimeLine(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() prof_path = os.path.join(FLAGS.summaries_dir, 'benchmark') with open(os.path.join(prof_path, 'timeline.json'), 'w') as f: f.write(ctf) # sumamry for TensorBoard writer.add_summary(summary_str, step) current_epoch = round((step * batch_size_train) / train_model.dataloader.get_total(), epoch_round) # Potential Validation Pass if FLAGS.validation_db and current_epoch >= next_validation: Validation(sess, val_model, current_epoch, check_op) # Find next nearest epoch value that exactly divisible by FLAGS.validation_interval: next_validation = (round(float(current_epoch)/FLAGS.validation_interval) + 1) * \ FLAGS.validation_interval # Saving Snapshot if FLAGS.snapshotInterval > 0 and current_epoch >= next_snapshot_save: save_snapshot(sess, saver, FLAGS.save, snapshot_prefix, current_epoch, FLAGS.serving_export) # To find next nearest epoch value that exactly divisible by FLAGS.snapshotInterval next_snapshot_save = (round(float(current_epoch)/FLAGS.snapshotInterval) + 1) * \ FLAGS.snapshotInterval last_snapshot_save_epoch = current_epoch writer.flush() except tf.errors.OutOfRangeError: logging.info( 'Done training for epochs limit reached: tf.errors.OutOfRangeError' ) except ValueError as err: logging.error(err.args[0]) exit(-1) except (KeyboardInterrupt): logging.info('Interrupt signal received.') # If required, perform final snapshot save if FLAGS.snapshotInterval > 0 and FLAGS.epoch > last_snapshot_save_epoch: save_snapshot(sess, saver, FLAGS.save, snapshot_prefix, FLAGS.epoch, FLAGS.serving_export) print('Training wall-time:', time.time() - start) # If required, perform final Validation pass if FLAGS.validation_db and current_epoch >= next_validation: Validation(sess, val_model, current_epoch, check_op) # Close and terminate the quques if FLAGS.train_db: del train_model if FLAGS.validation_db: del val_model if FLAGS.inference_db: del inf_model # We need to call sess.close() because we've used a with block sess.close() writer.close() logging.info('END') exit(0)