Ejemplo n.º 1
0
def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
                          quantize_layer, is_training):
  batch_size, bottleneck_tensor_size = bottleneck_tensor.get_shape().as_list()
  assert batch_size is None, 'We want to work with arbitrary batch size.'
  with tf.name_scope('input'):
    bottleneck_input = tf.placeholder_with_default(
        bottleneck_tensor,
        shape=[batch_size, bottleneck_tensor_size],
        name='BottleneckInputPlaceholder')

    ground_truth_input = tf.placeholder(
        tf.int64, [batch_size], name='GroundTruthInput')

  # Organizing the following ops so they are easier to see in TensorBoard.
  layer_name = 'final_retrain_ops'
  with tf.name_scope(layer_name):
    with tf.name_scope('weights'):
      initial_value = tf.truncated_normal(
          [bottleneck_tensor_size, class_count], stddev=0.001)
      layer_weights = tf.Variable(initial_value, name='final_weights')
      variable_summaries(layer_weights)

    with tf.name_scope('biases'):
      layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
      variable_summaries(layer_biases)

    with tf.name_scope('Wx_plus_b'):
      logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
      tf.summary.histogram('pre_activations', logits)

  final_tensor = tf.nn.softmax(logits, name=final_tensor_name)

  if quantize_layer:
    if is_training:
      contrib_quantize.create_training_graph()
    else:
      contrib_quantize.create_eval_graph()

  tf.summary.histogram('activations', final_tensor)

  if not is_training:
    return None, None, bottleneck_input, ground_truth_input, final_tensor

  with tf.name_scope('cross_entropy'):
    cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
        labels=ground_truth_input, logits=logits)

  tf.summary.scalar('cross_entropy', cross_entropy_mean)

  with tf.name_scope('train'):
    optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
    train_step = optimizer.minimize(cross_entropy_mean)

  return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
          final_tensor)
Ejemplo n.º 2
0
def build_model():
    """Builds graph for model to train with rewrites for quantization.

  Returns:
    g: Graph with fake quantization ops and batch norm folding suitable for
    training quantized weights.
    train_tensor: Train op for execution during training.
  """
    g = tf.Graph()
    with g.as_default(), tf.device(
            tf.train.replica_device_setter(FLAGS.ps_tasks)):
        inputs, labels = imagenet_input(is_training=True)
        with slim.arg_scope(
                mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)):
            logits, _ = mobilenet_v1.mobilenet_v1(
                inputs,
                is_training=True,
                depth_multiplier=FLAGS.depth_multiplier,
                num_classes=FLAGS.num_classes,
                final_endpoint=FLAGS.final_endpoint)

        tf.losses.softmax_cross_entropy(labels, logits)

        # Call rewriter to produce graph with fake quant ops and folded batch norms
        # quant_delay delays start of quantization till quant_delay steps, allowing
        # for better model accuracy.
        if FLAGS.quantize:
            contrib_quantize.create_training_graph(
                quant_delay=get_quant_delay())

        total_loss = tf.losses.get_total_loss(name='total_loss')
        # Configure the learning rate using an exponential decay.
        num_epochs_per_decay = 2.5
        imagenet_size = 1271167
        decay_steps = int(imagenet_size / FLAGS.batch_size *
                          num_epochs_per_decay)

        learning_rate = tf.train.exponential_decay(
            get_learning_rate(),
            tf.train.get_or_create_global_step(),
            decay_steps,
            _LEARNING_RATE_DECAY_FACTOR,
            staircase=True)
        opt = tf.train.GradientDescentOptimizer(learning_rate)

        train_tensor = slim.learning.create_train_op(total_loss, optimizer=opt)

    slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses')
    slim.summaries.add_scalar_summary(learning_rate, 'learning_rate',
                                      'training')
    return g, train_tensor
