def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor, quantize_layer, is_training): batch_size, bottleneck_tensor_size = bottleneck_tensor.get_shape().as_list() assert batch_size is None, 'We want to work with arbitrary batch size.' with tf.name_scope('input'): bottleneck_input = tf.placeholder_with_default( bottleneck_tensor, shape=[batch_size, bottleneck_tensor_size], name='BottleneckInputPlaceholder') ground_truth_input = tf.placeholder( tf.int64, [batch_size], name='GroundTruthInput') # Organizing the following ops so they are easier to see in TensorBoard. layer_name = 'final_retrain_ops' with tf.name_scope(layer_name): with tf.name_scope('weights'): initial_value = tf.truncated_normal( [bottleneck_tensor_size, class_count], stddev=0.001) layer_weights = tf.Variable(initial_value, name='final_weights') variable_summaries(layer_weights) with tf.name_scope('biases'): layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') variable_summaries(layer_biases) with tf.name_scope('Wx_plus_b'): logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases tf.summary.histogram('pre_activations', logits) final_tensor = tf.nn.softmax(logits, name=final_tensor_name) if quantize_layer: if is_training: contrib_quantize.create_training_graph() else: contrib_quantize.create_eval_graph() tf.summary.histogram('activations', final_tensor) if not is_training: return None, None, bottleneck_input, ground_truth_input, final_tensor with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'): optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate) train_step = optimizer.minimize(cross_entropy_mean) return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input, final_tensor)
def build_model(): """Builds graph for model to train with rewrites for quantization. Returns: g: Graph with fake quantization ops and batch norm folding suitable for training quantized weights. train_tensor: Train op for execution during training. """ g = tf.Graph() with g.as_default(), tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks)): inputs, labels = imagenet_input(is_training=True) with slim.arg_scope( mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)): logits, _ = mobilenet_v1.mobilenet_v1( inputs, is_training=True, depth_multiplier=FLAGS.depth_multiplier, num_classes=FLAGS.num_classes, final_endpoint=FLAGS.final_endpoint) tf.losses.softmax_cross_entropy(labels, logits) # Call rewriter to produce graph with fake quant ops and folded batch norms # quant_delay delays start of quantization till quant_delay steps, allowing # for better model accuracy. if FLAGS.quantize: contrib_quantize.create_training_graph( quant_delay=get_quant_delay()) total_loss = tf.losses.get_total_loss(name='total_loss') # Configure the learning rate using an exponential decay. num_epochs_per_decay = 2.5 imagenet_size = 1271167 decay_steps = int(imagenet_size / FLAGS.batch_size * num_epochs_per_decay) learning_rate = tf.train.exponential_decay( get_learning_rate(), tf.train.get_or_create_global_step(), decay_steps, _LEARNING_RATE_DECAY_FACTOR, staircase=True) opt = tf.train.GradientDescentOptimizer(learning_rate) train_tensor = slim.learning.create_train_op(total_loss, optimizer=opt) slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses') slim.summaries.add_scalar_summary(learning_rate, 'learning_rate', 'training') return g, train_tensor
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // config.num_clones tf.gfile.MakeDirs(FLAGS.train_logdir) common.outputlogMessage('Training on %s set' % FLAGS.train_split) common.outputlogMessage('Dataset: %s' % FLAGS.dataset) common.outputlogMessage('train_crop_size: %s' % str(FLAGS.train_crop_size)) common.outputlogMessage(str(FLAGS.train_crop_size)) common.outputlogMessage('atrous_rates: %s' % str(FLAGS.atrous_rates)) common.outputlogMessage('number of classes: %s' % str(FLAGS.num_classes)) common.outputlogMessage('Ignore label value: %s' % str(FLAGS.ignore_label)) pid = os.getpid() with open('train_py_pid.txt', 'w') as f_obj: f_obj.writelines('%d' % pid) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): dataset = data_generator.Dataset( dataset_name=FLAGS.dataset, split_name=FLAGS.train_split, dataset_dir=FLAGS.dataset_dir, batch_size=clone_batch_size, crop_size=[int(sz) for sz in FLAGS.train_crop_size], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, model_variant=FLAGS.model_variant, num_readers=4, is_training=True, should_shuffle=True, should_repeat=True, num_classes=FLAGS.num_classes, ignore_label=FLAGS.ignore_label) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_deeplab model_args = (dataset.get_one_shot_iterator(), { common.OUTPUT_TYPE: dataset.num_of_classes }, dataset.ignore_label) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in tf.model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for images, labels, semantic predictions if FLAGS.save_summaries_images: summary_image = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')) summaries.add( tf.summary.image('samples/%s' % common.IMAGE, summary_image)) first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')) # Scale up summary image pixel values for better visualization. pixel_scaling = max(1, 255 // dataset.num_of_classes) summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.LABEL, summary_label)) first_clone_output = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')) predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1) summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate, decay_steps=FLAGS.decay_steps, end_learning_rate=FLAGS.end_learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) elif FLAGS.optimizer == 'adam': optimizer = tf.train.AdamOptimizer( learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon) else: raise ValueError('Unknown optimizer') if FLAGS.quantize_delay_step >= 0: if FLAGS.num_clones > 1: raise ValueError( 'Quantization doesn\'t support multi-clone yet.') contrib_quantize.create_training_graph( quant_delay=FLAGS.quantize_delay_step) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Start the training. profile_dir = FLAGS.profile_logdir if profile_dir is not None: tf.gfile.MakeDirs(profile_dir) with contrib_tfprof.ProfileContext(enabled=profile_dir is not None, profile_dir=profile_dir): init_fn = None if FLAGS.tf_initial_checkpoint: init_fn = train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True) slim.learning.train(train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=init_fn, summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def model_fn(features, labels, mode, params, tf_sess=False): """ Create the model for estimator api Args: features: if input_layout == 'nhwc', a tensor with shape: [BATCH_SIZE, go.N, go.N, get_features_planes()] else, a tensor with shape: [BATCH_SIZE, get_features_planes(), go.N, go.N] labels: dict from string to tensor with shape 'pi_tensor': [BATCH_SIZE, go.N * go.N + 1] 'value_tensor': [BATCH_SIZE] mode: a tf.estimator.ModeKeys (batchnorm params update for TRAIN only) params: A dictionary (Typically derived from the FLAGS object.) Returns: tf.estimator.EstimatorSpec with props mode: same as mode arg predictions: dict of tensors 'policy': [BATCH_SIZE, go.N * go.N + 1] 'value': [BATCH_SIZE] loss: a single value tensor train_op: train op eval_metric_ops return dict of tensors logits: [BATCH_SIZE, go.N * go.N + 1] """ policy_output, value_output, logits = model_inference_fn( features, mode == tf.estimator.ModeKeys.TRAIN, params) # train ops policy_cost = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=tf.stop_gradient( labels['pi_tensor']))) value_cost = params['value_cost_weight'] * tf.reduce_mean( tf.square(value_output - labels['value_tensor'])) reg_vars = [ v for v in tf.trainable_variables() if 'bias' not in v.name and 'beta' not in v.name ] l2_cost = params['l2_strength'] * \ tf.add_n([tf.nn.l2_loss(v) for v in reg_vars]) combined_cost = policy_cost + value_cost + l2_cost global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.piecewise_constant(global_step, params['lr_boundaries'], params['lr_rates']) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Insert quantization ops if requested if params['quantize']: if mode == tf.estimator.ModeKeys.TRAIN: contrib_quantize.create_training_graph( quant_delay=params['quant_delay']) else: contrib_quantize.create_eval_graph() optimizer = tf.train.MomentumOptimizer(learning_rate, params['sgd_momentum']) # hvd multigpu optimizer = hvd.DistributedOptimizer(optimizer) if params['use_tpu']: optimizer = contrib_tpu_python_tpu_tpu_optimizer.CrossShardOptimizer( optimizer) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(combined_cost, global_step=global_step) # return train_op for sess if tf_sess: return train_op # Computations to be executed on CPU, outside of the main TPU queues. def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, value_tensor, policy_cost, value_cost, l2_cost, combined_cost, step, est_mode=tf.estimator.ModeKeys.TRAIN): policy_entropy = -tf.reduce_mean( tf.reduce_sum(policy_output * tf.log(policy_output), axis=1)) # pi_tensor is one_hot when generated from sgfs (for supervised learning) # and soft-max when using self-play records. argmax normalizes the two. policy_target_top_1 = tf.argmax(pi_tensor, axis=1) policy_output_in_top1 = tf.to_float( tf.nn.in_top_k(policy_output, policy_target_top_1, k=1)) policy_output_in_top3 = tf.to_float( tf.nn.in_top_k(policy_output, policy_target_top_1, k=3)) policy_top_1_confidence = tf.reduce_max(policy_output, axis=1) policy_target_top_1_confidence = tf.boolean_mask( policy_output, tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1])) value_cost_normalized = value_cost / params['value_cost_weight'] avg_value_observed = tf.reduce_mean(value_tensor) with tf.variable_scope('metrics'): metric_ops = { 'policy_cost': tf.metrics.mean(policy_cost), 'value_cost': tf.metrics.mean(value_cost), 'value_cost_normalized': tf.metrics.mean(value_cost_normalized), 'l2_cost': tf.metrics.mean(l2_cost), 'policy_entropy': tf.metrics.mean(policy_entropy), 'combined_cost': tf.metrics.mean(combined_cost), 'avg_value_observed': tf.metrics.mean(avg_value_observed), 'policy_accuracy_top_1': tf.metrics.mean(policy_output_in_top1), 'policy_accuracy_top_3': tf.metrics.mean(policy_output_in_top3), 'policy_top_1_confidence': tf.metrics.mean(policy_top_1_confidence), 'policy_target_top_1_confidence': tf.metrics.mean(policy_target_top_1_confidence), 'value_confidence': tf.metrics.mean(tf.abs(value_output)), } if est_mode == tf.estimator.ModeKeys.EVAL: return metric_ops # NOTE: global_step is rounded to a multiple of FLAGS.summary_steps. eval_step = tf.reduce_min(step) # Create summary ops so that they show up in SUMMARIES collection # That way, they get logged automatically during training summary_writer = contrib_summary.create_file_writer(FLAGS.work_dir) with summary_writer.as_default(), \ contrib_summary.record_summaries_every_n_global_steps( params['summary_steps'], eval_step): for metric_name, metric_op in metric_ops.items(): contrib_summary.scalar(metric_name, metric_op[1], step=eval_step) # Reset metrics occasionally so that they are mean of recent batches. reset_op = tf.variables_initializer(tf.local_variables('metrics')) cond_reset_op = tf.cond( tf.equal(eval_step % params['summary_steps'], tf.to_int64(1)), lambda: reset_op, lambda: tf.no_op()) return contrib_summary.all_summary_ops() + [cond_reset_op] metric_args = [ policy_output, value_output, labels['pi_tensor'], labels['value_tensor'], tf.reshape(policy_cost, [1]), tf.reshape(value_cost, [1]), tf.reshape(l2_cost, [1]), tf.reshape(combined_cost, [1]), tf.reshape(global_step, [1]), ] predictions = { 'policy_output': policy_output, 'value_output': value_output, } eval_metrics_only_fn = functools.partial( eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.EVAL) host_call_fn = functools.partial(eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.TRAIN) tpu_estimator_spec = contrib_tpu_python_tpu_tpu_estimator.TPUEstimatorSpec( mode=mode, predictions=predictions, loss=combined_cost, train_op=train_op) if params['use_tpu']: return tpu_estimator_spec else: return tpu_estimator_spec.as_estimator_spec()
def build_loss(self): response = self.response response_size = response.get_shape().as_list()[1:3] # [height, width] gt = construct_gt_score_maps( response_size, self.data_config['batch_size'], self.model_config['embed_config']['stride'], self.train_config['gt_config']) # loss: https://www.renom.jp/ja/notebooks/tutorial/basic_algorithm/lossfunction/notebook.html with tf.name_scope('Loss'): loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=response, labels=gt) with tf.name_scope('Balance_weights'): n_pos = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 1))) n_neg = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 0))) w_pos = 0.5 / n_pos w_neg = 0.5 / n_neg class_weights = tf.where(tf.equal(gt, 1), w_pos * tf.ones_like(gt), tf.ones_like(gt)) class_weights = tf.where(tf.equal(gt, 0), w_neg * tf.ones_like(gt), class_weights) loss = loss * class_weights # Note that we use reduce_sum instead of reduce_mean since the loss has # already been normalized by class_weights in spatial dimension. loss = tf.reduce_sum(loss, [1, 2]) batch_loss = tf.reduce_mean(loss, name='batch_loss') tf.losses.add_loss(batch_loss) total_loss = tf.losses.get_total_loss() self.batch_loss = batch_loss self.total_loss = total_loss # quantization # good note: https://www.tensorflowers.cn/t/7136 if self.model_config['embed_config']['quantization']: if self.train_config["export"]: contrib_quantize.create_eval_graph() else: contrib_quantize.create_training_graph(quant_delay=200000) tf.summary.image('exemplar', self.exemplars, family=self.mode) tf.summary.image('instance', self.instances, family=self.mode) mean_batch_loss, update_op1 = tf.metrics.mean(batch_loss) mean_total_loss, update_op2 = tf.metrics.mean(total_loss) with tf.control_dependencies([update_op1, update_op2]): tf.summary.scalar('batch_loss', mean_batch_loss, family=self.mode) tf.summary.scalar('total_loss', mean_total_loss, family=self.mode) if self.mode == 'train': tf.summary.image('GT', tf.reshape(gt[0], [1] + response_size + [1]), family='GT') tf.summary.image('Response', tf.expand_dims(tf.sigmoid(response), -1), family=self.mode) tf.summary.histogram('Response', self.response, family=self.mode) # Two more metrics to monitor the performance of training tf.summary.scalar('center_score_error', center_score_error(response), family=self.mode) tf.summary.scalar('center_dist_error', center_dist_error(response), family=self.mode)
def main(_): if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True, use_grayscale=FLAGS.use_grayscale) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if 'AuxLogits' in end_points: slim.losses.softmax_cross_entropy( end_points['AuxLogits'], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None if FLAGS.quantize_delay >= 0: contrib_quantize.create_training_graph( quant_delay=FLAGS.quantize_delay) ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') ########################### # Kicks off the training. # ########################### slim.learning.train( train_tensor, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(): # check required input arguments if not FLAGS.project_name: raise ValueError('You must supply a project name with --project_name') if not FLAGS.dataset_name: raise ValueError('You must supply a dataset name with --dataset_name') if not FLAGS.model_name in model_name_to_variables: raise ValueError( 'Model name not supported name please select one of the following model architecture: mobilenet_v1, mobilenet_v1_075, mobilenet_v1_050, mobilenet_v1_025, inception_v1' ) # set and check project_dir and experiment_dir. project_dir = os.path.join(FLAGS.project_dir, FLAGS.project_name) if not FLAGS.experiment_name: # list only directories that are names experiment_ experiment_dir = dataset_utils.create_new_experiment_dir(project_dir) else: experiment_dir = os.path.join(os.path.join(project_dir, 'experiments'), FLAGS.experiment_name) if not os.path.exists(experiment_dir): raise ValueError('Experiment directory {} does not exist.'.format( experiment_dir)) train_dir = os.path.join(experiment_dir, FLAGS.dataset_split_name) if not os.path.exists(train_dir): os.makedirs(train_dir) # set and check dataset_dir if FLAGS.image_dir: dataset_dir = convert_dataset.convert_img_to_tfrecord( project_dir, FLAGS.dataset_name, FLAGS.dataset_dir, FLAGS.image_dir, FLAGS.train_percentage, FLAGS.validation_percentage, FLAGS.test_percentage, FLAGS.train_image_size, FLAGS.train_image_size) else: if os.path.isdir(FLAGS.dataset_dir): dataset_dir = os.path.join(FLAGS.dataset_dir, FLAGS.dataset_name) else: dataset_dir = os.path.join(os.path.join(project_dir, 'datasets'), FLAGS.dataset_name) if not os.path.isdir(dataset_dir): raise ValueError( 'Can not find tfrecord dataset directory {}'.format(dataset_dir)) tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, dataset_dir) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=not FLAGS.feature_extraction, final_endpoint=FLAGS.final_endpoint) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=FLAGS.apply_image_augmentation, use_grayscale=FLAGS.use_grayscale) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn( image, train_image_size, train_image_size, add_image_summaries=FLAGS.add_image_summaries, crop_image=FLAGS.random_image_crop, min_object_covered=FLAGS.min_object_covered, rotate_image=FLAGS.random_image_rotation, random_flip=FLAGS.random_image_flip, roi=FLAGS.roi) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if FLAGS.imbalance_correction: # specify some class weightings class_weights = dataset.sorted_class_weights # deduce weights for batch samples based on their true label weights = tf.reduce_sum(tf.multiply(labels, class_weights), 1) slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=weights) else: if 'AuxLogits' in end_points: slim.losses.softmax_cross_entropy( end_points['AuxLogits'], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') else: slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) ############################# ## Calculation of metrics ## ############################# accuracy, accuracy_op = tf.metrics.accuracy( tf.argmax(labels, 1), tf.argmax(logits, 1)) precision, precision_op = tf.metrics.average_precision_at_k( tf.argmax(labels, 1), logits, 1) with tf.device('/device:CPU:0'): for class_id in range(dataset.num_classes): precision_at_k, precision_at_k_op = tf.metrics.precision_at_k( tf.argmax(labels, 1), logits, k=1, class_id=class_id) recall_at_k, recall_at_k_op = tf.metrics.recall_at_k( tf.argmax(labels, 1), logits, k=1, class_id=class_id) tf.add_to_collection('precision_at_{}'.format(class_id), precision_at_k) tf.add_to_collection('precision_at_{}_op'.format(class_id), precision_at_k_op) tf.add_to_collection('recall_at_{}'.format(class_id), recall_at_k) tf.add_to_collection('recall_at_{}_op'.format(class_id), recall_at_k_op) tf.add_to_collection('accuracy', accuracy) tf.add_to_collection('accuracy_op', accuracy_op) tf.add_to_collection('precision', precision) tf.add_to_collection('precision_op', precision_op) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ######################################################### ## Calculation of metrics for all clones ## ######################################################### # Metrics for all clones. accuracy = tf.get_collection('accuracy') accuracy_op = tf.get_collection('accuracy_op') precision = tf.get_collection('precision') precision_op = tf.get_collection('precision_op') # accuracy_op = tf.reshape(accuracy_op, []) # Stack and take the mean. accuracy = tf.reduce_mean(tf.stack(accuracy, axis=0)) accuracy_op = tf.reduce_mean(tf.stack(accuracy_op, axis=0)) precision = tf.reduce_mean(tf.stack(precision, axis=0)) precision_op = tf.reduce_mean(tf.stack(precision_op, axis=0)) # Add metric summaries. summaries.add(tf.summary.scalar('Metrics/accuracy', accuracy)) summaries.add(tf.summary.scalar('op/accuracy_op', accuracy_op)) summaries.add(tf.summary.scalar('Metrics/average_precision', precision)) summaries.add( tf.summary.scalar('op/average_precision_op', precision_op)) # Add precision/recall at each class to summary for class_id in range(dataset.num_classes): precision_at_k = tf.get_collection( 'precision_at_{}'.format(class_id)) precision_at_k_op = tf.get_collection( 'precision_at_{}_op'.format(class_id)) recall_at_k = tf.get_collection('recall_at_{}'.format(class_id)) recall_at_k_op = tf.get_collection( 'recall_at_{}_op'.format(class_id)) precision_at_k = tf.reduce_mean(tf.stack(precision_at_k, axis=0)) precision_at_k_op = tf.reduce_mean( tf.stack(precision_at_k_op, axis=0)) recall_at_k = tf.reduce_mean(tf.stack(recall_at_k, axis=0)) recall_at_k_op = tf.reduce_mean(tf.stack(recall_at_k_op, axis=0)) summaries.add( tf.summary.scalar( 'Metrics/class_{}_precision'.format(class_id), precision_at_k)) summaries.add( tf.summary.scalar('op/class_{}_precision_op'.format(class_id), precision_at_k_op)) summaries.add( tf.summary.scalar('Metrics/class_{}_recall'.format(class_id), recall_at_k)) summaries.add( tf.summary.scalar('op/class_{}_recall_op'.format(class_id), recall_at_k_op)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None if FLAGS.quantize_delay >= 0: contrib_quantize.create_training_graph( quant_delay=FLAGS.quantize_delay) ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add( tf.summary.scalar('Losses/learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) loss = total_loss # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') session_config = tf.ConfigProto( log_device_placement=FLAGS.verbose_placement, allow_soft_placement=not FLAGS.hard_placement) if not FLAGS.fixed_memory: session_config.gpu_options.allow_growth = True ########################### # Kicks off the training. # ########################### def train_step_fn(sess, train_op, global_step, train_step_kwargs): """Function that takes a gradient step and specifies whether to stop. Args: sess: The current session. train_op: An `Operation` that evaluates the gradients and returns the total loss. global_step: A `Tensor` representing the global training step. train_step_kwargs: A dictionary of keyword arguments. Returns: The total loss and a boolean indicating whether or not to stop training. Raises: ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not. """ start_time = time.time() trace_run_options = None run_metadata = None if 'should_trace' in train_step_kwargs: if 'logdir' not in train_step_kwargs: raise ValueError( 'logdir must be present in train_step_kwargs when ' 'should_trace is present') if sess.run(train_step_kwargs['should_trace']): trace_run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) run_metadata = config_pb2.RunMetadata() total_loss, np_global_step = sess.run([train_op, global_step], options=trace_run_options, run_metadata=run_metadata) time_elapsed = time.time() - start_time if run_metadata is not None: tl = timeline.Timeline(run_metadata.step_stats) trace = tl.generate_chrome_trace_format() trace_filename = os.path.join( train_step_kwargs['logdir'], 'tf_trace-%d.json' % np_global_step) logging.info('Writing trace to %s', trace_filename) file_io.write_string_to_file(trace_filename, trace) if 'summary_writer' in train_step_kwargs: train_step_kwargs['summary_writer'].add_run_metadata( run_metadata, 'run_metadata-%d' % np_global_step) if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): print('global step {:d}: loss = {:1.4f} ({:.3f} sec/step)'. format(np_global_step, total_loss, time_elapsed)) if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop or train_step_fn.should_stop train_step_fn.should_stop = False # train_step_fn.accuracy = accuracy def exit_gracefully(signum, frame): interrupted = datetime.datetime.utcnow() # if not experiment_file is None : print('Interrupted on (UTC): ', interrupted, sep='', file=experiment_file) experiment_file.flush() train_step_fn.should_stop = True print('Interrupted on (UTC): ', interrupted, sep='') signal.signal(signal.SIGINT, exit_gracefully) signal.signal(signal.SIGTERM, exit_gracefully) start = datetime.datetime.utcnow() print('Started on (UTC): ', start, sep='') # record script flags (FLAGS). write to experiment file experiment_file_path = os.path.join(train_dir, 'experiment_setting.txt') experiment_file = open(experiment_file_path, 'w') print('Experiment metadata file:', file=experiment_file) print(experiment_file_path, file=experiment_file) print('========================', file=experiment_file) print('All command-line flags:', file=experiment_file) print(experiment_file_path, file=experiment_file) for key, value in vars(FLAGS).items(): print(key, ' : ', value, sep='', file=experiment_file) print('========================', file=experiment_file) print('Started on (UTC): ', start, sep='', file=experiment_file) experiment_file.flush() slim.learning.train( train_tensor, train_step_fn=train_step_fn, logdir=train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None, session_config=session_config) finish = datetime.datetime.utcnow() # generate and save graph (output file model_name_graph.pb) print('Generate frozen graph') # TODO: Simplify by loading checkpoint+graph and freezing together (no need to save graph) # genrate and save inference graph is_training = False is_video_model = False batch_size = None num_frames = None quantize = False write_text_graphdef = False output_file = os.path.join(train_dir, FLAGS.model_name + '_graph.pb') export_inference_graph(FLAGS.dataset_name, dataset_dir, FLAGS.model_name, FLAGS.labels_offset, is_training, FLAGS.final_endpoint, FLAGS.train_image_size, FLAGS.use_grayscale, is_video_model, batch_size, num_frames, quantize, write_text_graphdef, output_file) # record training session end print('Finished on (UTC): ', finish, sep='', file=experiment_file) print('Elapsed: ', finish - start, sep='', file=experiment_file) experiment_file.flush()
def build_model_fn(features, labels, mode, params): """The model_fn for MnasNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ is_training = (mode == tf.estimator.ModeKeys.TRAIN) # This is essential, if using a keras-derived model. tf.keras.backend.set_learning_phase(is_training) if isinstance(features, dict): features = features['feature'] if mode == tf.estimator.ModeKeys.PREDICT: # Adds an identify node to help TFLite export. features = tf.identity(features, 'float_image_input') # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. TPU uses XLA compiler to figure out best layout. if params['data_format'] == 'channels_first': assert not params['transpose_input'] # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) stats_shape = [3, 1, 1] else: stats_shape = [1, 1, 3] if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant( imagenet_input.MEAN_RGB, shape=stats_shape, dtype=features.dtype) features /= tf.constant( imagenet_input.STDDEV_RGB, shape=stats_shape, dtype=features.dtype) has_moving_average_decay = (params['moving_average_decay'] > 0) tf.logging.info('Using open-source implementation for MnasNet definition.') override_params = {} if params['batch_norm_momentum']: override_params['batch_norm_momentum'] = params['batch_norm_momentum'] if params['batch_norm_epsilon']: override_params['batch_norm_epsilon'] = params['batch_norm_epsilon'] if params['dropout_rate']: override_params['dropout_rate'] = params['dropout_rate'] if params['data_format']: override_params['data_format'] = params['data_format'] if params['num_label_classes']: override_params['num_classes'] = params['num_label_classes'] if params['depth_multiplier']: override_params['depth_multiplier'] = params['depth_multiplier'] if params['depth_divisor']: override_params['depth_divisor'] = params['depth_divisor'] if params['min_depth']: override_params['min_depth'] = params['min_depth'] override_params['use_keras'] = params['use_keras'] def _build_model(model_name): """Build the model for a given model name.""" if model_name.startswith('mnasnet'): return mnasnet_models.build_mnasnet_model( features, model_name=model_name, training=is_training, override_params=override_params) elif model_name.startswith('mixnet'): return mixnet_builder.build_model( features, model_name=model_name, training=is_training, override_params=override_params) else: raise ValueError('Unknown model name {}'.format(model_name)) if params['precision'] == 'bfloat16': with tf.tpu.bfloat16_scope(): logits, _ = _build_model(params['model_name']) logits = tf.cast(logits, tf.float32) else: # params['precision'] == 'float32' logits, _ = _build_model(params['model_name']) if params['quantized_training']: try: from tensorflow.contrib import quantize # pylint: disable=g-import-not-at-top except ImportError as e: logging.exception('Quantized training is not supported in TensorFlow 2.x') raise e if is_training: tf.logging.info('Adding fake quantization ops for training.') quantize.create_training_graph( quant_delay=int(params['steps_per_epoch'] * FLAGS.quantization_delay_epochs)) else: tf.logging.info('Adding fake quantization ops for evaluation.') quantize.create_eval_graph() if mode == tf.estimator.ModeKeys.PREDICT: scaffold_fn = None if FLAGS.export_moving_average: # If the model is trained with moving average decay, to match evaluation # metrics, we need to export the model using moving average variables. restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) variables_to_restore = get_pretrained_variables_to_restore( restore_checkpoint, load_moving_average=True) tf.logging.info('Restoring from the latest checkpoint: %s', restore_checkpoint) tf.logging.info(str(variables_to_restore)) def restore_scaffold(): saver = tf.train.Saver(variables_to_restore) return tf.train.Scaffold(saver=saver) scaffold_fn = restore_scaffold predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }, scaffold_fn=scaffold_fn) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, params['num_label_classes']) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=params['label_smoothing']) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + params['weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_global_step() if has_moving_average_decay: ema = tf.train.ExponentialMovingAverage( decay=params['moving_average_decay'], num_updates=global_step) ema_vars = mnas_utils.get_ema_vars() host_call = None if is_training: # Compute the current epoch and associated learning rate from global_step. current_epoch = ( tf.cast(global_step, tf.float32) / params['steps_per_epoch']) scaled_lr = params['base_learning_rate'] * (params['train_batch_size'] / 256.0) # pylint: disable=line-too-long learning_rate = mnas_utils.build_learning_rate(scaled_lr, global_step, params['steps_per_epoch']) optimizer = mnas_utils.build_optimizer(learning_rate) if params['use_tpu']: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = tf.tpu.CrossShardOptimizer(optimizer) if params['add_summaries']: summary_writer = tf2.summary.create_file_writer( FLAGS.model_dir, max_queue=params['iterations_per_loop']) with summary_writer.as_default(): should_record = tf.equal(global_step % params['iterations_per_loop'], 0) with tf2.summary.record_if(should_record): tf2.summary.scalar('loss', loss, step=global_step) tf2.summary.scalar('learning_rate', learning_rate, step=global_step) tf2.summary.scalar('current_epoch', current_epoch, step=global_step) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops + tf.summary.all_v2_summary_ops()): train_op = optimizer.minimize(loss, global_step) if has_moving_average_decay: with tf.control_dependencies([train_op]): train_op = ema.apply(ema_vars) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('number of trainable parameters: {}'.format(num_params)) # Prepares scaffold_fn if needed. scaffold_fn = None if is_training and FLAGS.init_checkpoint: variables_to_restore = get_pretrained_variables_to_restore( FLAGS.init_checkpoint, has_moving_average_decay) tf.logging.info('Initializing from pretrained checkpoint: %s', FLAGS.init_checkpoint) if FLAGS.use_tpu: def init_scaffold(): tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore) return tf.train.Scaffold() scaffold_fn = init_scaffold else: tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore) restore_vars_dict = None if not is_training and has_moving_average_decay: # Load moving average variables for eval. restore_vars_dict = ema.variables_to_restore(ema_vars) def eval_scaffold(): saver = tf.train.Saver(restore_vars_dict) return tf.train.Scaffold(saver=saver) scaffold_fn = eval_scaffold return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). # 设置多gpu训练的相关参数 config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, # gpu数量 clone_on_cpu=FLAGS.clone_on_cpu, # 默认为False replica_id=FLAGS.task, # taskId num_replicas=FLAGS.num_replicas, # 默认为1 num_ps_tasks=FLAGS.num_ps_tasks) # 默认为0 # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // config.num_clones # 各个gpu均分batch_size tf.gfile.MakeDirs(FLAGS.train_logdir) # 创建存放训练日志的文件 tf.logging.info('Training on %s set', FLAGS.train_split) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): dataset = data_generator.Dataset( # 定义数据集参数 dataset_name=FLAGS.dataset, # 数据集名称 cityscapes split_name=FLAGS.train_split, # 指定带有train的tfrecorder数据集 默认为“train” dataset_dir=FLAGS.dataset_dir, # 数据集目录 tfrecoder文件的数据集目录 batch_size=clone_batch_size, # 均分后各个gpu训练中指定batch_size 的大小 crop_size=[int(sz) for sz in FLAGS.train_crop_size], # 训练中裁剪的图像大小 513,513 min_resize_value=FLAGS.min_resize_value, # 默认为 None max_resize_value=FLAGS.max_resize_value, # 默认为None resize_factor=FLAGS.resize_factor, # 默认为None min_scale_factor=FLAGS.min_scale_factor, # 训练中,图像变换尺度,用于数据增强 默认最小为0.5 max_scale_factor=FLAGS.max_scale_factor, # 训练中,图像变换尺度,用于数据增强 默认最大为2 scale_factor_step_size=FLAGS.scale_factor_step_size, # 训练中,图像变换尺度增加的步长,默认为0.25 从0.5到2 model_variant=FLAGS.model_variant, # 指定模型 xception_65 num_readers=4, # 读取数据个数 若多gpu可增大加快训练速度 is_training=True, should_shuffle=True, should_repeat=True) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): # 计数作用,每训练一个batch, global加1 global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_deeplab # 定义deeplab模型 model_args = (dataset.get_one_shot_iterator(), { common.OUTPUT_TYPE: dataset.num_of_classes }, dataset.ignore_label) #模型参数 clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in tf.model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for images, labels, semantic predictions if FLAGS.save_summaries_images: # 默认为False summary_image = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')) summaries.add( tf.summary.image('samples/%s' % common.IMAGE, summary_image)) first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')) # Scale up summary image pixel values for better visualization. pixel_scaling = max(1, 255 // dataset.num_of_classes) summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.LABEL, summary_label)) first_clone_output = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')) predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1) summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8) summaries.add( tf.summary.image( 'samples/%s' % common.OUTPUT_TYPE, summary_predictions)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( # 获取模型学习率 FLAGS.learning_policy, # poly学习策略 FLAGS.base_learning_rate, # 0.0001 FLAGS.learning_rate_decay_step, # 固定2000次进行一次学习率衰退 FLAGS.learning_rate_decay_factor, # 0.1 FLAGS.training_number_of_steps, # 训练次数 20000 FLAGS.learning_power, # poly power 0.9 FLAGS.slow_start_step, # 0 FLAGS.slow_start_learning_rate, # 1e-4 缓慢开始的学习率 decay_steps=FLAGS.decay_steps, # 0.0 end_learning_rate=FLAGS.end_learning_rate) # 0.0 summaries.add(tf.summary.scalar('learning_rate', learning_rate)) # 模型训练优化器 if FLAGS.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) elif FLAGS.optimizer == 'adam': # adam优化器 寻找全局最优点的优化算法,引入了二次方梯度校正 optimizer = tf.train.AdamOptimizer( learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon) else: raise ValueError('Unknown optimizer') if FLAGS.quantize_delay_step >= 0: # 默认为-1 忽略 if FLAGS.num_clones > 1: raise ValueError('Quantization doesn\'t support multi-clone yet.') contrib_quantize.create_training_graph( quant_delay=FLAGS.quantize_delay_step) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps # FLAGS.startup_delay_steps 默认为15 with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) # 计算total_loss total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) # 获取梯度乘子 grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) # grad_mult : {'logits/semantic/biases': 2.0, 'logits/semantic/weights': 1.0} if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients( # 将计算的梯度用于变量上,返回一个应用指定的梯度的操作 opration grads_and_vars, global_step=global_step) # 对global_step进行自增 update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False) # Start the training. profile_dir = FLAGS.profile_logdir # 默认为None if profile_dir is not None: tf.gfile.MakeDirs(profile_dir) with contrib_tfprof.ProfileContext( enabled=profile_dir is not None, profile_dir=profile_dir): init_fn = None if FLAGS.tf_initial_checkpoint: # 获取预训练权重 init_fn = train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True) slim.learning.train( train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=init_fn, summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks, ) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### keys_to_features = { "image/encoded": tf.FixedLenFeature((), tf.string, default_value=""), "image/format": tf.FixedLenFeature((), tf.string, default_value="png"), "image/class/label": tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } items_to_handlers = { "image": slim.tfexample_decoder.Image(), "label": slim.tfexample_decoder.Tensor("image/class/label"), } items_to_descs = { "image": "Color image", "label": "Class idx", } label_idx_to_name = {} for i, label in enumerate(CLASSES): label_idx_to_name[i] = label decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) file_pattern = "tfm_clf_%s.*" file_pattern = os.path.join(FLAGS.records_name, file_pattern % FLAGS.dataset_split_name) dataset = slim.dataset.Dataset( data_sources=file_pattern, # TODO UPDATE reader=tf.TFRecordReader, decoder=decoder, num_samples=80000, # TODO UPDATE items_to_descriptions=items_to_descs, num_classes=len(CLASSES), labels_to_names=label_idx_to_name, ) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True, ) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True, use_grayscale=FLAGS.use_grayscale) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size, ) [image, label] = provider.get(["image", "label"]) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size, ) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating network_fn clones.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if "AuxLogits" in end_points: slim.losses.softmax_cross_entropy( end_points["AuxLogits"], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope="aux_loss", ) slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram("activations/" + end_point, x)) summaries.add( tf.summary.scalar("sparsity/" + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar("losses/%s" % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None if FLAGS.quantize_delay >= 0: contrib_quantize.create_training_graph( quant_delay=FLAGS.quantize_delay) ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar("learning_rate", learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables, ) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append(variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar("total_loss", total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name="train_op") # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name="summary_op") ########################### # Kicks off the training. # ########################### slim.learning.train( train_tensor, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None, )
sc1 = tf.add(conv2, pool1, name='sc1') conv3 = slim.conv2d(sc1, 128, [3, 3], stride=2, scope='conv3') downsample1 = tf.image.resize_nearest_neighbor( sc1, [64, 64], name='nearest_neighbor_downsample') concat1 = tf.concat([conv3, downsample1], axis=3, name='concat1') dconv1 = slim.conv2d_transpose(concat1, 128, [3, 3], stride=2, scope='dconv1') upsample1 = tf.image.resize_images(concat1, [128, 128]) concat2 = tf.concat([dconv1, upsample1], axis=3, name='concat2') conv4 = slim.conv2d(upsample1, 128, [1, 1], scope='conv4') conv5 = slim.conv2d(concat2, 128, [1, 1], scope='conv5') sc2 = tf.add(conv4, conv5, name='sc2') pool2 = tf.reduce_mean(sc2, axis=[1, 2], name='global_average_pool') return pool2 g = tf.Graph() with g.as_default(): with slim.arg_scope(arg_scope()): end = net() quantize.create_training_graph(input_graph=g, quant_delay=2000000) tf.summary.FileWriter('.', g)
def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor, quantize_layer, is_training): """Adds a new softmax and fully-connected layer for training and eval. We need to retrain the top layer to identify our new classes, so this function adds the right operations to the graph, along with some variables to hold the weights, and then sets up all the gradients for the backward pass. The set up for the softmax and fully-connected layers is based on: https://www.tensorflow.org/tutorials/mnist/beginners/index.html Args: class_count: Integer of how many categories of things we're trying to recognize. final_tensor_name: Name string for the new final node that produces results. bottleneck_tensor: The output of the main CNN graph. quantize_layer: Boolean, specifying whether the newly added layer should be instrumented for quantization with TF-Lite. is_training: Boolean, specifying whether the newly add layer is for training or eval. Returns: The tensors for the training and cross entropy results, and tensors for the bottleneck input and ground truth input. """ batch_size, bottleneck_tensor_size = bottleneck_tensor.get_shape().as_list() assert batch_size is None, 'We want to work with arbitrary batch size.' with tf.name_scope('input'): bottleneck_input = tf.placeholder_with_default( bottleneck_tensor, shape=[batch_size, bottleneck_tensor_size], name='BottleneckInputPlaceholder') ground_truth_input = tf.placeholder( tf.int64, [batch_size], name='GroundTruthInput') # Organizing the following ops so they are easier to see in TensorBoard. layer_name = 'final_retrain_ops' with tf.name_scope(layer_name): with tf.name_scope('weights'): initial_value = tf.truncated_normal( [bottleneck_tensor_size, class_count], stddev=0.001) layer_weights = tf.Variable(initial_value, name='final_weights') variable_summaries(layer_weights) with tf.name_scope('biases'): layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') variable_summaries(layer_biases) with tf.name_scope('Wx_plus_b'): logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases tf.summary.histogram('pre_activations', logits) final_tensor = tf.nn.softmax(logits, name=final_tensor_name) # The tf.contrib.quantize functions rewrite the graph in place for # quantization. The imported model graph has already been rewritten, so upon # calling these rewrites, only the newly added final layer will be # transformed. if quantize_layer: if is_training: contrib_quantize.create_training_graph() else: contrib_quantize.create_eval_graph() tf.summary.histogram('activations', final_tensor) # If this is an eval graph, we don't need to add loss ops or an optimizer. if not is_training: return None, None, bottleneck_input, ground_truth_input, final_tensor with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'): optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate) train_step = optimizer.minimize(cross_entropy_mean) return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input, final_tensor)
def main(_): if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') if not FLAGS.frozen_pb: raise ValueError('You must supply the frozen pb with --frozen_pb') if not FLAGS.output_node_name: raise ValueError( 'You must supply the output node name with --output_node_name') if not FLAGS.output_dir: raise ValueError( 'You must supply the output directory with --output_dir') tf.logging.set_verbosity(tf.logging.INFO) tfrecords = prepare_tfrecords(FLAGS.dataset_name, FLAGS.dataset_dir, FLAGS.dataset_split_name) if FLAGS.max_num_batches: num_batches = FLAGS.max_num_batches else: num_records = sum( [len(list(tf.python_io.tf_record_iterator(r))) for r in tfrecords]) num_batches = int(math.ceil(num_records / float(FLAGS.batch_size))) tf.logging.info('Load GraphDef from frozen_pb {}'.format(FLAGS.frozen_pb)) graph_def = load_graph_def(FLAGS.frozen_pb) tf.logging.info('Quantize Graph') with tf.Session() as sess: tf.import_graph_def(graph_def, name='') quantized_graph = qg.create_training_graph(sess.graph) quantized_inf_graph = qg.create_eval_graph(sess.graph) # Initialize `iterator` with training data. with tf.Session(graph=quantized_graph) as sess: tf.logging.info('Prepare dataset') with tf.name_scope("dataset"): filenames = tf.placeholder(tf.string, shape=[None]) dataset = prepare_dataset(filenames, FLAGS.dataset_name, FLAGS.input_size, batch_size=FLAGS.batch_size) iterator = dataset.make_initializable_iterator() next_batch = iterator.get_next() tf.logging.info('Prepare metrics') lbls, preds, accuracy, acc_update_op = prepare_metrics( FLAGS.dataset_name) tf.logging.info('Prepare Saver') saver = tf.train.Saver() if FLAGS.summary_dir: tf.logging.info('Prepare summary writer') summary_writer = tf.summary.FileWriter(FLAGS.summary_dir) # initialize sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(iterator.initializer, feed_dict={filenames: tfrecords}) graph = sess.graph # get x and y x = graph.get_tensor_by_name('{}:0'.format(FLAGS.input_node_name)) y = graph.get_tensor_by_name('{}:0'.format(FLAGS.output_node_name)) # summary all min/max variables # print(graph.get_collection('variables')[3].eval()) for var in graph.get_collection('variables'): tf.summary.scalar(var.name, var) summaries = tf.summary.merge_all() for step in range(num_batches): images, labels = sess.run(next_batch) ys = sess.run(y, feed_dict={x: images}) sess.run(acc_update_op, feed_dict={lbls: labels, preds: ys}) summary = sess.run(summaries) if FLAGS.summary_dir: summary_writer.add_summary(summary, step) print('Accuracy: [{:.4f}]'.format(sess.run(accuracy))) if FLAGS.summary_dir: summary_writer.add_graph(graph) # save graph and ckpts saver.save(sess, os.path.join(FLAGS.output_dir, "model.ckpt")) # tf.train.write_graph(graph, FLAGS.output_dir, 'quantor.pb', as_text=False) tf.train.write_graph(quantized_inf_graph, FLAGS.output_dir, 'quantor.pb', as_text=False)