def _train_deeplab_model(iterator, num_of_classes, ignore_label): """Trains the deeplab model. Args: iterator: An iterator of type tf.data.Iterator for images and labels. num_of_classes: Number of classes for the dataset. ignore_label: Ignore label for the dataset. Returns: train_tensor: A tensor to update the model variables. summary_op: An operation to log the summaries. """ global_step = tf.train.get_or_create_global_step() learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) tower_losses = [] tower_grads = [] for i in range(FLAGS.num_clones): with tf.device('/gpu:%d' % i): # First tower has default name scope. name_scope = ('clone_%d' % i) if i else '' with tf.name_scope(name_scope) as scope: loss = _tower_loss(iterator=iterator, num_of_classes=num_of_classes, ignore_label=ignore_label, scope=scope, reuse_variable=(i != 0)) tower_losses.append(loss) if FLAGS.quantize_delay_step >= 0: if FLAGS.num_clones > 1: raise ValueError('Quantization doesn\'t support multi-clone yet.') tf.contrib.quantize.create_training_graph( quant_delay=FLAGS.quantize_delay_step) for i in range(FLAGS.num_clones): with tf.device('/gpu:%d' % i): name_scope = ('clone_%d' % i) if i else '' with tf.name_scope(name_scope) as scope: grads = optimizer.compute_gradients(tower_losses[i]) tower_grads.append(grads) with tf.device('/cpu:0'): grads_and_vars = _average_gradients(tower_grads) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = tf.contrib.training.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) # Gather update_ops. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops.append(grad_updates) update_op = tf.group(*update_ops) total_loss = tf.losses.get_total_loss(add_regularization_losses=True) # Print total loss to the terminal. # This implementation is mirrored from tf.slim.summaries. should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps), 0) total_loss = tf.cond( should_log, lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'), lambda: total_loss) tf.summary.scalar('total_loss', total_loss) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Excludes summaries from towers other than the first one. summary_op = tf.summary.merge_all(scope='(?!clone_)') return train_tensor, summary_op
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) labels = FLAGS.labels.split(',') num_classes = len(labels) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Creating train logdir: %s', FLAGS.train_logdir) with tf.Graph().as_default() as graph: global_step = tf.train.get_or_create_global_step() X = tf.placeholder(tf.float32, [None, FLAGS.height, FLAGS.width, 3], name='X') ground_truth = tf.placeholder(tf.int64, [None], name='ground_truth') is_training = tf.placeholder(tf.bool, name='is_training') keep_prob = tf.placeholder(tf.float32, [], name='keep_prob') # learning_rate = tf.placeholder(tf.float32, []) # apply SENet logits, end_points = model.hcd_model(X, num_classes=num_classes, is_training=is_training, keep_prob=keep_prob, attention_module='se_block') logits = tf.cond( is_training, lambda: tf.identity(logits), lambda: tf.reduce_mean(tf.reshape( logits, [FLAGS.val_batch_size, TEN_CROP, -1]), axis=1)) # Print name and shape of each tensor. tf.logging.info("++++++++++++++++++++++++++++++++++") tf.logging.info("Layers") tf.logging.info("++++++++++++++++++++++++++++++++++") for k, v in end_points.items(): tf.logging.info('name = %s, shape = %s' % (v.name, v.get_shape())) # # Print name and shape of parameter nodes (values not yet initialized) # tf.logging.info("++++++++++++++++++++++++++++++++++") # tf.logging.info("Parameters") # tf.logging.info("++++++++++++++++++++++++++++++++++") # for v in slim.get_model_variables(): # tf.logging.info('name = %s, shape = %s' % (v.name, v.get_shape())) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) prediction = tf.argmax(logits, axis=1, name='prediction') correct_prediction = tf.equal(prediction, ground_truth) confusion_matrix = tf.confusion_matrix(ground_truth, prediction, num_classes=num_classes) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy') summaries.add(tf.summary.scalar('accuracy', accuracy)) # Define loss tf.losses.sparse_softmax_cross_entropy(labels=ground_truth, logits=logits) # Gather update_ops. These contain, for example, # the updates for the batch_norm variables created by model. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # # Add summaries for model variables. # for model_var in slim.get_model_variables(): # summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) optimizer = tf.train.AdamOptimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) total_loss, grads_and_vars = train_utils.optimize(optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # # Modify the gradients for biases and last layer variables. # last_layers = train_utils.get_extra_layer_scopes( # FLAGS.last_layers_contain_logits_only) # grad_mult = train_utils.get_model_gradient_multipliers( # last_layers, FLAGS.last_layer_gradient_multiplier) # if grad_mult: # grads_and_vars = slim.learning.multiply_gradients( # grads_and_vars, grad_mult) # Gradient clipping # clipped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in grads_and_vars] # Otherwise -> # gradients, variables = zip(*optimizer.compute_gradients(loss)) # gradients, _ = tf.clip_by_global_norm(grads_and_vars[0], 5.0) # optimize = optimizer.apply_gradients(zip(gradients, grads_and_vars[1])) # TensorBoard: How to plot histogram for gradients grad_summ_op = tf.summary.merge([ tf.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads_and_vars ]) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') # Add the summaries. These contain the summaries # created by model and either optimize() or _gather_loss(). summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) train_writer = tf.summary.FileWriter(FLAGS.summaries_dir, graph) validation_writer = tf.summary.FileWriter( FLAGS.summaries_dir + '/validation', graph) ############### # Prepare data ############### # training dateset tfrecord_filenames = tf.placeholder(tf.string, shape=[]) dataset = train_data.Dataset(tfrecord_filenames, FLAGS.batch_size, FLAGS.how_many_training_epochs, FLAGS.height, FLAGS.width) iterator = dataset.dataset.make_initializable_iterator() next_batch = iterator.get_next() # validation dateset val_dataset = val_data.Dataset(tfrecord_filenames, FLAGS.val_batch_size, FLAGS.height, FLAGS.width) val_iterator = val_dataset.dataset.make_initializable_iterator() val_next_batch = val_iterator.get_next() sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True)) with tf.Session(config=sess_config) as sess: sess.run(tf.global_variables_initializer()) # Create a saver object which will save all the variables saver = tf.train.Saver() if FLAGS.saved_checkpoint_dir: if tf.gfile.IsDirectory(FLAGS.train_logdir): checkpoint_path = tf.train.latest_checkpoint( FLAGS.train_logdir) else: checkpoint_path = FLAGS.train_logdir saver.restore(sess, checkpoint_path) if FLAGS.pre_trained_checkpoint: train_utils.restore_fn(FLAGS) start_epoch = 0 # Get the number of training/validation steps per epoch tr_batches = int(PCAM_TRAIN_DATA_SIZE / FLAGS.batch_size) if PCAM_TRAIN_DATA_SIZE % FLAGS.batch_size > 0: tr_batches += 1 val_batches = int(PCAM_VALIDATE_DATA_SIZE / FLAGS.val_batch_size) if PCAM_VALIDATE_DATA_SIZE % FLAGS.val_batch_size > 0: val_batches += 1 # The filenames argument to the TFRecordDataset initializer can either be a string, # a list of strings, or a tf.Tensor of strings. train_record_filenames = os.path.join(FLAGS.dataset_dir, 'train.record') validate_record_filenames = os.path.join(FLAGS.dataset_dir, 'validate.record') ############################ # Training loop. ############################ for num_epoch in range(start_epoch, FLAGS.how_many_training_epochs): print("------------------------------------") print(" Epoch {} ".format(num_epoch)) print("------------------------------------") sess.run( iterator.initializer, feed_dict={tfrecord_filenames: train_record_filenames}) for step in range(tr_batches): train_batch_xs, train_batch_ys = sess.run(next_batch) # # Verify image # # assert not np.any(np.isnan(train_batch_xs)) # n_batch = train_batch_xs.shape[0] # # n_view = train_batch_xs.shape[1] # for i in range(n_batch): # img = train_batch_xs[i] # # scipy.misc.toimage(img).show() Or # img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) # cv2.imwrite('/home/ace19/Pictures/' + str(i) + '.png', img) # # cv2.imshow(str(train_batch_ys[idx]), img) # cv2.waitKey(100) # cv2.destroyAllWindows() augmented_batch_xs = aug_utils.aug(train_batch_xs) # # Verify image # # assert not np.any(np.isnan(train_batch_xs)) # n_batch = augmented_batch_xs.shape[0] # # n_view = train_batch_xs.shape[1] # for i in range(n_batch): # img = augmented_batch_xs[i] # # scipy.misc.toimage(img).show() Or # img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) # cv2.imwrite('/home/ace19/Pictures/' + str(i) + '.png', img) # # cv2.imshow(str(train_batch_ys[idx]), img) # cv2.waitKey(100) # cv2.destroyAllWindows() # Run the graph with this batch of training data and learning rate policy. lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \ sess.run([learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op], feed_dict={ X: augmented_batch_xs, ground_truth: train_batch_ys, is_training: True, keep_prob: 0.8 }) train_writer.add_summary(train_summary, num_epoch) train_writer.add_summary(grad_vals, num_epoch) tf.logging.info( 'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f' % (num_epoch, step, lr, train_accuracy * 100, train_loss)) ################################################### # Validate the model on the validation set ################################################### tf.logging.info('--------------------------') tf.logging.info(' Start validation ') tf.logging.info('--------------------------') total_val_accuracy = 0 validation_count = 0 total_conf_matrix = None # Reinitialize iterator with the validation dataset sess.run( val_iterator.initializer, feed_dict={tfrecord_filenames: validate_record_filenames}) for step in range(val_batches): validation_batch_xs, validation_batch_ys = sess.run( val_next_batch) # TTA batch_size, n_crops, c, h, w = validation_batch_xs.shape # fuse batch size and ncrops tencrop_val_batch_xs = np.reshape(validation_batch_xs, (-1, c, h, w)) val_summary, val_accuracy, conf_matrix = sess.run( [summary_op, accuracy, confusion_matrix], feed_dict={ X: tencrop_val_batch_xs, ground_truth: validation_batch_ys, is_training: False, keep_prob: 1.0 }) validation_writer.add_summary(val_summary, num_epoch) total_val_accuracy += val_accuracy validation_count += 1 if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix total_val_accuracy /= validation_count tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info( 'Validation accuracy = %.1f%% (N=%d)' % (total_val_accuracy * 100, PCAM_VALIDATE_DATA_SIZE)) # Save the model checkpoint periodically. if (num_epoch <= FLAGS.how_many_training_epochs - 1): checkpoint_path = os.path.join(FLAGS.train_logdir, FLAGS.ckpt_name_to_save) tf.logging.info('Saving to "%s-%d"', checkpoint_path, num_epoch) saver.save(sess, checkpoint_path, global_step=num_epoch)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) labels = FLAGS.labels.split(',') num_classes = len(labels) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Creating train logdir: %s', FLAGS.train_logdir) with tf.Graph().as_default() as graph: global_step = tf.train.get_or_create_global_step() # Define the model X = tf.placeholder( tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3], name='X') # for 299 size, otherwise you should modify shape for ur size. final_X = tf.placeholder(tf.float32, [FLAGS.num_views, None, 8, 8, 1536], name='final_X') ground_truth = tf.placeholder(tf.int64, [None], name='ground_truth') is_training = tf.placeholder(tf.bool) is_training2 = tf.placeholder(tf.bool) dropout_keep_prob = tf.placeholder(tf.float32) grouping_scheme = tf.placeholder(tf.bool, [NUM_GROUP, FLAGS.num_views]) grouping_weight = tf.placeholder(tf.float32, [NUM_GROUP, 1]) # learning_rate = tf.placeholder(tf.float32) # Grouping Module d_scores, _, final_desc = gvcnn.discrimination_score( X, num_classes, is_training) # GVCNN logits, _ = gvcnn.gvcnn(final_X, grouping_scheme, grouping_weight, num_classes, is_training2, dropout_keep_prob) # Define loss tf.reduce_mean( tf.losses.sparse_softmax_cross_entropy(labels=ground_truth, logits=logits)) # Gather update_ops. These contain, for example, # the updates for the batch_norm variables created by model. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) prediction = tf.argmax(logits, 1, name='prediction') correct_prediction = tf.equal(prediction, ground_truth) confusion_matrix = tf.confusion_matrix(ground_truth, prediction, num_classes=num_classes) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) summaries.add(tf.summary.scalar('accuracy', accuracy)) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) optimizer = tf.train.AdamOptimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) # for variable in slim.get_model_variables(): # summaries.add(tf.summary.histogram(variable.op.name, variable)) total_loss, grads_and_vars = train_utils.optimize(optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # # Modify the gradients for biases and last layer variables. # last_layers = train_utils.get_extra_layer_scopes( # FLAGS.last_layers_contain_logits_only) # grad_mult = train_utils.get_model_gradient_multipliers( # last_layers, FLAGS.last_layer_gradient_multiplier) # if grad_mult: # grads_and_vars = slim.learning.multiply_gradients( # grads_and_vars, grad_mult) grad_summ_op = tf.summary.merge([ tf.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads_and_vars ]) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') # Add the summaries. These contain the summaries # created by model and either optimize() or _gather_loss(). summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) train_writer = tf.summary.FileWriter(FLAGS.summaries_dir, graph) validation_writer = tf.summary.FileWriter( FLAGS.summaries_dir + '/validation', graph) ################ # Prepare data ################ filenames = tf.placeholder(tf.string, shape=[]) tr_dataset = data.Dataset(filenames, FLAGS.num_views, FLAGS.height, FLAGS.width, FLAGS.batch_size) iterator = tr_dataset.dataset.make_initializable_iterator() next_batch = iterator.get_next() sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True)) with tf.Session(config=sess_config) as sess: sess.run(tf.global_variables_initializer()) # TODO: # Create a saver object which will save all the variables saver = tf.train.Saver(keep_checkpoint_every_n_hours=1.0) if FLAGS.pre_trained_checkpoint: train_utils.restore_fn(FLAGS) start_epoch = 0 # Get the number of training/validation steps per epoch tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size) if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0: tr_batches += 1 val_batches = int(MODELNET_VALIDATE_DATA_SIZE / FLAGS.batch_size) if MODELNET_VALIDATE_DATA_SIZE % FLAGS.batch_size > 0: val_batches += 1 # The filenames argument to the TFRecordDataset initializer can either be a string, # a list of strings, or a tf.Tensor of strings. training_filenames = os.path.join(FLAGS.dataset_dir, 'train.record') validate_filenames = os.path.join(FLAGS.dataset_dir, 'validate.record') ################## # Training loop. ################## for training_epoch in range(start_epoch, FLAGS.how_many_training_epochs): print("-------------------------------------") print(" Epoch {} ".format(training_epoch)) print("-------------------------------------") sess.run(iterator.initializer, feed_dict={filenames: training_filenames}) for step in range(tr_batches): # Pull the image batch we'll use for training. train_batch_xs, train_batch_ys = sess.run(next_batch) # # Verify image # assert not np.any(np.isnan(train_batch_xs)) # n_batch = train_batch_xs.shape[0] # n_view = train_batch_xs.shape[1] # for i in range(n_batch): # for j in range(n_view): # img = train_batch_xs[i][j] # # scipy.misc.toimage(img).show() # # Or # img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) # cv2.imwrite('/home/ace19/Pictures/' + str(i) + # '_' + str(j) + '.png', img) # # cv2.imshow(str(train_batch_ys[idx]), img) # cv2.waitKey(100) # cv2.destroyAllWindows() # Sets up a graph with feeds and fetches for partial run. handle = sess.partial_run_setup([ d_scores, final_desc, learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op ], [ X, final_X, ground_truth, grouping_scheme, grouping_weight, is_training, is_training2, dropout_keep_prob ]) scores, final = sess.partial_run(handle, [d_scores, final_desc], feed_dict={ X: train_batch_xs, is_training: True }) schemes = gvcnn.grouping_scheme(scores, NUM_GROUP, FLAGS.num_views) weights = gvcnn.grouping_weight(scores, schemes) # Run the graph with this batch of training data. lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \ sess.partial_run(handle, [learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op], feed_dict={ final_X: final, ground_truth: train_batch_ys, grouping_scheme: schemes, grouping_weight: weights, is_training2: True, dropout_keep_prob: 0.8} ) train_writer.add_summary(train_summary, training_epoch) train_writer.add_summary(grad_vals, training_epoch) tf.logging.info( 'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f' % (training_epoch, step, lr, train_accuracy * 100, train_loss)) ################################################### # Validate the model on the validation set ################################################### # tf.logging.info('--------------------------') # tf.logging.info(' Start validation ') # tf.logging.info('--------------------------') # # # Reinitialize iterator with the validation dataset # sess.run(iterator.initializer, feed_dict={filenames: validate_filenames}) # total_val_accuracy = 0 # validation_count = 0 # total_conf_matrix = None # # for step in range(val_batches): # validation_batch_xs, validation_batch_ys = sess.run(next_batch) # # # Sets up a graph with feeds and fetches for partial run. # handle = sess.partial_run_setup([d_scores, final_desc, # summary_op, accuracy, confusion_matrix], # [X, final_X, ground_truth, learning_rate, # grouping_scheme, grouping_weight, is_training, # is_training2, dropout_keep_prob]) # # scores, final = sess.partial_run(handle, # [d_scores, final_desc], # feed_dict={ # X: validation_batch_xs, # is_training: False} # ) # schemes = gvcnn.grouping_scheme(scores, NUM_GROUP, FLAGS.num_views) # weights = gvcnn.grouping_weight(scores, schemes) # # # Run the graph with this batch of training data. # val_summary, val_accuracy, conf_matrix = \ # sess.partial_run(handle, # [summary_op, accuracy, confusion_matrix], # feed_dict={ # final_X: final, # ground_truth: validation_batch_ys, # grouping_scheme: schemes, # grouping_weight: weights, # is_training2: False, # dropout_keep_prob: 1.0} # ) # # validation_writer.add_summary(val_summary, training_epoch) # # total_val_accuracy += val_accuracy # validation_count += 1 # if total_conf_matrix is None: # total_conf_matrix = conf_matrix # else: # total_conf_matrix += conf_matrix # # # total_val_accuracy /= validation_count # tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) # tf.logging.info('Validation accuracy = %.1f%% (N=%d)' % # (total_val_accuracy * 100, MODELNET_VALIDATE_DATA_SIZE)) # Save the model checkpoint periodically. if (training_epoch <= FLAGS.how_many_training_epochs - 1): checkpoint_path = os.path.join(FLAGS.train_logdir, FLAGS.ckpt_name_to_save) tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_epoch) saver.save(sess, checkpoint_path, global_step=training_epoch)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // config.num_clones # Get dataset-dependent information. dataset = segmentation_dataset.get_dataset(FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', FLAGS.train_split) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): samples = input_generator.get( dataset, FLAGS.train_crop_size, clone_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, dataset_split=FLAGS.train_split, is_training=True, model_variant=FLAGS.model_variant) inputs_queue = prefetch_queue.prefetch_queue(samples, capacity=128 * config.num_clones) #samples, capacity=12 * config.num_clones) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_unet #model_args = (inputs_queue, { # common.OUTPUT_TYPE: dataset.num_classes #}, dataset.ignore_label) model_args = (inputs_queue, dataset, dataset.ignore_label) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) #input('stop!') # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for images, labels, semantic predictions if FLAGS.save_summaries_images: summary_image = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')) summaries.add( tf.summary.image('samples/%s' % common.IMAGE, summary_image)) first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')) # Scale up summary image pixel values for better visualization. pixel_scaling = max(1, 255 // dataset.num_classes) summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.LABEL, summary_label)) first_clone_output = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')) predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1) summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) #input('no training') # Start the training. slim.learning.train(train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True), summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size / config.num_clones # Get dataset-dependent information. dataset = segmentation_dataset.get_dataset( FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', FLAGS.train_split) with tf.Graph().as_default(): with tf.device(config.inputs_device()): samples = input_generator.get( dataset, FLAGS.train_crop_size, clone_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, dataset_split=FLAGS.train_split, is_training=True, model_variant=FLAGS.model_variant) inputs_queue = prefetch_queue.prefetch_queue( samples, capacity=128 * config.num_clones) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_deeplab model_args = (inputs_queue, { common.OUTPUT_TYPE: dataset.num_classes }, dataset.ignore_label) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False) # Start the training. slim.learning.train( train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True), summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) labels = FLAGS.labels.split(',') num_classes = len(labels) with tf.Graph().as_default() as graph: global_step = tf.compat.v1.train.get_or_create_global_step() # Define the model X = tf.compat.v1.placeholder( tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3], name='X') ground_truth = tf.compat.v1.placeholder(tf.int64, [None], name='ground_truth') is_training = tf.compat.v1.placeholder(tf.bool, name='is_training') dropout_keep_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_keep_prob') g_scheme = tf.compat.v1.placeholder(tf.int32, [FLAGS.num_group, FLAGS.num_views]) g_weight = tf.compat.v1.placeholder(tf.float32, [FLAGS.num_group]) # GVCNN view_scores, _, logits = model.gvcnn(X, num_classes, g_scheme, g_weight, is_training, dropout_keep_prob) # # basic - for verification # _, logits = model.basic(X, # num_classes, # is_training, # dropout_keep_prob) # Define loss _loss = tf.losses.sparse_softmax_cross_entropy(labels=ground_truth, logits=logits) # Gather initial summaries. summaries = set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) prediction = tf.argmax(logits, 1, name='prediction') correct_prediction = tf.equal(prediction, ground_truth) confusion_matrix = tf.math.confusion_matrix(ground_truth, prediction, num_classes=num_classes) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) summaries.add(tf.compat.v1.summary.scalar('accuracy', accuracy)) # # Add summaries for model variables. # for model_var in slim.get_model_variables(): # summaries.add(tf.compat.v1.summary.histogram(model_var.op.name, model_var)) # Add summaries for losses. for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES): summaries.add( tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss)) learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.compat.v1.train.MomentumOptimizer( learning_rate, FLAGS.momentum) summaries.add( tf.compat.v1.summary.scalar('learning_rate', learning_rate)) total_loss, grads_and_vars = train_utils.optimize(optimizer) total_loss = tf.debugging.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss)) # Gather update_ops. # These contain, for example, the updates for the batch_norm variables created by model. update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) # Create gradient update op. update_ops.append( optimizer.apply_gradients(grads_and_vars, global_step=global_step)) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') ################ # Prepare data ################ filenames = tf.compat.v1.placeholder(tf.string, shape=[]) tr_dataset = train_data.Dataset(filenames, FLAGS.num_views, FLAGS.height, FLAGS.width, FLAGS.batch_size) iterator = tr_dataset.dataset.make_initializable_iterator() next_batch = iterator.get_next() # validation dateset val_dataset = val_data.Dataset(filenames, FLAGS.num_views, FLAGS.height, FLAGS.width, FLAGS.val_batch_size) # val_batch_size val_iterator = val_dataset.dataset.make_initializable_iterator() val_next_batch = val_iterator.get_next() sess_config = tf.compat.v1.ConfigProto( gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) with tf.compat.v1.Session(config=sess_config) as sess: sess.run(tf.compat.v1.global_variables_initializer()) # Add the summaries. These contain the summaries # created by model and either optimize() or _gather_loss(). summaries |= set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) # Merge all summaries together. summary_op = tf.compat.v1.summary.merge(list(summaries)) train_writer = tf.compat.v1.summary.FileWriter( FLAGS.summaries_dir, graph) validation_writer = tf.compat.v1.summary.FileWriter( FLAGS.summaries_dir + '/validation', graph) # Create a saver object which will save all the variables saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=1.0) if FLAGS.pre_trained_checkpoint: train_utils.restore_fn(FLAGS) if FLAGS.saved_checkpoint_dir: if tf.gfile.IsDirectory(FLAGS.saved_checkpoint_dir): checkpoint_path = tf.train.latest_checkpoint( FLAGS.saved_checkpoint_dir) else: checkpoint_path = FLAGS.saved_checkpoint_dir saver.restore(sess, checkpoint_path) start_epoch = 0 # Get the number of training/validation steps per epoch tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size) if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0: tr_batches += 1 val_batches = int(MODELNET_VALIDATE_DATA_SIZE / FLAGS.val_batch_size) if MODELNET_VALIDATE_DATA_SIZE % FLAGS.val_batch_size > 0: val_batches += 1 # The filenames argument to the TFRecordDataset initializer can either be a string, # a list of strings, or a tf.Tensor of strings. training_filenames = os.path.join(FLAGS.dataset_dir, 'modelnet5_6view_train.record') validate_filenames = os.path.join(FLAGS.dataset_dir, 'modelnet5_6view_test.record') ################################### # Training loop. ################################### for num_epoch in range(start_epoch, FLAGS.how_many_training_epochs): print("-------------------------------------") print(" Epoch {} ".format(num_epoch)) print("-------------------------------------") sess.run(iterator.initializer, feed_dict={filenames: training_filenames}) for step in range(tr_batches): # Pull the image batch we'll use for training. train_batch_xs, train_batch_ys = sess.run(next_batch) # Sets up a graph with feeds and fetches for partial run. handle = sess.partial_run_setup( [ view_scores, learning_rate, # summary_op, top1_acc, loss, optimize_op, dummy], summary_op, accuracy, _loss, train_op ], [ X, ground_truth, g_scheme, g_weight, is_training, dropout_keep_prob ]) _view_scores = sess.partial_run(handle, [view_scores], feed_dict={ X: train_batch_xs, is_training: True, dropout_keep_prob: 0.8 }) _g_schemes = model.group_scheme(_view_scores, FLAGS.num_group, FLAGS.num_views) _g_weights = model.group_weight(_g_schemes) # Run the graph with this batch of training data. lr, train_summary, train_accuracy, train_loss, _ = \ sess.partial_run(handle, [learning_rate, summary_op, accuracy, _loss, train_op], feed_dict={ ground_truth: train_batch_ys, g_scheme: _g_schemes, g_weight: _g_weights} ) # for verification # lr, train_summary, train_accuracy, train_loss, _ = \ # sess.run([learning_rate, summary_op, accuracy, _loss, train_op], # feed_dict={ # X: train_batch_xs, # ground_truth: train_batch_ys, # is_training: True, # dropout_keep_prob: 0.8} # ) train_writer.add_summary(train_summary, num_epoch) tf.compat.v1.logging.info( 'Epoch #%d, Step #%d, rate %.6f, top1_acc %.3f%%, loss %.5f' % (num_epoch, step, lr, train_accuracy, train_loss)) ################################################### # Validate the model on the validation set ################################################### tf.compat.v1.logging.info('--------------------------') tf.compat.v1.logging.info(' Start validation ') tf.compat.v1.logging.info('--------------------------') total_val_losses = 0.0 total_val_top1_acc = 0.0 val_count = 0 total_conf_matrix = None # Reinitialize val_iterator with the validation dataset sess.run(val_iterator.initializer, feed_dict={filenames: validate_filenames}) for step in range(val_batches): validation_batch_xs, validation_batch_ys = sess.run( val_next_batch) # Sets up a graph with feeds and fetches for partial run. handle = sess.partial_run_setup([ view_scores, summary_op, accuracy, _loss, confusion_matrix ], [ X, g_scheme, g_weight, ground_truth, is_training, dropout_keep_prob ]) _view_scores = sess.partial_run(handle, [view_scores], feed_dict={ X: validation_batch_xs, is_training: False, dropout_keep_prob: 1.0 }) _g_schemes = model.group_scheme(_view_scores, FLAGS.num_group, FLAGS.num_views) _g_weights = model.group_weight(_g_schemes) # Run the graph with this batch of training data. val_summary, val_accuracy, val_loss, conf_matrix = \ sess.partial_run(handle, [summary_op, accuracy, _loss, confusion_matrix], feed_dict={ ground_truth: validation_batch_ys, g_scheme: _g_schemes, g_weight: _g_weights} ) # for verification # val_summary, val_accuracy, val_loss, conf_matrix = \ # sess.run([summary_op, accuracy, _loss, confusion_matrix], # feed_dict={ # X: validation_batch_xs, # ground_truth: validation_batch_ys, # is_training: False, # dropout_keep_prob: 1.0} # ) validation_writer.add_summary(val_summary, num_epoch) total_val_losses += val_loss total_val_top1_acc += val_accuracy val_count += 1 if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix total_val_losses /= val_count total_val_top1_acc /= val_count tf.compat.v1.logging.info('Confusion Matrix:\n %s' % total_conf_matrix) tf.compat.v1.logging.info('Validation loss = %.5f' % total_val_losses) tf.compat.v1.logging.info( 'Validation accuracy = %.3f%% (N=%d)' % (total_val_top1_acc, MODELNET_VALIDATE_DATA_SIZE)) # Save the model checkpoint periodically. if (num_epoch <= FLAGS.how_many_training_epochs - 1): checkpoint_path = os.path.join(FLAGS.train_logdir, FLAGS.ckpt_name_to_save) tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path, num_epoch) saver.save(sess, checkpoint_path, global_step=num_epoch)
def main(unused_arg): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // config.num_clones tf.gfile.MakeDirs(FLAGS.train_dir) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): samples, num_samples = get_dataset.get_dataset( FLAGS.dataset, FLAGS.dataset_dir, split_name=FLAGS.train_split, is_training=True, image_size=[FLAGS.image_size, FLAGS.image_size], batch_size=clone_batch_size, channel=FLAGS.input_channel) tf.logging.info('Training on %s set: %d', FLAGS.train_split, num_samples) inputs_queue = prefetch_queue.prefetch_queue(samples, capacity=128 * config.num_clones) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_model model_args = (inputs_queue, clone_batch_size) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. if FLAGS.save_summaries_variables: for model_var in slim.get_model_variables(): summaries.add( tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) #optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('losses/total_loss', total_loss)) # Modify the gradients for biases and last layer variables. if (FLAGS.dataset == 'protein') and FLAGS.add_counts_logits: last_layers = ['Logits', 'Counts_logits'] else: last_layers = ['Logits'] grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) session_config.gpu_options.allow_growth = True session_config.gpu_options.per_process_gpu_memory_fraction = 0.9 # Start the training. slim.learning.train(train_tensor, FLAGS.train_dir, is_chief=(FLAGS.task == 0), master=FLAGS.master, graph=graph, log_every_n_steps=FLAGS.log_every_n_steps, session_config=session_config, startup_delay_steps=startup_delay_steps, number_of_steps=FLAGS.number_of_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, init_fn=train_utils.get_model_init_fn( FLAGS.train_dir, FLAGS.fine_tune_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True), summary_op=summary_op, saver=tf.train.Saver(max_to_keep=50))
def build_model(): """Builds graph for model to train with rewrites for quantization. Returns: g: Graph with fake quantization ops and batch norm folding suitable for training quantized weights. train_tensor: Train op for execution during training. """ g = tf.Graph() with g.as_default(), tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks)): samples, _ = get_dataset.get_dataset(FLAGS.dataset, FLAGS.dataset_dir, split_name=FLAGS.train_split, is_training=True, image_size=[FLAGS.image_size, FLAGS.image_size], batch_size=FLAGS.batch_size, channel=FLAGS.input_channel) inputs = tf.identity(samples['image'], name='image') labels = tf.identity(samples['label'], name='label') model_options = common.ModelOptions(output_stride=FLAGS.output_stride) net, end_points = model.get_features( inputs, model_options=model_options, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm) logits, _ = model.classification(net, end_points, num_classes=FLAGS.num_classes, is_training=True) logits = slim.softmax(logits) focal_loss_tensor = train_utils.focal_loss(labels, logits, weights=1.0) # f1_loss_tensor = train_utils.f1_loss(labels, logits, weights=1.0) # cls_loss = f1_loss_tensor cls_loss = focal_loss_tensor # Gather update_ops update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_step = tf.train.get_or_create_global_step() learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) opt = tf.train.AdamOptimizer(learning_rate) # opt = tf.train.RMSPropOptimizer(learning_rate, momentum=FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) for loss in tf.get_collection(tf.GraphKeys.LOSSES): summaries.add(tf.summary.scalar('sub_losses/%s'%(loss.op.name), loss)) classifation_loss = tf.identity(cls_loss, name='classifation_loss') summaries.add(tf.summary.scalar('losses/classifation_loss', classifation_loss)) regularization_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) regularization_loss = tf.add_n(regularization_loss, name='regularization_loss') summaries.add(tf.summary.scalar('losses/regularization_loss', regularization_loss)) total_loss = tf.add(cls_loss, regularization_loss, name='total_loss') grads_and_vars = opt.compute_gradients(total_loss) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') summaries.add(tf.summary.scalar('losses/total_loss', total_loss)) grad_updates = opt.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) return g, train_tensor, summary_op
def _train_pgn_model(iterator, num_of_classes, model_options, ignore_label, reuse=None): """Trains the pgn model. Args: iterator: An iterator of type tf.data.Iterator for images and labels. num_of_classes: Number of classes for the dataset. ignore_label: Ignore label for the dataset. Returns: train_tensor: A tensor to update the model variables. summary_op: An operation to log the summaries. """ global_step = tf.train.get_or_create_global_step() summaries = [] learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) tower_grads = [] total_loss, total_seg_loss = 0, 0 tower_summaries = None for i in range(FLAGS.num_clones): with tf.device('/gpu:%d' % i): with tf.name_scope('clone_%d' % i) as scope: loss, seg_loss = _tower_loss(iterator=iterator, num_of_classes=num_of_classes, model_options=model_options, ignore_label=ignore_label, scope=scope, reuse_variable=(i != 0) # reuse_variable=reuse ) total_loss += loss total_seg_loss += seg_loss grads = optimizer.compute_gradients(loss) tower_grads.append(grads) tower_summaries = tf.summary.merge_all() summaries.append(tf.summary.scalar('learning_rate', learning_rate)) with tf.device('/cpu:0'): grads_and_vars = _average_gradients(tower_grads) if tower_summaries is not None: summaries.append(tower_summaries) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) # Gather update_ops. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops.append(grad_updates) update_op = tf.group(*update_ops) should_log = tf.equal(math_ops.mod(global_step, FLAGS.log_steps), 0) total_loss = tf.cond( should_log, lambda: tf.Print( total_loss, [total_loss, total_seg_loss, global_step], 'Total loss, Segmentation loss and Global step:'), lambda: total_loss) summaries.append(tf.summary.scalar('total_loss', total_loss)) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') summary_op = tf.summary.merge(summaries) return train_tensor, summary_op
def main(unused_argv): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) labels = FLAGS.labels.split(',') num_classes = len(labels) # tf.compat.v1.logging.info('Creating train logdir: %s', FLAGS.train_logdir) with tf.Graph().as_default() as graph: global_step = tf.compat.v1.train.get_or_create_global_step() X = tf.compat.v1.placeholder( tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3], name='X') ground_truth = tf.compat.v1.placeholder(tf.int64, [None], name='ground_truth') is_training = tf.compat.v1.placeholder(tf.bool, name='is_training') dropout_keep_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_keep_prob') # learning_rate = tf.placeholder(tf.float32, name='lr') # metric learning logits, features = \ model.mvcnn_with_deep_cosine_metric_learning(X, num_classes, is_training=is_training, keep_prob=dropout_keep_prob, attention_module='se_block') # logits, features = mvcnn.mvcnn(X, num_classes) cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy( labels=ground_truth, logits=logits) tf.compat.v1.summary.scalar("cross_entropy_loss", cross_entropy) # Gather update ops. These contain, for example, the updates for the # batch_norm variables created by model. update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) # Gather initial summaries. summaries = set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) predition = tf.argmax(logits, 1, name='prediction') correct_predition = tf.equal(predition, ground_truth) confusion_matrix = tf.math.confusion_matrix(ground_truth, predition, num_classes=num_classes) # accuracy = tf.reduce_mean(tf.cast(correct_predition, tf.float32)) # summaries.add(tf.summary.scalar('accuracy', accuracy)) accuracy = slim.metrics.accuracy(tf.cast(predition, tf.int64), ground_truth) tf.compat.v1.summary.scalar("accuracy", accuracy) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add( tf.compat.v1.summary.histogram(model_var.op.name, model_var)) # Add summaries for losses. for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES): summaries.add( tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss)) learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate) summaries.add( tf.compat.v1.summary.scalar('learning_rate', learning_rate)) total_loss, grads_and_vars = train_utils.optimize(optimizer) total_loss = tf.compat.v1.check_numerics(total_loss, 'Loss is inf or nan') summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss)) # TensorBoard: How to plot histogram for gradients # grad_summ_op = tf.compat.v1.summary.merge([tf.compat.v1.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads_and_vars]) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') # Add the summaries. These contain the summaries created by model # and either optimize() or _gather_loss() summaries |= set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) # Merge all summaries together. summary_op = tf.compat.v1.summary.merge(list(summaries)) train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir, graph) validation_writer = tf.compat.v1.summary.FileWriter( FLAGS.summaries_dir + '/validation', graph) ##################### # prepare data ##################### tfrecord_names = tf.compat.v1.placeholder(tf.string, shape=[]) _dataset = data.Dataset(tfrecord_names, FLAGS.num_views, FLAGS.height, FLAGS.width, FLAGS.batch_size) iterator = _dataset.dataset.make_initializable_iterator() next_batch = iterator.get_next() sess_config = tf.compat.v1.ConfigProto( gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) with tf.compat.v1.Session(config=sess_config) as sess: sess.run(tf.compat.v1.global_variables_initializer()) saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=1.0) if FLAGS.pre_trained_checkpoint: train_utils.restore_fn(FLAGS) start_epoch = 0 training_batches = int(MODELNET10_TRAIN_DATA_SIZE / FLAGS.batch_size) if MODELNET10_TRAIN_DATA_SIZE % FLAGS.batch_size > 0: training_batches += 1 val_batches = int(MODELNET10_VALIDATE_DATA_SIZE / FLAGS.batch_size) if MODELNET10_VALIDATE_DATA_SIZE % FLAGS.batch_size > 0: val_batches += 1 # The filenames argument to the TFRecordDataset initializer can either # be a string, a list of strings, or a tf.Tensor of strings. training_tf_filenames = os.path.join(FLAGS.dataset_dir, 'train.record') val_tf_filenames = os.path.join(FLAGS.dataset_dir, 'validate.record') ################## # Training loop. ################## for n_epoch in range(start_epoch, FLAGS.how_many_training_epochs): tf.compat.v1.logging.info('--------------------------') tf.compat.v1.logging.info(' Epoch %d' % n_epoch) tf.compat.v1.logging.info('--------------------------') sess.run(iterator.initializer, feed_dict={tfrecord_names: training_tf_filenames}) for step in range(training_batches): train_batch_xs, train_batch_ys = sess.run(next_batch) # # Verify image # assert not np.any(np.isnan(train_batch_xs)) # n_batch = train_batch_xs.shape[0] # n_view = train_batch_xs.shape[1] # for i in range(n_batch): # for j in range(n_view): # img = train_batch_xs[i][j] # # scipy.misc.toimage(img).show() # # Or # img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) # cv2.imwrite('/home/ace19/Pictures/' + str(i) + # '_' + str(j) + '.png', img) # # cv2.imshow(str(train_batch_ys[idx]), img) # cv2.waitKey(100) # cv2.destroyAllWindows() lr, train_summary, train_accuracy, train_loss, _ = \ sess.run([learning_rate, summary_op, accuracy, total_loss, train_op], feed_dict={X: train_batch_xs, ground_truth: train_batch_ys, is_training: True, dropout_keep_prob: 0.8}) # lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \ # sess.run([learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op], # feed_dict={X: train_batch_xs, # ground_truth: train_batch_ys, # is_training: True, # dropout_keep_prob: 0.8}) train_writer.add_summary(train_summary, n_epoch) # train_writer.add_summary(grad_vals, n_epoch) tf.compat.v1.logging.info( 'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f' % (n_epoch, step, lr, train_accuracy * 100, train_loss)) ################################################### # Validate the model on the validation set ################################################### tf.compat.v1.logging.info('--------------------------') tf.compat.v1.logging.info(' Start validation ') tf.compat.v1.logging.info('--------------------------') # Reinitialize iterator with the validation dataset sess.run(iterator.initializer, feed_dict={tfrecord_names: val_tf_filenames}) total_val_accuracy = 0 validation_count = 0 total_conf_matrix = None for step in range(val_batches): validation_batch_xs, validation_batch_ys = sess.run( next_batch) val_summary, val_accuracy, conf_matrix = \ sess.run([summary_op, accuracy, confusion_matrix], feed_dict={X: validation_batch_xs, ground_truth: validation_batch_ys, is_training: False, dropout_keep_prob: 1.0}) validation_writer.add_summary(val_summary, n_epoch) total_val_accuracy += val_accuracy validation_count += 1 if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix total_val_accuracy /= validation_count tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.compat.v1.logging.info( 'Validation accuracy = %.1f%% (N=%d)' % (total_val_accuracy * 100, MODELNET10_VALIDATE_DATA_SIZE)) # Save the model checkpoint periodically. if (n_epoch <= FLAGS.how_many_training_epochs - 1): checkpoint_path = os.path.join(FLAGS.train_logdir, FLAGS.ckpt_name_to_save) tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path, n_epoch) saver.save(sess, checkpoint_path, global_step=n_epoch)
def main(unused_argv): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) labels = FLAGS.labels.split(',') num_classes = len(labels) with tf.Graph().as_default() as graph: global_step = tf.compat.v1.train.get_or_create_global_step() X = tf.compat.v1.placeholder(tf.float32, [None, FLAGS.height, FLAGS.width, 3], name='X') ground_truth = tf.compat.v1.placeholder(tf.int64, [None], name='ground_truth') is_training = tf.compat.v1.placeholder(tf.bool, name='is_training') keep_prob = tf.compat.v1.placeholder(tf.float32, [], name='keep_prob') tfrecord_filenames = tf.compat.v1.placeholder(tf.string, shape=[]) # # Print name and shape of each tensor. # tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++") # tf.compat.v1.logging.info("Layers") # tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++") # for k, v in end_points.items(): # tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape())) # # Gather initial summaries. summaries = set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) # # Add summaries for model variables. # for variable in slim.get_model_variables(): # summaries.add(tf.compat.v1.summary.histogram(variable.op.name, variable)) # # # Add summaries for losses. # for loss in tf.compat.v1.get_collection(tf.GraphKeys.LOSSES): # summaries.add(tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss)) learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) summaries.add( tf.compat.v1.summary.scalar('learning_rate', learning_rate)) # optimizers = \ # [tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9) for _ in range(FLAGS.num_gpu)] # optimizers = \ # [tf.compat.v1.train.MomentumOptimizer(learning_rate, FLAGS.momentum) for _ in range(FLAGS.num_gpu)] optimizers = \ [tf.compat.v1.train.GradientDescentOptimizer(learning_rate) for _ in range(FLAGS.num_gpu)] logits = [] losses = [] grad_list = [] filename_batch = [] image_batch = [] gt_batch = [] for gpu_idx in range(FLAGS.num_gpu): tf.compat.v1.logging.info('creating gpu tower @ %d' % (gpu_idx + 1)) image_batch.append(X) gt_batch.append(ground_truth) scope_name = 'tower%d' % gpu_idx with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_idx)), \ tf.compat.v1.variable_scope(scope_name): # apply SENet _, logit = model.deep_cosine_softmax( X, num_classes=num_classes, is_training=is_training, is_reuse=False, keep_prob=keep_prob, attention_module='se_block') # # Print name and shape of parameter nodes (values not yet initialized) tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++") tf.compat.v1.logging.info("Parameters") tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++") for v in slim.get_model_variables(): tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape())) # # TTA # logit = tf.cond(is_training, # lambda: tf.identity(logit), # lambda: tf.reduce_mean(tf.reshape(logit, [FLAGS.val_batch_size // FLAGS.num_gpu, TEN_CROP, -1]), axis=1)) logits.append(logit) l = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=ground_truth, logits=logit) losses.append(l) loss_w_reg = tf.reduce_sum(l) + tf.add_n( slim.losses.get_regularization_losses(scope=scope_name)) grad_list.append([ x for x in optimizers[gpu_idx].compute_gradients(loss_w_reg) if x[0] is not None ]) y_hat = tf.concat(logits, axis=0) image_batch = tf.concat(image_batch, axis=0) gt_batch = tf.concat(gt_batch, axis=0) # Acc top1_acc = tf.reduce_mean( tf.cast(tf.nn.in_top_k(y_hat, gt_batch, k=1), dtype=tf.float32)) summaries.add(tf.compat.v1.summary.scalar('top1_acc', top1_acc)) # top5_acc = tf.reduce_mean( # tf.cast(tf.nn.in_top_k(y_hat, gt_batch, k=5), dtype=tf.float32) # ) # summaries.add(tf.compat.v1.summary.scalar('top5_acc', top5_acc)) prediction = tf.argmax(y_hat, axis=1, name='prediction') confusion_matrix = tf.math.confusion_matrix(gt_batch, prediction, num_classes=num_classes) confusion_matrix = tf.div(confusion_matrix, FLAGS.num_gpu) loss = tf.reduce_mean(losses) loss = tf.compat.v1.check_numerics(loss, 'Loss is inf or nan.') summaries.add(tf.compat.v1.summary.scalar('loss', loss)) # use NCCL grads, all_vars = train_helper.split_grad_list(grad_list) reduced_grad = train_helper.allreduce_grads(grads, average=True) grads = train_helper.merge_grad_list(reduced_grad, all_vars) # optimizer using NCCL train_ops = [] for idx, grad_and_vars in enumerate(grads): # apply_gradients may create variables. Make them LOCAL_VARIABLESZ¸¸¸¸¸¸ with tf.name_scope('apply_gradients'), tf.device( tf.DeviceSpec(device_type="GPU", device_index=idx)): update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS, scope='tower%d' % idx) with tf.control_dependencies(update_ops): train_ops.append(optimizers[idx].apply_gradients( grad_and_vars, name='apply_grad_{}'.format(idx), global_step=global_step)) # TODO: # TensorBoard: How to plot histogram for gradients # grad_summ_op = tf.summary.merge([tf.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads_and_vars]) optimize_op = tf.group(*train_ops, name='train_op') sync_op = train_helper.get_post_init_ops() # Create a saver object which will save all the variables saver = tf.compat.v1.train.Saver() best_ckpt_saver = BestCheckpointSaver(save_dir=FLAGS.train_logdir, num_to_keep=100, maximize=False, saver=saver) best_val_loss = 99999 best_val_acc = 0 start_epoch = 0 epoch_count = tf.Variable(start_epoch, trainable=False) epoch_count_add = tf.assign(epoch_count, epoch_count + 1) ############### # Prepare data ############### # training dateset tr_dataset = train_data.Dataset(tfrecord_filenames, FLAGS.batch_size // FLAGS.num_gpu, num_classes, FLAGS.how_many_training_epochs, TRAIN_DATA_SIZE, FLAGS.height, FLAGS.width) iterator = tr_dataset.dataset.make_initializable_iterator() next_batch = iterator.get_next() # validation dateset val_dataset = val_data.Dataset(tfrecord_filenames, FLAGS.val_batch_size // FLAGS.num_gpu, num_classes, FLAGS.how_many_training_epochs, VALIDATE_DATA_SIZE, FLAGS.height, FLAGS.width) # 256, # 256 ~ 480 # 256) val_iterator = val_dataset.dataset.make_initializable_iterator() val_next_batch = val_iterator.get_next() sess_config = tf.compat.v1.ConfigProto( gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) with tf.compat.v1.Session(config=sess_config) as sess: sess.run(tf.compat.v1.global_variables_initializer()) # Add the summaries. These contain the summaries # created by model and either optimize() or _gather_loss(). summaries |= set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) # Merge all summaries together. summary_op = tf.compat.v1.summary.merge(list(summaries)) train_writer = tf.compat.v1.summary.FileWriter( FLAGS.summaries_dir, graph) validation_writer = tf.compat.v1.summary.FileWriter( FLAGS.summaries_dir + '/validation', graph) # TODO: supports multi gpu -> add scope ('tower%d' % gpu_idx) if FLAGS.pre_trained_checkpoint: train_utils.restore_fn(FLAGS) if FLAGS.saved_checkpoint_dir: if tf.gfile.IsDirectory(FLAGS.saved_checkpoint_dir): checkpoint_path = tf.train.latest_checkpoint( FLAGS.saved_checkpoint_dir) else: checkpoint_path = FLAGS.saved_checkpoint_dir saver.restore(sess, checkpoint_path) # global_step = checkpoint_path.split('/')[-1].split('-')[-1] sess.run(sync_op) # Get the number of training/validation steps per epoch tr_batches = int(TRAIN_DATA_SIZE / (FLAGS.batch_size // FLAGS.num_gpu)) if TRAIN_DATA_SIZE % (FLAGS.batch_size // FLAGS.num_gpu) > 0: tr_batches += 1 val_batches = int(VALIDATE_DATA_SIZE / (FLAGS.val_batch_size // FLAGS.num_gpu)) if VALIDATE_DATA_SIZE % (FLAGS.val_batch_size // FLAGS.num_gpu) > 0: val_batches += 1 # The filenames argument to the TFRecordDataset initializer can either be a string, # a list of strings, or a tf.Tensor of strings. train_record_filenames = os.path.join(FLAGS.dataset_dir, 'train.record') validate_record_filenames = os.path.join(FLAGS.dataset_dir, 'validate.record') ############################ # Training loop. ############################ for num_epoch in range(start_epoch, FLAGS.how_many_training_epochs): print("------------------------------------") print(" Epoch {} ".format(num_epoch)) print("------------------------------------") sess.run(epoch_count_add) sess.run( iterator.initializer, feed_dict={tfrecord_filenames: train_record_filenames}) for step in range(tr_batches): filenames, train_batch_xs, train_batch_ys = sess.run( next_batch) # show_batch_data(filenames, train_batch_xs, train_batch_ys) # # augmented_batch_xs = aug_utils.aug(train_batch_xs) # show_batch_data(filenames, augmented_batch_xs, # train_batch_ys, 'aug') # Run the graph with this batch of training data and learning rate policy. lr, train_summary, train_top1_acc, train_loss, _ = \ sess.run([learning_rate, summary_op, top1_acc, loss, optimize_op], feed_dict={ X: train_batch_xs, ground_truth: train_batch_ys, is_training: True, keep_prob: 0.8 }) train_writer.add_summary(train_summary, num_epoch) # train_writer.add_summary(grad_vals, num_epoch) tf.compat.v1.logging.info( 'Epoch #%d, Step #%d, rate %.6f, top1_acc %.3f%%, loss %.5f' % (num_epoch, step, lr, train_top1_acc, train_loss)) ################################################### # Validate the model on the validation set ################################################### tf.compat.v1.logging.info('--------------------------') tf.compat.v1.logging.info(' Start validation ') tf.compat.v1.logging.info('--------------------------') total_val_losses = 0.0 total_val_top1_acc = 0.0 val_count = 0 total_conf_matrix = None sess.run( val_iterator.initializer, feed_dict={tfrecord_filenames: validate_record_filenames}) for step in range(val_batches): filenames, validation_batch_xs, validation_batch_ys = sess.run( val_next_batch) # # TTA # batch_size, n_crops, c, h, w = validation_batch_xs.shape # # fuse batch size and ncrops # tencrop_val_batch_xs = np.reshape(validation_batch_xs, (-1, c, h, w)) # show_batch_data(filenames, tencrop_val_batch_xs, validation_batch_ys) # augmented_val_batch_xs = aug_utils.aug(tencrop_val_batch_xs) # show_batch_data(filenames, augmented_val_batch_xs, # validation_batch_ys, 'aug') val_summary, val_loss, val_top1_acc, _confusion_matrix = sess.run( [summary_op, loss, top1_acc, confusion_matrix], feed_dict={ X: validation_batch_xs, ground_truth: validation_batch_ys, is_training: False, keep_prob: 1.0 }) validation_writer.add_summary(val_summary, num_epoch) total_val_losses += val_loss total_val_top1_acc += val_top1_acc # total_val_accuracy += val_top1_acc val_count += 1 if total_conf_matrix is None: total_conf_matrix = _confusion_matrix else: total_conf_matrix += _confusion_matrix total_val_losses /= val_count total_val_top1_acc /= val_count # total_val_accuracy /= val_count tf.compat.v1.logging.info('Confusion Matrix:\n %s' % total_conf_matrix) tf.compat.v1.logging.info('Validation loss = %.5f' % total_val_losses) tf.compat.v1.logging.info( 'Validation top1 accuracy = %.3f%% (N=%d)' % (total_val_top1_acc, VALIDATE_DATA_SIZE)) # periodic synchronization sess.run(sync_op) # Save the model checkpoint periodically. if (num_epoch <= FLAGS.how_many_training_epochs - 1): # best_checkpoint_path = os.path.join(FLAGS.train_logdir, 'best_' + FLAGS.ckpt_name_to_save) # tf.compat.v1.logging.info('Saving to "%s"', best_checkpoint_path) # saver.save(sess, best_checkpoint_path, global_step=num_epoch) # save & keep best model wrt. validation loss best_ckpt_saver.handle(total_val_losses, sess, epoch_count) if best_val_loss > total_val_losses: best_val_loss = total_val_losses best_val_acc = total_val_top1_acc chk_path = get_best_checkpoint(FLAGS.train_logdir, select_maximum_value=False) tf.compat.v1.logging.info( 'training done. best_model val_loss=%.5f, top1_acc=%.3f%%, ckpt=%s' % (best_val_loss, best_val_acc, chk_path))