Ejemplo n.º 3
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    common.outputlogMessage('Training on %s set' % FLAGS.train_split)
    common.outputlogMessage('Dataset: %s' % FLAGS.dataset)
    common.outputlogMessage('train_crop_size: %s' % str(FLAGS.train_crop_size))
    common.outputlogMessage(str(FLAGS.train_crop_size))
    common.outputlogMessage('atrous_rates: %s' % str(FLAGS.atrous_rates))
    common.outputlogMessage('number of classes: %s' % str(FLAGS.num_classes))
    common.outputlogMessage('Ignore label value: %s' % str(FLAGS.ignore_label))
    pid = os.getpid()
    with open('train_py_pid.txt', 'w') as f_obj:
        f_obj.writelines('%d' % pid)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True,
                num_classes=FLAGS.num_classes,
                ignore_label=FLAGS.ignore_label)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (dataset.get_one_shot_iterator(), {
                common.OUTPUT_TYPE: dataset.num_of_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy,
                FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps,
                FLAGS.learning_power,
                FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate,
                decay_steps=FLAGS.decay_steps,
                end_learning_rate=FLAGS.end_learning_rate)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.adam_learning_rate,
                    epsilon=FLAGS.adam_epsilon)
            else:
                raise ValueError('Unknown optimizer')

        if FLAGS.quantize_delay_step >= 0:
            if FLAGS.num_clones > 1:
                raise ValueError(
                    'Quantization doesn\'t support multi-clone yet.')
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay_step)

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        profile_dir = FLAGS.profile_logdir
        if profile_dir is not None:
            tf.gfile.MakeDirs(profile_dir)

        with contrib_tfprof.ProfileContext(enabled=profile_dir is not None,
                                           profile_dir=profile_dir):
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Ejemplo n.º 4
0
def model_fn(features, labels, mode, params, tf_sess=False):
    """
    Create the model for estimator api

    Args:
        features: if input_layout == 'nhwc', a tensor with shape:
                [BATCH_SIZE, go.N, go.N, get_features_planes()]
            else, a tensor with shape:
                [BATCH_SIZE, get_features_planes(), go.N, go.N]
        labels: dict from string to tensor with shape
            'pi_tensor': [BATCH_SIZE, go.N * go.N + 1]
            'value_tensor': [BATCH_SIZE]
        mode: a tf.estimator.ModeKeys (batchnorm params update for TRAIN only)
        params: A dictionary (Typically derived from the FLAGS object.)
    Returns: tf.estimator.EstimatorSpec with props
        mode: same as mode arg
        predictions: dict of tensors
            'policy': [BATCH_SIZE, go.N * go.N + 1]
            'value': [BATCH_SIZE]
        loss: a single value tensor
        train_op: train op
        eval_metric_ops
    return dict of tensors
        logits: [BATCH_SIZE, go.N * go.N + 1]
    """

    policy_output, value_output, logits = model_inference_fn(
        features, mode == tf.estimator.ModeKeys.TRAIN, params)

    # train ops
    policy_cost = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                   labels=tf.stop_gradient(
                                                       labels['pi_tensor'])))

    value_cost = params['value_cost_weight'] * tf.reduce_mean(
        tf.square(value_output - labels['value_tensor']))

    reg_vars = [
        v for v in tf.trainable_variables()
        if 'bias' not in v.name and 'beta' not in v.name
    ]
    l2_cost = params['l2_strength'] * \
        tf.add_n([tf.nn.l2_loss(v) for v in reg_vars])

    combined_cost = policy_cost + value_cost + l2_cost

    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.piecewise_constant(global_step,
                                                params['lr_boundaries'],
                                                params['lr_rates'])
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    # Insert quantization ops if requested
    if params['quantize']:
        if mode == tf.estimator.ModeKeys.TRAIN:
            contrib_quantize.create_training_graph(
                quant_delay=params['quant_delay'])
        else:
            contrib_quantize.create_eval_graph()

    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                           params['sgd_momentum'])

    # hvd multigpu
    optimizer = hvd.DistributedOptimizer(optimizer)

    if params['use_tpu']:
        optimizer = contrib_tpu_python_tpu_tpu_optimizer.CrossShardOptimizer(
            optimizer)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(combined_cost, global_step=global_step)

    # return train_op for sess
    if tf_sess: return train_op

    # Computations to be executed on CPU, outside of the main TPU queues.
    def eval_metrics_host_call_fn(policy_output,
                                  value_output,
                                  pi_tensor,
                                  value_tensor,
                                  policy_cost,
                                  value_cost,
                                  l2_cost,
                                  combined_cost,
                                  step,
                                  est_mode=tf.estimator.ModeKeys.TRAIN):
        policy_entropy = -tf.reduce_mean(
            tf.reduce_sum(policy_output * tf.log(policy_output), axis=1))
        # pi_tensor is one_hot when generated from sgfs (for supervised learning)
        # and soft-max when using self-play records. argmax normalizes the two.
        policy_target_top_1 = tf.argmax(pi_tensor, axis=1)

        policy_output_in_top1 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=1))
        policy_output_in_top3 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=3))

        policy_top_1_confidence = tf.reduce_max(policy_output, axis=1)
        policy_target_top_1_confidence = tf.boolean_mask(
            policy_output,
            tf.one_hot(policy_target_top_1,
                       tf.shape(policy_output)[1]))

        value_cost_normalized = value_cost / params['value_cost_weight']
        avg_value_observed = tf.reduce_mean(value_tensor)

        with tf.variable_scope('metrics'):
            metric_ops = {
                'policy_cost':
                tf.metrics.mean(policy_cost),
                'value_cost':
                tf.metrics.mean(value_cost),
                'value_cost_normalized':
                tf.metrics.mean(value_cost_normalized),
                'l2_cost':
                tf.metrics.mean(l2_cost),
                'policy_entropy':
                tf.metrics.mean(policy_entropy),
                'combined_cost':
                tf.metrics.mean(combined_cost),
                'avg_value_observed':
                tf.metrics.mean(avg_value_observed),
                'policy_accuracy_top_1':
                tf.metrics.mean(policy_output_in_top1),
                'policy_accuracy_top_3':
                tf.metrics.mean(policy_output_in_top3),
                'policy_top_1_confidence':
                tf.metrics.mean(policy_top_1_confidence),
                'policy_target_top_1_confidence':
                tf.metrics.mean(policy_target_top_1_confidence),
                'value_confidence':
                tf.metrics.mean(tf.abs(value_output)),
            }

        if est_mode == tf.estimator.ModeKeys.EVAL:
            return metric_ops

        # NOTE: global_step is rounded to a multiple of FLAGS.summary_steps.
        eval_step = tf.reduce_min(step)

        # Create summary ops so that they show up in SUMMARIES collection
        # That way, they get logged automatically during training
        summary_writer = contrib_summary.create_file_writer(FLAGS.work_dir)
        with summary_writer.as_default(), \
                contrib_summary.record_summaries_every_n_global_steps(
                    params['summary_steps'], eval_step):
            for metric_name, metric_op in metric_ops.items():
                contrib_summary.scalar(metric_name,
                                       metric_op[1],
                                       step=eval_step)

        # Reset metrics occasionally so that they are mean of recent batches.
        reset_op = tf.variables_initializer(tf.local_variables('metrics'))
        cond_reset_op = tf.cond(
            tf.equal(eval_step % params['summary_steps'], tf.to_int64(1)),
            lambda: reset_op, lambda: tf.no_op())

        return contrib_summary.all_summary_ops() + [cond_reset_op]

    metric_args = [
        policy_output,
        value_output,
        labels['pi_tensor'],
        labels['value_tensor'],
        tf.reshape(policy_cost, [1]),
        tf.reshape(value_cost, [1]),
        tf.reshape(l2_cost, [1]),
        tf.reshape(combined_cost, [1]),
        tf.reshape(global_step, [1]),
    ]

    predictions = {
        'policy_output': policy_output,
        'value_output': value_output,
    }

    eval_metrics_only_fn = functools.partial(
        eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.EVAL)
    host_call_fn = functools.partial(eval_metrics_host_call_fn,
                                     est_mode=tf.estimator.ModeKeys.TRAIN)

    tpu_estimator_spec = contrib_tpu_python_tpu_tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=combined_cost,
        train_op=train_op)
    if params['use_tpu']:
        return tpu_estimator_spec
    else:
        return tpu_estimator_spec.as_estimator_spec()
Ejemplo n.º 5
0
    def build_loss(self):
        response = self.response
        response_size = response.get_shape().as_list()[1:3]  # [height, width]

        gt = construct_gt_score_maps(
            response_size, self.data_config['batch_size'],
            self.model_config['embed_config']['stride'],
            self.train_config['gt_config'])

        # loss: https://www.renom.jp/ja/notebooks/tutorial/basic_algorithm/lossfunction/notebook.html
        with tf.name_scope('Loss'):
            loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=response,
                                                           labels=gt)

            with tf.name_scope('Balance_weights'):
                n_pos = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 1)))
                n_neg = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 0)))
                w_pos = 0.5 / n_pos
                w_neg = 0.5 / n_neg
                class_weights = tf.where(tf.equal(gt, 1),
                                         w_pos * tf.ones_like(gt),
                                         tf.ones_like(gt))
                class_weights = tf.where(tf.equal(gt, 0),
                                         w_neg * tf.ones_like(gt),
                                         class_weights)
                loss = loss * class_weights

            # Note that we use reduce_sum instead of reduce_mean since the loss has
            # already been normalized by class_weights in spatial dimension.
            loss = tf.reduce_sum(loss, [1, 2])

            batch_loss = tf.reduce_mean(loss, name='batch_loss')
            tf.losses.add_loss(batch_loss)

            total_loss = tf.losses.get_total_loss()
            self.batch_loss = batch_loss
            self.total_loss = total_loss

            # quantization
            # good note: https://www.tensorflowers.cn/t/7136
            if self.model_config['embed_config']['quantization']:
                if self.train_config["export"]:
                    contrib_quantize.create_eval_graph()
                else:
                    contrib_quantize.create_training_graph(quant_delay=200000)

            tf.summary.image('exemplar', self.exemplars, family=self.mode)
            tf.summary.image('instance', self.instances, family=self.mode)

            mean_batch_loss, update_op1 = tf.metrics.mean(batch_loss)
            mean_total_loss, update_op2 = tf.metrics.mean(total_loss)
            with tf.control_dependencies([update_op1, update_op2]):
                tf.summary.scalar('batch_loss',
                                  mean_batch_loss,
                                  family=self.mode)
                tf.summary.scalar('total_loss',
                                  mean_total_loss,
                                  family=self.mode)

            if self.mode == 'train':
                tf.summary.image('GT',
                                 tf.reshape(gt[0], [1] + response_size + [1]),
                                 family='GT')
            tf.summary.image('Response',
                             tf.expand_dims(tf.sigmoid(response), -1),
                             family=self.mode)
            tf.summary.histogram('Response', self.response, family=self.mode)

            # Two more metrics to monitor the performance of training
            tf.summary.scalar('center_score_error',
                              center_score_error(response),
                              family=self.mode)
            tf.summary.scalar('center_dist_error',
                              center_dist_error(response),
                              family=self.mode)
Ejemplo n.º 6
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=True,
            use_grayscale=FLAGS.use_grayscale)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main():

    # check required input arguments
    if not FLAGS.project_name:
        raise ValueError('You must supply a project name with --project_name')
    if not FLAGS.dataset_name:
        raise ValueError('You must supply a dataset name with --dataset_name')
    if not FLAGS.model_name in model_name_to_variables:
        raise ValueError(
            'Model name not supported name please select one of the following model architecture: mobilenet_v1, mobilenet_v1_075, mobilenet_v1_050, mobilenet_v1_025, inception_v1'
        )

    # set and check project_dir and experiment_dir.
    project_dir = os.path.join(FLAGS.project_dir, FLAGS.project_name)
    if not FLAGS.experiment_name:
        # list only directories that are names experiment_
        experiment_dir = dataset_utils.create_new_experiment_dir(project_dir)
    else:
        experiment_dir = os.path.join(os.path.join(project_dir, 'experiments'),
                                      FLAGS.experiment_name)
        if not os.path.exists(experiment_dir):
            raise ValueError('Experiment directory {} does not exist.'.format(
                experiment_dir))

    train_dir = os.path.join(experiment_dir, FLAGS.dataset_split_name)
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)

    # set and check dataset_dir
    if FLAGS.image_dir:
        dataset_dir = convert_dataset.convert_img_to_tfrecord(
            project_dir, FLAGS.dataset_name, FLAGS.dataset_dir,
            FLAGS.image_dir, FLAGS.train_percentage,
            FLAGS.validation_percentage, FLAGS.test_percentage,
            FLAGS.train_image_size, FLAGS.train_image_size)
    else:
        if os.path.isdir(FLAGS.dataset_dir):
            dataset_dir = os.path.join(FLAGS.dataset_dir, FLAGS.dataset_name)
        else:
            dataset_dir = os.path.join(os.path.join(project_dir, 'datasets'),
                                       FLAGS.dataset_name)
    if not os.path.isdir(dataset_dir):
        raise ValueError(
            'Can not find tfrecord dataset directory {}'.format(dataset_dir))
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=not FLAGS.feature_extraction,
            final_endpoint=FLAGS.final_endpoint)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=FLAGS.apply_image_augmentation,
            use_grayscale=FLAGS.use_grayscale)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(
                image,
                train_image_size,
                train_image_size,
                add_image_summaries=FLAGS.add_image_summaries,
                crop_image=FLAGS.random_image_crop,
                min_object_covered=FLAGS.min_object_covered,
                rotate_image=FLAGS.random_image_rotation,
                random_flip=FLAGS.random_image_flip,
                roi=FLAGS.roi)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if FLAGS.imbalance_correction:
                # specify some class weightings
                class_weights = dataset.sorted_class_weights
                # deduce weights for batch samples based on their true label
                weights = tf.reduce_sum(tf.multiply(labels, class_weights), 1)

                slim.losses.softmax_cross_entropy(
                    logits,
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=weights)
            else:
                if 'AuxLogits' in end_points:
                    slim.losses.softmax_cross_entropy(
                        end_points['AuxLogits'],
                        labels,
                        label_smoothing=FLAGS.label_smoothing,
                        weights=0.4,
                        scope='aux_loss')
                else:
                    slim.losses.softmax_cross_entropy(
                        logits,
                        labels,
                        label_smoothing=FLAGS.label_smoothing,
                        weights=1.0)
            #############################
            ## Calculation of metrics ##
            #############################
            accuracy, accuracy_op = tf.metrics.accuracy(
                tf.argmax(labels, 1), tf.argmax(logits, 1))
            precision, precision_op = tf.metrics.average_precision_at_k(
                tf.argmax(labels, 1), logits, 1)

            with tf.device('/device:CPU:0'):
                for class_id in range(dataset.num_classes):
                    precision_at_k, precision_at_k_op = tf.metrics.precision_at_k(
                        tf.argmax(labels, 1), logits, k=1, class_id=class_id)
                    recall_at_k, recall_at_k_op = tf.metrics.recall_at_k(
                        tf.argmax(labels, 1), logits, k=1, class_id=class_id)
                    tf.add_to_collection('precision_at_{}'.format(class_id),
                                         precision_at_k)
                    tf.add_to_collection('precision_at_{}_op'.format(class_id),
                                         precision_at_k_op)
                    tf.add_to_collection('recall_at_{}'.format(class_id),
                                         recall_at_k)
                    tf.add_to_collection('recall_at_{}_op'.format(class_id),
                                         recall_at_k_op)

            tf.add_to_collection('accuracy', accuracy)
            tf.add_to_collection('accuracy_op', accuracy_op)
            tf.add_to_collection('precision', precision)
            tf.add_to_collection('precision_op', precision_op)

            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #########################################################
        ## Calculation of metrics for all clones ##
        #########################################################
        # Metrics for all clones.
        accuracy = tf.get_collection('accuracy')
        accuracy_op = tf.get_collection('accuracy_op')
        precision = tf.get_collection('precision')
        precision_op = tf.get_collection('precision_op')
        # accuracy_op = tf.reshape(accuracy_op, [])

        # Stack and take the mean.
        accuracy = tf.reduce_mean(tf.stack(accuracy, axis=0))
        accuracy_op = tf.reduce_mean(tf.stack(accuracy_op, axis=0))
        precision = tf.reduce_mean(tf.stack(precision, axis=0))
        precision_op = tf.reduce_mean(tf.stack(precision_op, axis=0))

        # Add metric summaries.
        summaries.add(tf.summary.scalar('Metrics/accuracy', accuracy))
        summaries.add(tf.summary.scalar('op/accuracy_op', accuracy_op))
        summaries.add(tf.summary.scalar('Metrics/average_precision',
                                        precision))
        summaries.add(
            tf.summary.scalar('op/average_precision_op', precision_op))

        # Add precision/recall at each class to summary
        for class_id in range(dataset.num_classes):
            precision_at_k = tf.get_collection(
                'precision_at_{}'.format(class_id))
            precision_at_k_op = tf.get_collection(
                'precision_at_{}_op'.format(class_id))
            recall_at_k = tf.get_collection('recall_at_{}'.format(class_id))
            recall_at_k_op = tf.get_collection(
                'recall_at_{}_op'.format(class_id))

            precision_at_k = tf.reduce_mean(tf.stack(precision_at_k, axis=0))
            precision_at_k_op = tf.reduce_mean(
                tf.stack(precision_at_k_op, axis=0))
            recall_at_k = tf.reduce_mean(tf.stack(recall_at_k, axis=0))
            recall_at_k_op = tf.reduce_mean(tf.stack(recall_at_k_op, axis=0))

            summaries.add(
                tf.summary.scalar(
                    'Metrics/class_{}_precision'.format(class_id),
                    precision_at_k))
            summaries.add(
                tf.summary.scalar('op/class_{}_precision_op'.format(class_id),
                                  precision_at_k_op))
            summaries.add(
                tf.summary.scalar('Metrics/class_{}_recall'.format(class_id),
                                  recall_at_k))
            summaries.add(
                tf.summary.scalar('op/class_{}_recall_op'.format(class_id),
                                  recall_at_k_op))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(
                tf.summary.scalar('Losses/learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        loss = total_loss
        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        session_config = tf.ConfigProto(
            log_device_placement=FLAGS.verbose_placement,
            allow_soft_placement=not FLAGS.hard_placement)
        if not FLAGS.fixed_memory:
            session_config.gpu_options.allow_growth = True

        ###########################
        # Kicks off the training. #
        ###########################
        def train_step_fn(sess, train_op, global_step, train_step_kwargs):
            """Function that takes a gradient step and specifies whether to stop.
      Args:
        sess: The current session.
        train_op: An `Operation` that evaluates the gradients and returns the total
          loss.
        global_step: A `Tensor` representing the global training step.
        train_step_kwargs: A dictionary of keyword arguments.
      Returns:
        The total loss and a boolean indicating whether or not to stop training.
      Raises:
        ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
      """
            start_time = time.time()

            trace_run_options = None
            run_metadata = None
            if 'should_trace' in train_step_kwargs:
                if 'logdir' not in train_step_kwargs:
                    raise ValueError(
                        'logdir must be present in train_step_kwargs when '
                        'should_trace is present')
                if sess.run(train_step_kwargs['should_trace']):
                    trace_run_options = config_pb2.RunOptions(
                        trace_level=config_pb2.RunOptions.FULL_TRACE)
                    run_metadata = config_pb2.RunMetadata()

            total_loss, np_global_step = sess.run([train_op, global_step],
                                                  options=trace_run_options,
                                                  run_metadata=run_metadata)

            time_elapsed = time.time() - start_time

            if run_metadata is not None:
                tl = timeline.Timeline(run_metadata.step_stats)
                trace = tl.generate_chrome_trace_format()
                trace_filename = os.path.join(
                    train_step_kwargs['logdir'],
                    'tf_trace-%d.json' % np_global_step)
                logging.info('Writing trace to %s', trace_filename)
                file_io.write_string_to_file(trace_filename, trace)
                if 'summary_writer' in train_step_kwargs:
                    train_step_kwargs['summary_writer'].add_run_metadata(
                        run_metadata, 'run_metadata-%d' % np_global_step)

            if 'should_log' in train_step_kwargs:
                if sess.run(train_step_kwargs['should_log']):
                    print('global step {:d}: loss = {:1.4f} ({:.3f} sec/step)'.
                          format(np_global_step, total_loss, time_elapsed))

            if 'should_stop' in train_step_kwargs:
                should_stop = sess.run(train_step_kwargs['should_stop'])
            else:
                should_stop = False

            return total_loss, should_stop or train_step_fn.should_stop

        train_step_fn.should_stop = False

        # train_step_fn.accuracy = accuracy

        def exit_gracefully(signum, frame):
            interrupted = datetime.datetime.utcnow()
            # if not experiment_file is None :
            print('Interrupted on (UTC): ',
                  interrupted,
                  sep='',
                  file=experiment_file)
            experiment_file.flush()
            train_step_fn.should_stop = True
            print('Interrupted on (UTC): ', interrupted, sep='')

        signal.signal(signal.SIGINT, exit_gracefully)
        signal.signal(signal.SIGTERM, exit_gracefully)

        start = datetime.datetime.utcnow()
        print('Started on (UTC): ', start, sep='')

        # record script flags (FLAGS). write to experiment file
        experiment_file_path = os.path.join(train_dir,
                                            'experiment_setting.txt')
        experiment_file = open(experiment_file_path, 'w')
        print('Experiment metadata file:', file=experiment_file)
        print(experiment_file_path, file=experiment_file)
        print('========================', file=experiment_file)
        print('All command-line flags:', file=experiment_file)
        print(experiment_file_path, file=experiment_file)
        for key, value in vars(FLAGS).items():
            print(key, ' : ', value, sep='', file=experiment_file)
        print('========================', file=experiment_file)
        print('Started on (UTC): ', start, sep='', file=experiment_file)
        experiment_file.flush()

        slim.learning.train(
            train_tensor,
            train_step_fn=train_step_fn,
            logdir=train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            session_config=session_config)

        finish = datetime.datetime.utcnow()
        # generate and save graph (output file model_name_graph.pb)
        print('Generate frozen graph')
        # TODO: Simplify by loading checkpoint+graph and freezing together (no need to save graph)
        # genrate and save inference graph
        is_training = False
        is_video_model = False
        batch_size = None
        num_frames = None
        quantize = False
        write_text_graphdef = False
        output_file = os.path.join(train_dir, FLAGS.model_name + '_graph.pb')
        export_inference_graph(FLAGS.dataset_name, dataset_dir,
                               FLAGS.model_name, FLAGS.labels_offset,
                               is_training, FLAGS.final_endpoint,
                               FLAGS.train_image_size, FLAGS.use_grayscale,
                               is_video_model, batch_size, num_frames,
                               quantize, write_text_graphdef, output_file)
        # record training session end
        print('Finished on (UTC): ', finish, sep='', file=experiment_file)
        print('Elapsed: ', finish - start, sep='', file=experiment_file)
        experiment_file.flush()
Ejemplo n.º 8
0
def build_model_fn(features, labels, mode, params):
  """The model_fn for MnasNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
      `params['batch_size']` is always provided and should be used as the
      effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  # This is essential, if using a keras-derived model.
  tf.keras.backend.set_learning_phase(is_training)

  if isinstance(features, dict):
    features = features['feature']

  if mode == tf.estimator.ModeKeys.PREDICT:
    # Adds an identify node to help TFLite export.
    features = tf.identity(features, 'float_image_input')

  # In most cases, the default data format NCHW instead of NHWC should be
  # used for a significant performance boost on GPU. NHWC should be used
  # only if the network needs to be run on CPU since the pooling operations
  # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
  if params['data_format'] == 'channels_first':
    assert not params['transpose_input']    # channels_first only for GPU
    features = tf.transpose(features, [0, 3, 1, 2])
    stats_shape = [3, 1, 1]
  else:
    stats_shape = [1, 1, 3]

  if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

  # Normalize the image to zero mean and unit variance.
  features -= tf.constant(
      imagenet_input.MEAN_RGB, shape=stats_shape, dtype=features.dtype)
  features /= tf.constant(
      imagenet_input.STDDEV_RGB, shape=stats_shape, dtype=features.dtype)

  has_moving_average_decay = (params['moving_average_decay'] > 0)

  tf.logging.info('Using open-source implementation for MnasNet definition.')
  override_params = {}
  if params['batch_norm_momentum']:
    override_params['batch_norm_momentum'] = params['batch_norm_momentum']
  if params['batch_norm_epsilon']:
    override_params['batch_norm_epsilon'] = params['batch_norm_epsilon']
  if params['dropout_rate']:
    override_params['dropout_rate'] = params['dropout_rate']
  if params['data_format']:
    override_params['data_format'] = params['data_format']
  if params['num_label_classes']:
    override_params['num_classes'] = params['num_label_classes']
  if params['depth_multiplier']:
    override_params['depth_multiplier'] = params['depth_multiplier']
  if params['depth_divisor']:
    override_params['depth_divisor'] = params['depth_divisor']
  if params['min_depth']:
    override_params['min_depth'] = params['min_depth']
  override_params['use_keras'] = params['use_keras']

  def _build_model(model_name):
    """Build the model for a given model name."""
    if model_name.startswith('mnasnet'):
      return mnasnet_models.build_mnasnet_model(
          features,
          model_name=model_name,
          training=is_training,
          override_params=override_params)
    elif model_name.startswith('mixnet'):
      return mixnet_builder.build_model(
          features,
          model_name=model_name,
          training=is_training,
          override_params=override_params)
    else:
      raise ValueError('Unknown model name {}'.format(model_name))

  if params['precision'] == 'bfloat16':
    with tf.tpu.bfloat16_scope():
      logits, _ = _build_model(params['model_name'])
    logits = tf.cast(logits, tf.float32)
  else:  # params['precision'] == 'float32'
    logits, _ = _build_model(params['model_name'])

  if params['quantized_training']:
    try:
      from tensorflow.contrib import quantize  # pylint: disable=g-import-not-at-top
    except ImportError as e:
      logging.exception('Quantized training is not supported in TensorFlow 2.x')
      raise e

    if is_training:
      tf.logging.info('Adding fake quantization ops for training.')
      quantize.create_training_graph(
          quant_delay=int(params['steps_per_epoch'] *
                          FLAGS.quantization_delay_epochs))
    else:
      tf.logging.info('Adding fake quantization ops for evaluation.')
      quantize.create_eval_graph()

  if mode == tf.estimator.ModeKeys.PREDICT:
    scaffold_fn = None
    if FLAGS.export_moving_average:
      # If the model is trained with moving average decay, to match evaluation
      # metrics, we need to export the model using moving average variables.
      restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
      variables_to_restore = get_pretrained_variables_to_restore(
          restore_checkpoint, load_moving_average=True)
      tf.logging.info('Restoring from the latest checkpoint: %s',
                      restore_checkpoint)
      tf.logging.info(str(variables_to_restore))

      def restore_scaffold():
        saver = tf.train.Saver(variables_to_restore)
        return tf.train.Scaffold(saver=saver)

      scaffold_fn = restore_scaffold

    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        },
        scaffold_fn=scaffold_fn)

  # If necessary, in the model_fn, use params['batch_size'] instead the batch
  # size flags (--train_batch_size or --eval_batch_size).
  batch_size = params['batch_size']  # pylint: disable=unused-variable

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits,
      onehot_labels=one_hot_labels,
      label_smoothing=params['label_smoothing'])

  # Add weight decay to the loss for non-batch-normalization variables.
  loss = cross_entropy + params['weight_decay'] * tf.add_n([
      tf.nn.l2_loss(v)
      for v in tf.trainable_variables()
      if 'batch_normalization' not in v.name
  ])

  global_step = tf.train.get_global_step()
  if has_moving_average_decay:
    ema = tf.train.ExponentialMovingAverage(
        decay=params['moving_average_decay'], num_updates=global_step)
    ema_vars = mnas_utils.get_ema_vars()

  host_call = None
  if is_training:
    # Compute the current epoch and associated learning rate from global_step.
    current_epoch = (
        tf.cast(global_step, tf.float32) / params['steps_per_epoch'])

    scaled_lr = params['base_learning_rate'] * (params['train_batch_size'] / 256.0)  # pylint: disable=line-too-long
    learning_rate = mnas_utils.build_learning_rate(scaled_lr, global_step,
                                                   params['steps_per_epoch'])
    optimizer = mnas_utils.build_optimizer(learning_rate)
    if params['use_tpu']:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = tf.tpu.CrossShardOptimizer(optimizer)

      if params['add_summaries']:
        summary_writer = tf2.summary.create_file_writer(
            FLAGS.model_dir, max_queue=params['iterations_per_loop'])
        with summary_writer.as_default():
          should_record = tf.equal(global_step % params['iterations_per_loop'],
                                   0)
          with tf2.summary.record_if(should_record):
            tf2.summary.scalar('loss', loss, step=global_step)
            tf2.summary.scalar('learning_rate', learning_rate, step=global_step)
            tf2.summary.scalar('current_epoch', current_epoch, step=global_step)

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops + tf.summary.all_v2_summary_ops()):
      train_op = optimizer.minimize(loss, global_step)

    if has_moving_average_decay:
      with tf.control_dependencies([train_op]):
        train_op = ema.apply(ema_vars)

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:

    def metric_fn(labels, logits):
      """Evaluation metric function.

      Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
      predictions = tf.argmax(logits, axis=1)
      top_1_accuracy = tf.metrics.accuracy(labels, predictions)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      top_5_accuracy = tf.metrics.mean(in_top_5)

      return {
          'top_1_accuracy': top_1_accuracy,
          'top_5_accuracy': top_5_accuracy,
      }

    eval_metrics = (metric_fn, [labels, logits])

  num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
  tf.logging.info('number of trainable parameters: {}'.format(num_params))

  # Prepares scaffold_fn if needed.
  scaffold_fn = None
  if is_training and FLAGS.init_checkpoint:
    variables_to_restore = get_pretrained_variables_to_restore(
        FLAGS.init_checkpoint, has_moving_average_decay)
    tf.logging.info('Initializing from pretrained checkpoint: %s',
                    FLAGS.init_checkpoint)
    if FLAGS.use_tpu:

      def init_scaffold():
        tf.train.init_from_checkpoint(FLAGS.init_checkpoint,
                                      variables_to_restore)
        return tf.train.Scaffold()

      scaffold_fn = init_scaffold
    else:
      tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore)

  restore_vars_dict = None
  if not is_training and has_moving_average_decay:
    # Load moving average variables for eval.
    restore_vars_dict = ema.variables_to_restore(ema_vars)

    def eval_scaffold():
      saver = tf.train.Saver(restore_vars_dict)
      return tf.train.Scaffold(saver=saver)

    scaffold_fn = eval_scaffold

  return tf.estimator.tpu.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics,
      scaffold_fn=scaffold_fn)
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  # 设置多gpu训练的相关参数
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,  # gpu数量
      clone_on_cpu=FLAGS.clone_on_cpu,  # 默认为False
      replica_id=FLAGS.task,    # taskId
      num_replicas=FLAGS.num_replicas,  # 默认为1
      num_ps_tasks=FLAGS.num_ps_tasks)  # 默认为0

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones    # 各个gpu均分batch_size

  tf.gfile.MakeDirs(FLAGS.train_logdir)     # 创建存放训练日志的文件
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      dataset = data_generator.Dataset(     # 定义数据集参数
          dataset_name=FLAGS.dataset,   # 数据集名称 cityscapes
          split_name=FLAGS.train_split,  # 指定带有train的tfrecorder数据集 默认为“train”
          dataset_dir=FLAGS.dataset_dir,   # 数据集目录 tfrecoder文件的数据集目录
          batch_size=clone_batch_size,  # 均分后各个gpu训练中指定batch_size 的大小
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],  # 训练中裁剪的图像大小 513,513
          min_resize_value=FLAGS.min_resize_value,  # 默认为 None
          max_resize_value=FLAGS.max_resize_value,  # 默认为None
          resize_factor=FLAGS.resize_factor,    # 默认为None
          min_scale_factor=FLAGS.min_scale_factor,   # 训练中,图像变换尺度,用于数据增强 默认最小为0.5
          max_scale_factor=FLAGS.max_scale_factor,   # 训练中,图像变换尺度,用于数据增强 默认最大为2
          scale_factor_step_size=FLAGS.scale_factor_step_size,      # 训练中,图像变换尺度增加的步长,默认为0.25  从0.5到2
          model_variant=FLAGS.model_variant,    # 指定模型 xception_65
          num_readers=4,    # 读取数据个数 若多gpu可增大加快训练速度
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      # 计数作用,每训练一个batch, global加1
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab  # 定义deeplab模型
      model_args = (dataset.get_one_shot_iterator(), {
          common.OUTPUT_TYPE: dataset.num_of_classes
      }, dataset.ignore_label) #模型参数
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in tf.model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:      # 默认为False
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_of_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(  # 获取模型学习率
          FLAGS.learning_policy,    # poly学习策略
          FLAGS.base_learning_rate,     # 0.0001
          FLAGS.learning_rate_decay_step,   # 固定2000次进行一次学习率衰退
          FLAGS.learning_rate_decay_factor,     # 0.1
          FLAGS.training_number_of_steps,   # 训练次数 20000
          FLAGS.learning_power,     # poly power 0.9
          FLAGS.slow_start_step,    # 0
          FLAGS.slow_start_learning_rate,   # 1e-4 缓慢开始的学习率
          decay_steps=FLAGS.decay_steps,    # 0.0
          end_learning_rate=FLAGS.end_learning_rate)     # 0.0

      summaries.add(tf.summary.scalar('learning_rate', learning_rate))
      # 模型训练优化器
      if FLAGS.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      elif FLAGS.optimizer == 'adam':   # adam优化器 寻找全局最优点的优化算法,引入了二次方梯度校正
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
      else:
        raise ValueError('Unknown optimizer')

    if FLAGS.quantize_delay_step >= 0:  # 默认为-1 忽略
      if FLAGS.num_clones > 1:
        raise ValueError('Quantization doesn\'t support multi-clone yet.')
      contrib_quantize.create_training_graph(
          quant_delay=FLAGS.quantize_delay_step)

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps    # FLAGS.startup_delay_steps 默认为15

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)    # 计算total_loss
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      # 获取梯度乘子
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      # grad_mult : {'logits/semantic/biases': 2.0, 'logits/semantic/weights': 1.0}
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(     # 将计算的梯度用于变量上,返回一个应用指定的梯度的操作 opration
          grads_and_vars, global_step=global_step)  # 对global_step进行自增
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    profile_dir = FLAGS.profile_logdir   # 默认为None
    if profile_dir is not None:
      tf.gfile.MakeDirs(profile_dir)

    with contrib_tfprof.ProfileContext(
        enabled=profile_dir is not None, profile_dir=profile_dir):
      init_fn = None
      if FLAGS.tf_initial_checkpoint:   # 获取预训练权重
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_logdir,
          log_every_n_steps=FLAGS.log_steps,
          master=FLAGS.master,
          number_of_steps=FLAGS.training_number_of_steps,
          is_chief=(FLAGS.task == 0),
          session_config=session_config,
          startup_delay_steps=startup_delay_steps,
          init_fn=init_fn,
          summary_op=summary_op,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
Ejemplo n.º 10
0
def main(_):

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks,
        )

    # Create global_step
    with tf.device(deploy_config.variables_device()):
        global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    keys_to_features = {
        "image/encoded":
        tf.FixedLenFeature((), tf.string, default_value=""),
        "image/format":
        tf.FixedLenFeature((), tf.string, default_value="png"),
        "image/class/label":
        tf.FixedLenFeature([],
                           tf.int64,
                           default_value=tf.zeros([], dtype=tf.int64)),
    }

    items_to_handlers = {
        "image": slim.tfexample_decoder.Image(),
        "label": slim.tfexample_decoder.Tensor("image/class/label"),
    }

    items_to_descs = {
        "image": "Color image",
        "label": "Class idx",
    }

    label_idx_to_name = {}
    for i, label in enumerate(CLASSES):
        label_idx_to_name[i] = label

    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                      items_to_handlers)
    file_pattern = "tfm_clf_%s.*"
    file_pattern = os.path.join(FLAGS.records_name,
                                file_pattern % FLAGS.dataset_split_name)
    dataset = slim.dataset.Dataset(
        data_sources=file_pattern,  # TODO UPDATE
        reader=tf.TFRecordReader,
        decoder=decoder,
        num_samples=80000,  # TODO UPDATE
        items_to_descriptions=items_to_descs,
        num_classes=len(CLASSES),
        labels_to_names=label_idx_to_name,
    )

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True,
    )

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True,
        use_grayscale=FLAGS.use_grayscale)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size,
        )
        [image, label] = provider.get(["image", "label"])
        label -= FLAGS.labels_offset

        train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, train_image_size,
                                       train_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size,
        )
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
        """Allows data parallelism by creating network_fn clones."""
        images, labels = batch_queue.dequeue()
        logits, end_points = network_fn(images)

        #############################
        # Specify the loss function #
        #############################
        if "AuxLogits" in end_points:
            slim.losses.softmax_cross_entropy(
                end_points["AuxLogits"],
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=0.4,
                scope="aux_loss",
            )
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
        return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram("activations/" + end_point, x))
        summaries.add(
            tf.summary.scalar("sparsity/" + end_point, tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
        summaries.add(tf.summary.scalar("losses/%s" % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
        summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
        moving_average_variables = slim.get_model_variables()
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
    else:
        moving_average_variables, variable_averages = None, None

    if FLAGS.quantize_delay >= 0:
        contrib_quantize.create_training_graph(
            quant_delay=FLAGS.quantize_delay)

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples,
                                                 global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar("learning_rate", learning_rate))

    if FLAGS.sync_replicas:
        # If sync_replicas is enabled, the averaging will be done in the chief
        # queue runner.
        optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            total_num_replicas=FLAGS.worker_replicas,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables,
        )
    elif FLAGS.moving_average_decay:
        # Update ops executed locally by trainer.
        update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones, optimizer, var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar("total_loss", total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name="train_op")

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name="summary_op")

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None,
    )
Ejemplo n.º 11
0
    sc1 = tf.add(conv2, pool1, name='sc1')

    conv3 = slim.conv2d(sc1, 128, [3, 3], stride=2, scope='conv3')
    downsample1 = tf.image.resize_nearest_neighbor(
        sc1, [64, 64], name='nearest_neighbor_downsample')
    concat1 = tf.concat([conv3, downsample1], axis=3, name='concat1')

    dconv1 = slim.conv2d_transpose(concat1,
                                   128, [3, 3],
                                   stride=2,
                                   scope='dconv1')
    upsample1 = tf.image.resize_images(concat1, [128, 128])

    concat2 = tf.concat([dconv1, upsample1], axis=3, name='concat2')

    conv4 = slim.conv2d(upsample1, 128, [1, 1], scope='conv4')
    conv5 = slim.conv2d(concat2, 128, [1, 1], scope='conv5')
    sc2 = tf.add(conv4, conv5, name='sc2')

    pool2 = tf.reduce_mean(sc2, axis=[1, 2], name='global_average_pool')
    return pool2


g = tf.Graph()
with g.as_default():
    with slim.arg_scope(arg_scope()):
        end = net()
        quantize.create_training_graph(input_graph=g, quant_delay=2000000)

tf.summary.FileWriter('.', g)
Ejemplo n.º 12
0
def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
                          quantize_layer, is_training):
  """Adds a new softmax and fully-connected layer for training and eval.

  We need to retrain the top layer to identify our new classes, so this function
  adds the right operations to the graph, along with some variables to hold the
  weights, and then sets up all the gradients for the backward pass.

  The set up for the softmax and fully-connected layers is based on:
  https://www.tensorflow.org/tutorials/mnist/beginners/index.html

  Args:
    class_count: Integer of how many categories of things we're trying to
        recognize.
    final_tensor_name: Name string for the new final node that produces results.
    bottleneck_tensor: The output of the main CNN graph.
    quantize_layer: Boolean, specifying whether the newly added layer should be
        instrumented for quantization with TF-Lite.
    is_training: Boolean, specifying whether the newly add layer is for training
        or eval.

  Returns:
    The tensors for the training and cross entropy results, and tensors for the
    bottleneck input and ground truth input.
  """
  batch_size, bottleneck_tensor_size = bottleneck_tensor.get_shape().as_list()
  assert batch_size is None, 'We want to work with arbitrary batch size.'
  with tf.name_scope('input'):
    bottleneck_input = tf.placeholder_with_default(
        bottleneck_tensor,
        shape=[batch_size, bottleneck_tensor_size],
        name='BottleneckInputPlaceholder')

    ground_truth_input = tf.placeholder(
        tf.int64, [batch_size], name='GroundTruthInput')

  # Organizing the following ops so they are easier to see in TensorBoard.
  layer_name = 'final_retrain_ops'
  with tf.name_scope(layer_name):
    with tf.name_scope('weights'):
      initial_value = tf.truncated_normal(
          [bottleneck_tensor_size, class_count], stddev=0.001)
      layer_weights = tf.Variable(initial_value, name='final_weights')
      variable_summaries(layer_weights)

    with tf.name_scope('biases'):
      layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
      variable_summaries(layer_biases)

    with tf.name_scope('Wx_plus_b'):
      logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
      tf.summary.histogram('pre_activations', logits)

  final_tensor = tf.nn.softmax(logits, name=final_tensor_name)

  # The tf.contrib.quantize functions rewrite the graph in place for
  # quantization. The imported model graph has already been rewritten, so upon
  # calling these rewrites, only the newly added final layer will be
  # transformed.
  if quantize_layer:
    if is_training:
      contrib_quantize.create_training_graph()
    else:
      contrib_quantize.create_eval_graph()

  tf.summary.histogram('activations', final_tensor)

  # If this is an eval graph, we don't need to add loss ops or an optimizer.
  if not is_training:
    return None, None, bottleneck_input, ground_truth_input, final_tensor

  with tf.name_scope('cross_entropy'):
    cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
        labels=ground_truth_input, logits=logits)

  tf.summary.scalar('cross_entropy', cross_entropy_mean)

  with tf.name_scope('train'):
    optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
    train_step = optimizer.minimize(cross_entropy_mean)

  return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
          final_tensor)
Ejemplo n.º 13
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')
    if not FLAGS.frozen_pb:
        raise ValueError('You must supply the frozen pb with --frozen_pb')
    if not FLAGS.output_node_name:
        raise ValueError(
            'You must supply the output node name with --output_node_name')
    if not FLAGS.output_dir:
        raise ValueError(
            'You must supply the output directory with --output_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    tfrecords = prepare_tfrecords(FLAGS.dataset_name, FLAGS.dataset_dir,
                                  FLAGS.dataset_split_name)

    if FLAGS.max_num_batches:
        num_batches = FLAGS.max_num_batches
    else:
        num_records = sum(
            [len(list(tf.python_io.tf_record_iterator(r))) for r in tfrecords])
        num_batches = int(math.ceil(num_records / float(FLAGS.batch_size)))

    tf.logging.info('Load GraphDef from frozen_pb {}'.format(FLAGS.frozen_pb))
    graph_def = load_graph_def(FLAGS.frozen_pb)

    tf.logging.info('Quantize Graph')
    with tf.Session() as sess:
        tf.import_graph_def(graph_def, name='')
        quantized_graph = qg.create_training_graph(sess.graph)
        quantized_inf_graph = qg.create_eval_graph(sess.graph)

    # Initialize `iterator` with training data.
    with tf.Session(graph=quantized_graph) as sess:
        tf.logging.info('Prepare dataset')
        with tf.name_scope("dataset"):
            filenames = tf.placeholder(tf.string, shape=[None])
            dataset = prepare_dataset(filenames,
                                      FLAGS.dataset_name,
                                      FLAGS.input_size,
                                      batch_size=FLAGS.batch_size)
            iterator = dataset.make_initializable_iterator()
            next_batch = iterator.get_next()

        tf.logging.info('Prepare metrics')
        lbls, preds, accuracy, acc_update_op = prepare_metrics(
            FLAGS.dataset_name)

        tf.logging.info('Prepare Saver')
        saver = tf.train.Saver()

        if FLAGS.summary_dir:
            tf.logging.info('Prepare summary writer')
            summary_writer = tf.summary.FileWriter(FLAGS.summary_dir)

        # initialize
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(iterator.initializer, feed_dict={filenames: tfrecords})

        graph = sess.graph

        # get x and y
        x = graph.get_tensor_by_name('{}:0'.format(FLAGS.input_node_name))
        y = graph.get_tensor_by_name('{}:0'.format(FLAGS.output_node_name))

        # summary all min/max variables
        # print(graph.get_collection('variables')[3].eval())
        for var in graph.get_collection('variables'):
            tf.summary.scalar(var.name, var)
        summaries = tf.summary.merge_all()

        for step in range(num_batches):
            images, labels = sess.run(next_batch)
            ys = sess.run(y, feed_dict={x: images})
            sess.run(acc_update_op, feed_dict={lbls: labels, preds: ys})
            summary = sess.run(summaries)
            if FLAGS.summary_dir:
                summary_writer.add_summary(summary, step)

        print('Accuracy: [{:.4f}]'.format(sess.run(accuracy)))
        if FLAGS.summary_dir:
            summary_writer.add_graph(graph)

        # save graph and ckpts
        saver.save(sess, os.path.join(FLAGS.output_dir, "model.ckpt"))
        # tf.train.write_graph(graph, FLAGS.output_dir, 'quantor.pb', as_text=False)
        tf.train.write_graph(quantized_inf_graph,
                             FLAGS.output_dir,
                             'quantor.pb',
                             as_text=False)