Exemple #1
0
  def testCreateLogisticClassifier(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = LogisticClassifier
      clone_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      clone = clones[0]
      self.assertEqual(len(slim.get_variables()), 2)
      for v in slim.get_variables():
        self.assertDeviceEqual(v.device, 'CPU:0')
        self.assertDeviceEqual(v.value().device, 'CPU:0')
      self.assertEqual(clone.outputs.op.name,
                       'LogisticClassifier/fully_connected/Sigmoid')
      self.assertEqual(clone.scope, '')
      self.assertDeviceEqual(clone.device, 'GPU:0')
      self.assertEqual(len(slim.losses.get_losses()), 1)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(update_ops, [])
Exemple #2
0
  def testCreateMulticloneWithPS(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = BatchNormClassifier
      clone_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=2,
                                                    num_ps_tasks=2)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      self.assertEqual(len(slim.get_variables()), 5)
      for i, v in enumerate(slim.get_variables()):
        t = i % 2
        self.assertDeviceEqual(v.device, '/job:ps/task:%d/device:CPU:0' % t)
        self.assertDeviceEqual(v.device, v.value().device)
      self.assertEqual(len(clones), 2)
      for i, clone in enumerate(clones):
        self.assertEqual(
            clone.outputs.op.name,
            'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
        self.assertEqual(clone.scope, 'clone_%d/' % i)
        self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:%d' % i)
Exemple #3
0
  def testCreateMulticlone(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = BatchNormClassifier
      clone_args = (tf_inputs, tf_labels)
      num_clones = 4
      deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      self.assertEqual(len(slim.get_variables()), 5)
      for v in slim.get_variables():
        self.assertDeviceEqual(v.device, 'CPU:0')
        self.assertDeviceEqual(v.value().device, 'CPU:0')
      self.assertEqual(len(clones), num_clones)
      for i, clone in enumerate(clones):
        self.assertEqual(
            clone.outputs.op.name,
            'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
        self.assertEqual(len(update_ops), 2)
        self.assertEqual(clone.scope, 'clone_%d/' % i)
        self.assertDeviceEqual(clone.device, 'GPU:%d' % i)
Exemple #4
0
  def testCreateOnecloneWithPS(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = BatchNormClassifier
      model_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1,
                                                    num_ps_tasks=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
      self.assertEqual(len(slim.get_variables()), 5)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(len(update_ops), 2)

      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
                                                                optimizer)
      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
      self.assertEqual(total_loss.op.name, 'total_loss')
      for g, v in grads_and_vars:
        self.assertDeviceEqual(g.device, '/job:worker/device:GPU:0')
        self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
Exemple #5
0
  def testCreateOnecloneWithPS(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = BatchNormClassifier
      clone_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1,
                                                    num_ps_tasks=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      self.assertEqual(len(clones), 1)
      clone = clones[0]
      self.assertEqual(clone.outputs.op.name,
                       'BatchNormClassifier/fully_connected/Sigmoid')
      self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:0')
      self.assertEqual(clone.scope, '')
      self.assertEqual(len(slim.get_variables()), 5)
      for v in slim.get_variables():
        self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
        self.assertDeviceEqual(v.device, v.value().device)
Exemple #6
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None,
          gpu_usage=None):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the inference graph is
      built (before optimization). This is helpful to perform additional changes
      to the training graph such as adding FakeQuant ops. The function should
      modify the default graph.

  Raises:
    ValueError: If both num_clones > 1 and train_config.sync_replicas is true.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        if num_clones != 1 and train_config.sync_replicas:
            raise ValueError('In Synchronous SGD mode num_clones must ',
                             'be 1. Found num_clones: {}'.format(num_clones))
        batch_size = train_config.batch_size // num_clones
        if train_config.sync_replicas:
            batch_size //= train_config.replicas_to_aggregate

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=worker_replicas)
            sync_optimizer = training_optimizer

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (
                None if train_config.add_regularization_loss else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram('ModelVars/' + model_var.op.name,
                                     model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                  loss_tensor))
        global_summaries.add(
            tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        if gpu_usage is not None:
            session_config.gpu_options.per_process_gpu_memory_fraction = gpu_usage

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            max_to_keep=train_config.max_checkpoints_to_keep,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            if not train_config.fine_tune_checkpoint_type:
                # train_config.from_detection_checkpoint field is deprecated. For
                # backward compatibility, fine_tune_checkpoint_type is set based on
                # from_detection_checkpoint.
                if train_config.from_detection_checkpoint:
                    train_config.fine_tune_checkpoint_type = 'detection'
                else:
                    train_config.fine_tune_checkpoint_type = 'classification'
            var_map = detection_model.restore_map(
                fine_tune_checkpoint_type=train_config.
                fine_tune_checkpoint_type,
                load_all_detection_checkpoint_vars=(
                    train_config.load_all_detection_checkpoint_vars))
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map,
                    train_config.fine_tune_checkpoint,
                    include_global_step=False))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=train_config.save_summaries_secs,
            save_interval_secs=train_config.checkpoint_save_interval_secs,
            sync_optimizer=sync_optimizer,
            saver=saver)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.worker_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
      FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,
      num_classes=(dataset.num_classes - FLAGS.labels_offset),
      weight_decay=FLAGS.weight_decay,
      is_training=True,
      width_multiplier=FLAGS.width_multiplier)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,
      is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size)

      # gt_bboxes format [ymin, xmin, ymax, xmax]
      [image, img_shape, gt_labels, gt_bboxes] = provider.get(['image', 'shape',
                                                               'object/label',
                                                               'object/bbox'])

      # Preprocesing
      # gt_bboxes = scale_bboxes(gt_bboxes, img_shape)  # bboxes format [0,1) for tf draw

      image, gt_labels, gt_bboxes = image_preprocessing_fn(image,
                                                           config.IMG_HEIGHT,
                                                           config.IMG_WIDTH,
                                                           labels=gt_labels,
                                                           bboxes=gt_bboxes,
                                                           )

      #############################################
      # Encode annotations for losses computation #
      #############################################

      # anchors format [cx, cy, w, h]
      anchors = tf.convert_to_tensor(config.ANCHOR_SHAPE, dtype=tf.float32)

      # encode annos, box_input format [cx, cy, w, h]
      input_mask, labels_input, box_delta_input, box_input = encode_annos(gt_labels,
                                                                          gt_bboxes,
                                                                          anchors,
                                                                          config.NUM_CLASSES)

      images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input = tf.train.batch(
        [image, input_mask, labels_input, box_delta_input, box_input],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

      batch_queue = slim.prefetch_queue.prefetch_queue(
        [images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input = batch_queue.dequeue()
      anchors = tf.convert_to_tensor(config.ANCHOR_SHAPE, dtype=tf.float32)
      end_points = network_fn(images)
      end_points["viz_images"] = images
      conv_ds_14 = end_points['MobileNet/conv_ds_14/depthwise_conv']
      dropout = slim.dropout(conv_ds_14, keep_prob=0.5, is_training=True)
      num_output = config.NUM_ANCHORS * (config.NUM_CLASSES + 1 + 4)
      predict = slim.conv2d(dropout, num_output, kernel_size=(3, 3), stride=1, padding='SAME',
                            activation_fn=None,
                            weights_initializer=tf.truncated_normal_initializer(stddev=0.0001),
                            scope="MobileNet/conv_predict")

      with tf.name_scope("Interpre_prediction") as scope:
        pred_box_delta, pred_class_probs, pred_conf, ious, det_probs, det_boxes, det_class = \
          interpre_prediction(predict, b_input_mask, anchors, b_box_input)
        end_points["viz_det_probs"] = det_probs
        end_points["viz_det_boxes"] = det_boxes
        end_points["viz_det_class"] = det_class

      with tf.name_scope("Losses") as scope:
        losses(b_input_mask, b_labels_input, ious, b_box_delta_input, pred_class_probs, pred_conf, pred_box_delta)

      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      if end_point not in ["viz_images", "viz_det_probs", "viz_det_boxes", "viz_det_class"]:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram('activations/' + end_point, x))
        summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                        tf.nn.zero_fraction(x)))

    # Add summaries for det result TODO(shizehao): vizulize prediction


    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
        opt=optimizer,
        replicas_to_aggregate=FLAGS.replicas_to_aggregate,
        variable_averages=variable_averages,
        variables_to_average=moving_average_variables,
        replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
        total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
      clones,
      optimizer,
      var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                      name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
      train_tensor,
      logdir=FLAGS.train_dir,
      master=FLAGS.master,
      is_chief=(FLAGS.task == 0),
      init_fn=_get_init_fn(),
      summary_op=summary_op,
      number_of_steps=FLAGS.max_number_of_steps,
      log_every_n_steps=FLAGS.log_every_n_steps,
      save_summaries_secs=FLAGS.save_summaries_secs,
      save_interval_secs=FLAGS.save_interval_secs,
      sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Exemple #8
0
def main():
    print(args)
    prt('')

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))

    np.random.seed(seed=args.seed)

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(num_clones=args.num_gpus,
                                                      clone_on_cpu=False)
        tf.set_random_seed(args.seed)
        #global_step = tf.Variable(0, trainable=False)
        global_step = variables.get_or_create_global_step()

        # Placeholder for the learning rate
        #learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate')

        #batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size')
        with tf.device('/cpu:0'):
            is_training_placeholder = tf.placeholder(tf.bool,
                                                     name='is_training')
            image_paths_placeholder = tf.placeholder(tf.string,
                                                     shape=(None, 3),
                                                     name='image_paths')
            labels_placeholder = tf.placeholder(tf.int64,
                                                shape=(None, 3),
                                                name='labels')

            input_queue = data_flow_ops.FIFOQueue(capacity=100000,
                                                  dtypes=[tf.string, tf.int64],
                                                  shapes=[(3, ), (3, )],
                                                  shared_name=None,
                                                  name=None)
            enqueue_op = input_queue.enqueue_many(
                [image_paths_placeholder, labels_placeholder])

            nrof_preprocess_threads = 8
            images_and_labels = []
            for _ in range(nrof_preprocess_threads):
                filenames, label = input_queue.dequeue()
                #filenames = tf.Print(filenames, [tf.shape(filenames)], 'filenames shape:')
                images = []
                for filename in tf.unstack(filenames):
                    #filename = tf.Print(filename, [filename], 'filename = ')
                    file_contents = tf.read_file(filename)
                    image = tf.image.decode_jpeg(file_contents)
                    #image = tf.Print(image, [tf.shape(image)], 'data count = ')
                    if image.dtype != tf.float32:
                        image = tf.image.convert_image_dtype(image,
                                                             dtype=tf.float32)
                    if args.random_crop:
                        #image = tf.random_crop(image, [args.image_size, args.image_size, 3])
                        bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
                                           dtype=tf.float32,
                                           shape=[1, 1, 4])
                        sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
                            tf.shape(image),
                            bounding_boxes=bbox,
                            area_range=(0.7, 1.0),
                            use_image_if_no_bounding_boxes=True)
                        bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
                        image = tf.slice(image, bbox_begin, bbox_size)
                    #else:
                    #    image = tf.image.resize_image_with_crop_or_pad(image, args.image_size, args.image_size)
                    image = tf.expand_dims(image, 0)
                    image = tf.image.resize_bilinear(
                        image, [args.image_size, args.image_size],
                        align_corners=False)
                    image = tf.squeeze(image, [0])
                    if args.random_flip:
                        image = tf.image.random_flip_left_right(image)
                    image.set_shape((args.image_size, args.image_size, 3))
                    ##pylint: disable=no-member
                    image = tf.subtract(image, 0.5)
                    image = tf.multiply(image, 2.0)
                    #image = tf.Print(image, [tf.shape(image)], 'data count = ')
                    images.append(image)
                    #images.append(tf.image.per_image_standardization(image))
                images_and_labels.append([images, label])

            learning_rate = get_learning_rate(args)
            opt = get_optimizer(args, learning_rate)
            image_batch, label_batch = tf.train.batch_join(
                images_and_labels,
                batch_size=args.batch_size,
                shapes=[(args.image_size, args.image_size, 3), ()],
                enqueue_many=True,
                capacity=4 * nrof_preprocess_threads * args.batch_size,
                allow_smaller_final_batch=False)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [image_batch, label_batch], capacity=9000)

        def clone_fn(_batch_queue):
            _image_batch, _label_batch = _batch_queue.dequeue()
            embeddings = image_to_embedding(_image_batch,
                                            is_training_placeholder, args)

            # Split embeddings into anchor, positive and negative and calculate triplet loss
            anchor, positive, negative = tf.unstack(
                tf.reshape(embeddings, [-1, 3, args.embedding_size]), 3, 1)
            triplet_loss = triplet_loss_fn(anchor, positive, negative,
                                           args.alpha)
            tf.losses.add_loss(triplet_loss)
            #tf.summary.scalar('learning_rate', learning_rate)
            return embeddings, _label_batch, triplet_loss

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone = clones[0]
        triplet_loss = first_clone.outputs[2]
        embeddings = first_clone.outputs[0]
        _label_batch = first_clone.outputs[1]
        #embedding_clones = model_deploy.create_clones(deploy_config, embedding_fn, [batch_queue])

        #first_clone_scope = deploy_config.clone_scope(0)
        #update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
        update_ops = []
        vdic = [
            v for v in tf.trainable_variables() if v.name.find("Logits/") < 0
        ]
        pretrained_saver = tf.train.Saver(vdic)
        saver = tf.train.Saver(max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = get_learning_rate(args)
            opt = get_optimizer(args, learning_rate)

        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, opt, var_list=tf.trainable_variables())

        grad_updates = opt.apply_gradients(clones_gradients,
                                           global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_op = control_flow_ops.with_dependencies([update_op],
                                                      total_loss,
                                                      name='train_op')

        vdic = [
            v for v in tf.trainable_variables() if v.name.find("Logits/") < 0
        ]
        pretrained_saver = tf.train.Saver(vdic)
        saver = tf.train.Saver(max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Start running operations on the Graph.
        #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        #sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        sess = tf.Session()

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={is_training_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={is_training_placeholder: True})

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():

            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                pretrained_saver.restore(
                    sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                eval_one_epoch(args, sess, dataset, image_paths_placeholder,
                               labels_placeholder, is_training_placeholder,
                               enqueue_op, clones)
                # Train for one epoch
                train_one_epoch(args, sess, dataset, image_paths_placeholder,
                                labels_placeholder, is_training_placeholder,
                                enqueue_op, input_queue, clones, total_loss,
                                train_op, summary_op, summary_writer)

                # Save variables and the metagraph if it doesn't exist already
                global_step = variables.get_or_create_global_step()
                step = sess.run(global_step, feed_dict=None)
                print('one epoch finish', step)
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)
                print('saver finish')

    sess.close()
    return model_dir
def main(_):
    #tf.disable_v2_behavior() ###
    tf.compat.v1.disable_eager_execution()
    tf.compat.v1.enable_resource_variables()

    # Enable habana bf16 conversion pass
    if FLAGS.dtype == 'bf16':
        os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path
        FLAGS.precision = 'bf16'
    else:
        os.environ['TF_BF16_CONVERSION'] = "0"

    if FLAGS.use_horovod:
        hvd_init()

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=True,
            use_grayscale=FLAGS.use_grayscale)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #if FLAGS.quantize_delay >= 0:
        #  quantize.create_training_graph(quant_delay=FLAGS.quantize_delay) #for debugging!!

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if horovod_enabled():
            hvd.broadcast_global_variables(0)
        ###########################
        # Kicks off the training. #
        ###########################
        with dump_callback():
            with logger.benchmark_context(FLAGS):
                eps1 = ExamplesPerSecondKerasHook(FLAGS.log_every_n_steps,
                                                  output_dir=FLAGS.train_dir,
                                                  batch_size=FLAGS.batch_size)

                write_hparams_v1(
                    eps1.writer, {
                        'batch_size': FLAGS.batch_size,
                        **{x: getattr(FLAGS, x)
                           for x in FLAGS}
                    })

                train_step_kwargs = {}
                if FLAGS.max_number_of_steps:
                    should_stop_op = math_ops.greater_equal(
                        global_step, FLAGS.max_number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                if FLAGS.log_every_n_steps > 0:
                    train_step_kwargs['should_log'] = math_ops.equal(
                        math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)

                eps1.on_train_begin()
                train_step_kwargs['EPS'] = eps1

                slim.learning.train(
                    train_tensor,
                    logdir=FLAGS.train_dir,
                    train_step_fn=train_step1,
                    train_step_kwargs=train_step_kwargs,
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    init_fn=_get_init_fn(),
                    summary_op=summary_op,
                    summary_writer=None,
                    number_of_steps=FLAGS.max_number_of_steps,
                    log_every_n_steps=FLAGS.log_every_n_steps,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    save_interval_secs=FLAGS.save_interval_secs,
                    sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(train_dir, num_clones, clone_on_cpu, train_size, val_size,
         num_classes, worker_replicas, log_every_n_steps, save_interval_secs,
         weight_decay, opt, l_r, moving_average_decay, d_set,
         max_number_of_steps, check):

    num_ps_tasks = 0
    num_readers = 4
    num_preprocessing_threads = 4
    task = 0

    if not d_set['dataset_dir']:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(d_set['dataset_name'],
                                              d_set['dataset_split_name'],
                                              d_set['dataset_dir'],
                                              train_size=train_size,
                                              val_size=val_size,
                                              num_classes=num_classes)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            d_set['model_name'],
            num_classes=(dataset.num_classes - d_set['labels_offset']),
            weight_decay=weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = d_set['preprocessing_name'] or d_set['model_name']
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=num_readers,
                common_queue_capacity=20 * d_set['batch_size'],
                common_queue_min=10 * d_set['batch_size'])
            [image, label] = provider.get(['image', 'label'])
            label -= d_set['labels_offset']

            train_image_size = d_set[
                'train_image_size'] or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=d_set['batch_size'],
                num_threads=num_preprocessing_threads,
                capacity=5 * d_set['batch_size'])
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - d_set['labels_offset'])
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=l_r['label_smoothing'],
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=l_r['label_smoothing'],
                weights=1.0)
            return end_points

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        #################################
        # Configure the moving averages #
        #################################
        if moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step, l_r,
                                                     d_set['batch_size'])
            optimizer = _configure_optimizer(learning_rate, opt)

        if l_r['sync_replicas']:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=l_r['replicas_to_aggregate'],
                total_num_replicas=worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train(check['trainable_scopes'])

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master='',
            is_chief=(task == 0),
            init_fn=_get_init_fn(train_dir, check),
            number_of_steps=max_number_of_steps,
            log_every_n_steps=log_every_n_steps,
            save_interval_secs=save_interval_secs,
            sync_optimizer=optimizer if l_r['sync_replicas'] else None)
Exemple #11
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)
    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)
        # Create global_step.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # Select the dataset.
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        # Get the SSD network and its anchors.
        ssd_class = nets_factory.get_network(FLAGS.model_name)
        ssd_params = ssd_class.default_params._replace(
            num_classes=FLAGS.num_classes)
        ssd_net = ssd_class(ssd_params)
        ssd_shape = ssd_net.params.img_shape
        ssd_anchors = ssd_net.anchors(ssd_shape)

        # Select the preprocessing function.
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        tf_utils.print_configuration(FLAGS.__flags, ssd_params,
                                     dataset.data_sources, FLAGS.train_dir)
        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device(deploy_config.inputs_device()):
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.
            [image, shape, glabels, gbboxes] = provider.get(
                ['image', 'shape', 'object/label', 'object/bbox'])
            # Pre-processing image, labels and bboxes.
            image, glabels, gbboxes = \
                image_preprocessing_fn(image, glabels, gbboxes,
                                       out_shape=ssd_shape,
                                       data_format=DATA_FORMAT)
            # Encode groundtruth labels and bboxes.
            gclasses, glocalisations, gscores = \
                ssd_net.bboxes_encode(glabels, gbboxes, ssd_anchors)
            batch_shape = [1] + [len(ssd_anchors)] * 3

            # Training batches and queue.
            r = tf.train.batch(tf_utils.reshape_list(
                [image, gclasses, glocalisations, gscores]),
                               batch_size=FLAGS.batch_size,
                               num_threads=FLAGS.num_preprocessing_threads,
                               capacity=5 * FLAGS.batch_size)
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(r, batch_shape)

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list(
                    [b_image, b_gclasses, b_glocalisations, b_gscores]),
                capacity=2 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple
            clones of network_fn."""
            # Dequeue batch.
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            # Construct SSD network.
            arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay,
                                          data_format=DATA_FORMAT)
            with slim.arg_scope(arg_scope):
                predictions, localisations, logits, end_points = \
                    ssd_net.net(b_image, is_training=True)
            # Add loss function.
            ssd_net.losses(logits,
                           localisations,
                           b_gclasses,
                           b_glocalisations,
                           b_gscores,
                           match_threshold=FLAGS.match_threshold,
                           negative_ratio=FLAGS.negative_ratio,
                           alpha=FLAGS.loss_alpha,
                           label_smoothing=FLAGS.label_smoothing)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # =================================================================== #
        # Add summaries from first clone.
        # =================================================================== #
        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        # Add summaries for losses and extra losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # =================================================================== #
        # Configure the optimization procedure.
        # =================================================================== #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, dataset.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        # and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(log_device_placement=False,
                                gpu_options=gpu_options)
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
def main_fun(argv, ctx):
  import tensorflow as tf
  from tensorflow.python.ops import control_flow_ops
  from datasets import dataset_factory
  from deployment import model_deploy
  from nets import nets_factory
  from preprocessing import preprocessing_factory

  sys.argv = argv

  slim = tf.contrib.slim

  tf.app.flags.DEFINE_integer(
      'num_gpus', '1', 'The number of GPUs to use per node')

  tf.app.flags.DEFINE_boolean('rdma', False, 'Whether to use rdma.')

  tf.app.flags.DEFINE_string(
      'master', '', 'The address of the TensorFlow master to use.')

  tf.app.flags.DEFINE_string(
      'train_dir', '/tmp/tfmodel/',
      'Directory where checkpoints and event logs are written to.')

  tf.app.flags.DEFINE_integer('num_clones', 1,
                              'Number of model clones to deploy.')

  tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
                              'Use CPUs to deploy clones.')

  tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')

  tf.app.flags.DEFINE_integer(
      'num_ps_tasks', 0,
      'The number of parameter servers. If the value is 0, then the parameters '
      'are handled locally by the worker.')

  tf.app.flags.DEFINE_integer(
      'num_readers', 4,
      'The number of parallel readers that read data from the dataset.')

  tf.app.flags.DEFINE_integer(
      'num_preprocessing_threads', 4,
      'The number of threads used to create the batches.')

  tf.app.flags.DEFINE_integer(
      'log_every_n_steps', 10,
      'The frequency with which logs are print.')

  tf.app.flags.DEFINE_integer(
      'save_summaries_secs', 600,
      'The frequency with which summaries are saved, in seconds.')

  tf.app.flags.DEFINE_integer(
      'save_interval_secs', 600,
      'The frequency with which the model is saved, in seconds.')

  tf.app.flags.DEFINE_integer(
      'task', 0, 'Task id of the replica running the training.')

  ######################
  # Optimization Flags #
  ######################

  tf.app.flags.DEFINE_float(
      'weight_decay', 0.00004, 'The weight decay on the model weights.')

  tf.app.flags.DEFINE_string(
      'optimizer', 'rmsprop',
      'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
      '"ftrl", "momentum", "sgd" or "rmsprop".')

  tf.app.flags.DEFINE_float(
      'adadelta_rho', 0.95,
      'The decay rate for adadelta.')

  tf.app.flags.DEFINE_float(
      'adagrad_initial_accumulator_value', 0.1,
      'Starting value for the AdaGrad accumulators.')

  tf.app.flags.DEFINE_float(
      'adam_beta1', 0.9,
      'The exponential decay rate for the 1st moment estimates.')

  tf.app.flags.DEFINE_float(
      'adam_beta2', 0.999,
      'The exponential decay rate for the 2nd moment estimates.')

  tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')

  tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
                            'The learning rate power.')

  tf.app.flags.DEFINE_float(
      'ftrl_initial_accumulator_value', 0.1,
      'Starting value for the FTRL accumulators.')

  tf.app.flags.DEFINE_float(
      'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')

  tf.app.flags.DEFINE_float(
      'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')

  tf.app.flags.DEFINE_float(
      'momentum', 0.9,
      'The momentum for the MomentumOptimizer and RMSPropOptimizer.')

  tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')

  #######################
  # Learning Rate Flags #
  #######################

  tf.app.flags.DEFINE_string(
      'learning_rate_decay_type',
      'exponential',
      'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
      ' or "polynomial"')

  tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

  tf.app.flags.DEFINE_float(
      'end_learning_rate', 0.0001,
      'The minimal end learning rate used by a polynomial decay learning rate.')

  tf.app.flags.DEFINE_float(
      'label_smoothing', 0.0, 'The amount of label smoothing.')

  tf.app.flags.DEFINE_float(
      'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')

  tf.app.flags.DEFINE_float(
      'num_epochs_per_decay', 2.0,
      'Number of epochs after which learning rate decays.')

  tf.app.flags.DEFINE_bool(
      'sync_replicas', False,
      'Whether or not to synchronize the replicas during training.')

  tf.app.flags.DEFINE_integer(
      'replicas_to_aggregate', 1,
      'The Number of gradients to collect before updating params.')

  tf.app.flags.DEFINE_float(
      'moving_average_decay', None,
      'The decay to use for the moving average.'
      'If left as None, then moving averages are not used.')

  #######################
  # Dataset Flags #
  #######################

  tf.app.flags.DEFINE_string(
      'dataset_name', 'imagenet', 'The name of the dataset to load.')

  tf.app.flags.DEFINE_string(
      'dataset_split_name', 'train', 'The name of the train/test split.')

  tf.app.flags.DEFINE_string(
      'dataset_dir', None, 'The directory where the dataset files are stored.')

  tf.app.flags.DEFINE_integer(
      'labels_offset', 0,
      'An offset for the labels in the dataset. This flag is primarily used to '
      'evaluate the VGG and ResNet architectures which do not use a background '
      'class for the ImageNet dataset.')

  tf.app.flags.DEFINE_string(
      'model_name', 'inception_v3', 'The name of the architecture to train.')

  tf.app.flags.DEFINE_string(
      'preprocessing_name', None, 'The name of the preprocessing to use. If left '
      'as `None`, then the model_name flag is used.')

  tf.app.flags.DEFINE_integer(
      'batch_size', 32, 'The number of samples in each batch.')

  tf.app.flags.DEFINE_integer(
      'train_image_size', None, 'Train image size')

  tf.app.flags.DEFINE_integer('max_number_of_steps', None,
                              'The maximum number of training steps.')

  #####################
  # Fine-Tuning Flags #
  #####################

  tf.app.flags.DEFINE_string(
      'checkpoint_path', None,
      'The path to a checkpoint from which to fine-tune.')

  tf.app.flags.DEFINE_string(
      'checkpoint_exclude_scopes', None,
      'Comma-separated list of scopes of variables to exclude when restoring '
      'from a checkpoint.')

  tf.app.flags.DEFINE_string(
      'trainable_scopes', None,
      'Comma-separated list of scopes to filter the set of variables to train.'
      'By default, None would train all the variables.')

  tf.app.flags.DEFINE_boolean(
      'ignore_missing_vars', False,
      'When restoring a checkpoint would ignore missing variables.')

  FLAGS = tf.app.flags.FLAGS
  FLAGS.job_name = ctx.job_name
  FLAGS.task = ctx.task_index
  FLAGS.num_clones = FLAGS.num_gpus
  FLAGS.worker_replicas = len(ctx.cluster_spec['worker'])
  assert(FLAGS.num_ps_tasks == (len(ctx.cluster_spec['ps']) if 'ps' in ctx.cluster_spec else 0))

  def _configure_learning_rate(num_samples_per_epoch, global_step):
    """Configures the learning rate.

    Args:
      num_samples_per_epoch: The number of samples in each epoch of training.
      global_step: The global_step tensor.

    Returns:
      A `Tensor` representing the learning rate.

    Raises:
      ValueError: if
    """
    decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
                      FLAGS.num_epochs_per_decay)
    if FLAGS.sync_replicas:
      decay_steps /= FLAGS.replicas_to_aggregate

    if FLAGS.learning_rate_decay_type == 'exponential':
      return tf.train.exponential_decay(FLAGS.learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True,
                                        name='exponential_decay_learning_rate')
    elif FLAGS.learning_rate_decay_type == 'fixed':
      return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
    elif FLAGS.learning_rate_decay_type == 'polynomial':
      return tf.train.polynomial_decay(FLAGS.learning_rate,
                                       global_step,
                                       decay_steps,
                                       FLAGS.end_learning_rate,
                                       power=1.0,
                                       cycle=False,
                                       name='polynomial_decay_learning_rate')
    else:
      raise ValueError('learning_rate_decay_type [%s] was not recognized',
                       FLAGS.learning_rate_decay_type)


  def _configure_optimizer(learning_rate):
    """Configures the optimizer used for training.

    Args:
      learning_rate: A scalar or `Tensor` learning rate.

    Returns:
      An instance of an optimizer.

    Raises:
      ValueError: if FLAGS.optimizer is not recognized.
    """
    if FLAGS.optimizer == 'adadelta':
      optimizer = tf.train.AdadeltaOptimizer(
          learning_rate,
          rho=FLAGS.adadelta_rho,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'adagrad':
      optimizer = tf.train.AdagradOptimizer(
          learning_rate,
          initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
    elif FLAGS.optimizer == 'adam':
      optimizer = tf.train.AdamOptimizer(
          learning_rate,
          beta1=FLAGS.adam_beta1,
          beta2=FLAGS.adam_beta2,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'ftrl':
      optimizer = tf.train.FtrlOptimizer(
          learning_rate,
          learning_rate_power=FLAGS.ftrl_learning_rate_power,
          initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
          l1_regularization_strength=FLAGS.ftrl_l1,
          l2_regularization_strength=FLAGS.ftrl_l2)
    elif FLAGS.optimizer == 'momentum':
      optimizer = tf.train.MomentumOptimizer(
          learning_rate,
          momentum=FLAGS.momentum,
          name='Momentum')
    elif FLAGS.optimizer == 'rmsprop':
      optimizer = tf.train.RMSPropOptimizer(
          learning_rate,
          decay=FLAGS.rmsprop_decay,
          momentum=FLAGS.momentum,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'sgd':
      optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    else:
      raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
    return optimizer


  def _add_variables_summaries(learning_rate):
    summaries = []
    for variable in slim.get_model_variables():
      summaries.append(tf.summary.histogram(variable.op.name, variable))
    summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
    return summaries


  def _get_init_fn():
    """Returns a function run by the chief worker to warm-start the training.

    Note that the init_fn is only run when initializing the model during the very
    first global step.

    Returns:
      An init function run by the supervisor.
    """
    if FLAGS.checkpoint_path is None:
      return None

    # Warn the user if a checkpoint exists in the train_dir. Then we'll be
    # ignoring the checkpoint anyway.
    if tf.train.latest_checkpoint(FLAGS.train_dir):
      tf.logging.info(
          'Ignoring --checkpoint_path because a checkpoint already exists in %s'
          % FLAGS.train_dir)
      return None

    exclusions = []
    if FLAGS.checkpoint_exclude_scopes:
      exclusions = [scope.strip()
                    for scope in FLAGS.checkpoint_exclude_scopes.split(',')]

    # TODO(sguada) variables.filter_variables()
    variables_to_restore = []
    for var in slim.get_model_variables():
      excluded = False
      for exclusion in exclusions:
        if var.op.name.startswith(exclusion):
          excluded = True
          break
      if not excluded:
        variables_to_restore.append(var)

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)

    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=FLAGS.ignore_missing_vars)


  def _get_variables_to_train():
    """Returns a list of variables to train.

    Returns:
      A list of variables to train by the optimizer.
    """
    if FLAGS.trainable_scopes is None:
      return tf.trainable_variables()
    else:
      scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]

    variables_to_train = []
    for scope in scopes:
      variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
      variables_to_train.extend(variables)
    return variables_to_train

  # main
  cluster_spec, server = TFNode.start_cluster_server(ctx=ctx, num_gpus=FLAGS.num_gpus, rdma=FLAGS.rdma)
  if ctx.job_name == 'ps':
    # `ps` jobs wait for incoming connections from the workers.
    server.join()
  else:
    # `worker` jobs will actually do the work.
    if not FLAGS.dataset_dir:
      raise ValueError('You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
      #######################
      # Config model_deploy #
      #######################
      deploy_config = model_deploy.DeploymentConfig(
          num_clones=FLAGS.num_clones,
          clone_on_cpu=FLAGS.clone_on_cpu,
          replica_id=FLAGS.task,
          num_replicas=FLAGS.worker_replicas,
          num_ps_tasks=FLAGS.num_ps_tasks)

      # Create global_step
      #with tf.device(deploy_config.variables_device()):
      #  global_step = slim.create_global_step()
      with tf.device("/job:ps/task:0"):
        global_step = tf.Variable(0, name="global_step")

      ######################
      # Select the dataset #
      ######################
      dataset = dataset_factory.get_dataset(
          FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

      ######################
      # Select the network #
      ######################
      network_fn = nets_factory.get_network_fn(
          FLAGS.model_name,
          num_classes=(dataset.num_classes - FLAGS.labels_offset),
          weight_decay=FLAGS.weight_decay,
          is_training=True)

      #####################################
      # Select the preprocessing function #
      #####################################
      preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
      image_preprocessing_fn = preprocessing_factory.get_preprocessing(
          preprocessing_name,
          is_training=True)

      ##############################################################
      # Create a dataset provider that loads data from the dataset #
      ##############################################################
      with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= FLAGS.labels_offset

        train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, train_image_size, train_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)

      ####################
      # Define the model #
      ####################
      def clone_fn(batch_queue):
        """Allows data parallelism by creating multiple clones of network_fn."""
        images, labels = batch_queue.dequeue()
        logits, end_points = network_fn(images)

        #############################
        # Specify the loss function #
        #############################
        if 'AuxLogits' in end_points:
          tf.losses.softmax_cross_entropy(
              logits=end_points['AuxLogits'], onehot_labels=labels,
              label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
        tf.losses.softmax_cross_entropy(
            logits=logits, onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=1.0)
        return end_points

      # Gather initial summaries.
      summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

      clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
      first_clone_scope = deploy_config.clone_scope(0)
      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by network_fn.
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

      # Add summaries for end_points.
      end_points = clones[0].outputs
      for end_point in end_points:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram('activations/' + end_point, x))
        summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                        tf.nn.zero_fraction(x)))

      # Add summaries for losses.
      for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
        summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

      # Add summaries for variables.
      for variable in slim.get_model_variables():
        summaries.add(tf.summary.histogram(variable.op.name, variable))

      #################################
      # Configure the moving averages #
      #################################
      if FLAGS.moving_average_decay:
        moving_average_variables = slim.get_model_variables()
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
      else:
        moving_average_variables, variable_averages = None, None

      #########################################
      # Configure the optimization procedure. #
      #########################################
      with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

      if FLAGS.sync_replicas:
        # If sync_replicas is enabled, the averaging will be done in the chief
        # queue runner.
        optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables,
            replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
            total_num_replicas=FLAGS.worker_replicas)
      elif FLAGS.moving_average_decay:
        # Update ops executed locally by trainer.
        update_ops.append(variable_averages.apply(moving_average_variables))

      # Variables to train.
      variables_to_train = _get_variables_to_train()

      #  and returns a train_tensor and summary_op
      total_loss, clones_gradients = model_deploy.optimize_clones(
          clones,
          optimizer,
          var_list=variables_to_train)
      # Add total_loss to summary.
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Create gradient updates.
      grad_updates = optimizer.apply_gradients(clones_gradients,
                                               global_step=global_step)
      update_ops.append(grad_updates)

      update_op = tf.group(*update_ops)
      train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                        name='train_op')

      # Add the summaries from the first clone. These contain the summaries
      # created by model_fn and either optimize_clones() or _gather_clone_loss().
      summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                         first_clone_scope))

      # Merge all summaries together.
      summary_op = tf.summary.merge(list(summaries), name='summary_op')


      ###########################
      # Kicks off the training. #
      ###########################
      summary_writer = tf.summary.FileWriter("tensorboard_%d" %(ctx.worker_num), graph=tf.get_default_graph())
      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_dir,
          master=server.target,
          is_chief=(FLAGS.task == 0),
          init_fn=_get_init_fn(),
          summary_op=summary_op,
          number_of_steps=FLAGS.max_number_of_steps,
          log_every_n_steps=FLAGS.log_every_n_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs,
          summary_writer=summary_writer,
          sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)
    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)
        # Create global_step.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # Select the dataset.
        dataset = dataset_factory.get_dataset(
            FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

        # Get the SSD network and its anchors.
        ssd_class = nets_factory.get_network(FLAGS.model_name)
        ssd_params = ssd_class.default_params._replace(num_classes=FLAGS.num_classes)
        ssd_net = ssd_class(ssd_params)
        ssd_shape = ssd_net.params.img_shape
        ssd_anchors = ssd_net.anchors(ssd_shape)

        # Select the preprocessing function.
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        tf_utils.print_configuration(FLAGS.__flags, ssd_params,
                                     dataset.data_sources, FLAGS.train_dir)
        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device(deploy_config.inputs_device()):
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.
            [image, shape, glabels, gbboxes] = provider.get(['image', 'shape',
                                                             'object/label',
                                                             'object/bbox'])
            # Pre-processing image, labels and bboxes.
            image, glabels, gbboxes = \
                image_preprocessing_fn(image, glabels, gbboxes,
                                       out_shape=ssd_shape,
                                       data_format=DATA_FORMAT)
            # Encode groundtruth labels and bboxes.
            gclasses, glocalisations, gscores = \
                ssd_net.bboxes_encode(glabels, gbboxes, ssd_anchors)
            batch_shape = [1] + [len(ssd_anchors)] * 3

            # Training batches and queue.
            r = tf.train.batch(
                tf_utils.reshape_list([image, gclasses, glocalisations, gscores]),
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(r, batch_shape)

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list([b_image, b_gclasses, b_glocalisations, b_gscores]),
                capacity=2 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple
            clones of network_fn."""
            # Dequeue batch.
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            # Construct SSD network.
            arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay,
                                          data_format=DATA_FORMAT)
            with slim.arg_scope(arg_scope):
                predictions, localisations, logits, end_points = \
                    ssd_net.net(b_image, is_training=True)
            # Add loss function.
            ssd_net.losses(logits, localisations,
                           b_gclasses, b_glocalisations, b_gscores,
                           match_threshold=FLAGS.match_threshold,
                           negative_ratio=FLAGS.negative_ratio,
                           alpha=FLAGS.loss_alpha,
                           label_smoothing=FLAGS.label_smoothing)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # =================================================================== #
        # Add summaries from first clone.
        # =================================================================== #
        clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                            tf.nn.zero_fraction(x)))
        # Add summaries for losses and extra losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # =================================================================== #
        # Configure the optimization procedure.
        # =================================================================== #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(FLAGS,
                                                             dataset.num_samples,
                                                             global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        # and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones,
            optimizer,
            var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(log_device_placement=False,
                                gpu_options=gpu_options)
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master='',
            is_chief=True,
            init_fn=tf_utils.get_init_fn(FLAGS),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            saver=saver,
            save_interval_secs=FLAGS.save_interval_secs,
            session_config=config,
            sync_optimizer=None)
def main(_):
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.device

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=False,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=dataset.num_classes,
            weight_decay=FLAGS.weight_decay,
            batch_norm_decay=None,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            # Noise up the images - don't do that for models where we are preprocessing the images with an existing ISP.
            with tf.device("/cpu:0"):
                noisy_batch, a, gauss_std = sensor_model.sensor_noise_rand_light_level(
                    images, [FLAGS.ll_low, FLAGS.ll_high], scale=1.0)
            bayer_mask = sensor_model.get_bayer_mask(train_image_size,
                                                     train_image_size)
            inputs = noisy_batch * bayer_mask

            # These parameters are only relevant for our special ISP functions. Mobilenet for instance will just eat them and not act upon them.
            logits, end_points, _ = network_fn(
                images=inputs,
                num_classes=dataset.num_classes,
                alpha=a,
                sigma=gauss_std,
                bayer_mask=bayer_mask,
                use_anscombe=FLAGS.use_anscombe,
                noise_channel=FLAGS.noise_channel,
                num_iters=FLAGS.num_iters,
                num_layers=FLAGS.num_layers,
                isp_model_name=FLAGS.isp_model_name,
                is_real_data=False)

            end_points['ground_truth'] = images
            # end_points['noisy'] = noisy_batch

            #############################
            # Specify the loss function #
            #############################
            tf.losses.softmax_cross_entropy(
                logits=logits,
                onehot_labels=labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # Add image summary for denoised image
        for end_point in end_points:
            if end_point in ['outputs', 'post_anscombe', 'pre_inv_anscombe']:
                summaries.add(
                    tf.summary.image(end_point, end_points[end_point]))
            if end_point in [
                    'mobilenet_input', 'noisy', 'inputs', 'ground_truth', 'R',
                    'G1', 'G2', 'B'
            ]:
                clean_image = end_points[end_point]
                summaries.add(tf.summary.image(end_point, clean_image))
                summaries.add(
                    tf.summary.scalar('bounds/%s_min' % end_point,
                                      tf.reduce_min(clean_image)))
                summaries.add(
                    tf.summary.scalar('bounds/%s_max' % end_point,
                                      tf.reduce_max(clean_image)))

        #################################
        # Configure the moving averages #
        #################################
        moving_average_variables = slim.get_model_variables()
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        # Update ops executed locally by trainer.
        update_ops.append(variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(train_tensor,
                            saver=saver,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=_get_init_fn(),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            sync_optimizer=None)
Exemple #15
0
def main(unused_argv):
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  with tf.Graph().as_default():
    with tf.device(config.inputs_device()):
      train_crop_size = (None if 0 in FLAGS.train_crop_size else
                         FLAGS.train_crop_size)
      assert FLAGS.dataset
      assert len(FLAGS.dataset) == len(FLAGS.dataset_dir)
      if len(FLAGS.first_frame_finetuning) == 1:
        first_frame_finetuning = (list(FLAGS.first_frame_finetuning)
                                  * len(FLAGS.dataset))
      else:
        first_frame_finetuning = FLAGS.first_frame_finetuning
      if len(FLAGS.three_frame_dataset) == 1:
        three_frame_dataset = (list(FLAGS.three_frame_dataset)
                               * len(FLAGS.dataset))
      else:
        three_frame_dataset = FLAGS.three_frame_dataset
      assert len(FLAGS.dataset) == len(first_frame_finetuning)
      assert len(FLAGS.dataset) == len(three_frame_dataset)
      datasets, samples_list = zip(
          *[_get_dataset_and_samples(config, train_crop_size, dataset,
                                     dataset_dir, bool(first_frame_finetuning_),
                                     bool(three_frame_dataset_))
            for dataset, dataset_dir, first_frame_finetuning_,
            three_frame_dataset_ in zip(FLAGS.dataset, FLAGS.dataset_dir,
                                        first_frame_finetuning,
                                        three_frame_dataset)])
      # Note that this way of doing things is wasteful since it will evaluate
      # all branches but just use one of them. But let's do it anyway for now,
      # since it's easy and will probably be fast enough.
      dataset = datasets[0]
      if len(samples_list) == 1:
        samples = samples_list[0]
      else:
        probabilities = FLAGS.dataset_sampling_probabilities
        if probabilities:
          assert len(probabilities) == len(samples_list)
        else:
          # Default to uniform probabilities.
          probabilities = [1.0 / len(samples_list) for _ in samples_list]
        probabilities = tf.constant(probabilities)
        logits = tf.log(probabilities[tf.newaxis])
        rand_idx = tf.squeeze(tf.multinomial(logits, 1, output_dtype=tf.int32),
                              axis=[0, 1])

        def wrap(x):
          def f():
            return x
          return f

        samples = tf.case({tf.equal(rand_idx, idx): wrap(s)
                           for idx, s in enumerate(samples_list)},
                          exclusive=True)

      # Prefetch_queue requires the shape to be known at graph creation time.
      # So we only use it if we crop to a fixed size.
      if train_crop_size is None:
        inputs_queue = samples
      else:
        inputs_queue = prefetch_queue.prefetch_queue(
            samples,
            capacity=FLAGS.prefetch_queue_capacity_factor*config.num_clones,
            num_threads=FLAGS.prefetch_queue_num_threads)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      if FLAGS.classification_loss == 'triplet':
        embedding_dim = FLAGS.embedding_dimension
        output_type_to_dim = {'embedding': embedding_dim}
      else:
        output_type_to_dim = {common.OUTPUT_TYPE: dataset.num_classes}
      model_args = (inputs_queue, output_type_to_dim, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in tf.contrib.framework.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy,
          FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step,
          FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps,
          FLAGS.learning_power,
          FLAGS.slow_start_step,
          FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(grads_and_vars,
                                                          grad_mult)

      with tf.name_scope('grad_clipping'):
        grads_and_vars = slim.learning.clip_gradient_norms(grads_and_vars, 5.0)

      # Create histogram summaries for the gradients.
      # We have too many summaries for mldash, so disable this one for now.
      # for grad, var in grads_and_vars:
      #   summaries.add(tf.summary.histogram(
      #       var.name.replace(':0', '_0') + '/gradient', grad))

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(grads_and_vars,
                                               global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(FLAGS.train_logdir,
                                              FLAGS.tf_initial_checkpoint,
                                              FLAGS.initialize_last_layer,
                                              last_layers,
                                              ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Exemple #16
0
def main(_):
    if ((not FLAGS.dataset_dir_iris) or (not FLAGS.dataset_dir_face)):
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset_iris = dataset_factory.get_dataset(FLAGS.dataset_name_iris,
                                                   FLAGS.dataset_split_name,
                                                   FLAGS.dataset_dir_iris)

        dataset_face = dataset_factory.get_dataset(FLAGS.dataset_name_face,
                                                   FLAGS.dataset_split_name,
                                                   FLAGS.dataset_dir_face)

        ####################
        # Select the network #
        ####################

        #  network_fn_iris = nets_factory.get_network_fn(
        #     FLAGS.model_name_iris,
        #    num_classes=(dataset.num_classes - FLAGS.labels_offset),
        #    weight_decay=FLAGS.weight_decay,
        #   is_training=True)

        network_fn_joint = nets_factory.get_network_fn_joint(
            FLAGS.model_name_joint,
            num_classes=(dataset_face.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name_iris = FLAGS.preprocessing_name_iris or FLAGS.model_name_iris
        image_preprocessing_fn_iris = preprocessing_factory.get_preprocessing(
            preprocessing_name_iris, is_training=True)

        preprocessing_name_face = FLAGS.preprocessing_name_face or FLAGS.model_name_face
        image_preprocessing_fn_face = preprocessing_factory.get_preprocessing(
            preprocessing_name_face, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider_iris = slim.dataset_data_provider.DatasetDataProvider(
                dataset_iris,
                shuffle=False,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image_iris, label_iris] = provider_iris.get(['image', 'label'])
            label_iris -= FLAGS.labels_offset

            #	train_image_size_iris = FLAGS.train_image_size_iris or network_fn_iris.default_image_size
            new_height_iris = FLAGS.New_Height_Of_Image_iris or network_fn_joint.default_image_size
            new_width_iris = FLAGS.New_Width_Of_Image_iris or network_fn_joint.default_image_size

            #         image = image_preprocessing_fn(image, train_image_size, train_image_size)
            image_iris = image_preprocessing_fn_iris(image_iris,
                                                     new_height_iris,
                                                     new_width_iris)

            #  io.imshow(image)
            #  io.show()
            images_iris, labels_iris = tf.train.batch(
                [image_iris, label_iris],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            #      tf.image_summary('images', images)
            labels_iris = slim.one_hot_encoding(
                labels_iris, dataset_iris.num_classes - FLAGS.labels_offset)
            batch_queue_iris = slim.prefetch_queue.prefetch_queue(
                [images_iris, labels_iris],
                capacity=2 * deploy_config.num_clones)

        with tf.device(deploy_config.inputs_device()):
            provider_face = slim.dataset_data_provider.DatasetDataProvider(
                dataset_face,
                shuffle=False,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image_face, label_face] = provider_face.get(['image', 'label'])
            label_face -= FLAGS.labels_offset

            #	train_image_size_face = FLAGS.train_image_size_face or network_fn_face.default_image_size
            new_height_face = FLAGS.New_Height_Of_Image_face or network_fn_joint.default_image_size
            new_width_face = FLAGS.New_Width_Of_Image_face or network_fn_joint.default_image_size

            #         image = image_preprocessing_fn(image, train_image_size, train_image_size)
            image_face = image_preprocessing_fn_face(image_face,
                                                     new_height_face,
                                                     new_width_face)

            #  io.imshow(image)
            #  io.show()
            images_face, labels_face = tf.train.batch(
                [image_face, label_face],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            #      tf.image_summary('images', images)
            labels_face = slim.one_hot_encoding(
                labels_face, dataset_face.num_classes - FLAGS.labels_offset)
            batch_queue_face = slim.prefetch_queue.prefetch_queue(
                [images_face, labels_face],
                capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################

        def clone_fn(batch_queue_iris, batch_queue_face):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images_iris, labels_iris = batch_queue_iris.dequeue()
            images_face, labels_face = batch_queue_face.dequeue()
            logits, end_points = network_fn_joint(images_face, images_iris)

            #  def clone_fn_face(batch_queue_face):
            #      """Allows data parallelism by creating multiple clones of network_fn."""
            #    images_face, labels_face = batch_queue_face.dequeue()
            #    logits_face, end_points_face, features_face,model_var_face = network_fn_face(images_face)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels_face,
                    label_smoothing=FLAGS.label_smoothing,
                    weight=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels_face,
                label_smoothing=FLAGS.label_smoothing,
                weight=1.0)

            # Adding the accuracy metric
            with tf.name_scope('accuracy'):
                predictions = tf.argmax(logits, 1)
                labels_face = tf.argmax(labels_face, 1)
                accuracy = tf.reduce_mean(
                    tf.to_float(tf.equal(predictions, labels_face)))
                tf.add_to_collection('accuracy', accuracy)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(
            deploy_config, clone_fn, [batch_queue_iris, batch_queue_face])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.histogram_summary('activations/' + end_point, x))
            summaries.add(
                tf.scalar_summary('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.histogram_summary(variable.op.name, variable))

        # Add summaries for the input images.
        summaries.add(
            tf.image_summary('face',
                             images_face,
                             max_images=15,
                             name='Face_images'))
        summaries.add(
            tf.image_summary('iris',
                             images_iris,
                             max_images=15,
                             name='Iris_images'))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset_face.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(
                tf.scalar_summary('learning_rate',
                                  learning_rate,
                                  name='learning_rate'))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # # Add total_loss to summary.
        # summaries.add(tf.scalar_summary('total_loss', total_loss,
        #                                 name='total_loss'))

        # Add total_loss and accuacy to summary.
        summaries.add(
            tf.scalar_summary('eval/Total_Loss', total_loss,
                              name='total_loss'))
        accuracy = tf.get_collection('accuracy', first_clone_scope)[0]
        summaries.add(
            tf.scalar_summary('eval/Accuracy', accuracy, name='accuracy'))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.merge_summary(list(summaries), name='summary_op')

        init_iris, init_feed = _get_init_op()

        #	var_2=[v for v in tf.all_variables() if v.name == "vgg_19/conv3/conv3_3/weights:0"][0]

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=init_iris,
            init_feed_dict=init_feed,
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Exemple #17
0
def main(_):
    '''
   training with optimization
   '''
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        #train_batch_queue = None
        #val_batch_queue = None
        train_set = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                "train2014", FLAGS.dataset_dir)
        #val_set = dataset_factory.get_dataset(FLAGS.dataset_name, "train2014", FLAGS.dataset_dir)

        with tf.device(deploy_config.inputs_device()):
            #####Consider Replace the following until #####
            options = tf.python_io.TFRecordOptions(
                TFRecordCompressionType.ZLIB)
            train_provider = slim.dataset_data_provider.DatasetDataProvider(
                train_set,
                num_readers=FLAGS.num_readers,
                reader_kwargs={'options': options},
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [train_image, train_label, train_boxes, train_masks
             ] = train_provider.get(['image', 'label', 'gt_boxes', 'gt_masks'])
            ##train_image=tf.reshape(train_image,(height, width, 3))
            ##train_label=tf.reshape(train_label,(height, width, 1))

            print(train_image, train_label, train_masks, train_boxes)
            train_image, train_label, train_boxes, train_masks = coco_preprocessing.preprocess_image(
                train_image,
                train_label,
                train_boxes,
                train_masks,
                is_training=True)

            train_images, train_labels = tf.train.batch(
                [train_image, train_label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)

            train_batch_queue = slim.prefetch_queue.prefetch_queue(
                [train_images, train_labels], capacity=2 * FLAGS.num_clones)
            print(train_batch_queue)


#          val_provider = slim.dataset_data_provider.DatasetDataProvider(
#             val_set,
#             num_readers=FLAGS.num_readers,
#             common_queue_capacity=20 * FLAGS.batch_size,
#             common_queue_min=10 * FLAGS.batch_size)
#
#          [val_image, val_label, val_boxes, val_masks] = val_provider.get(['image', 'label', 'gt_boxes', 'gt_masks'])
#
#          val_image, val_label, val_boxes, val_masks = coco_preprocessing.preprocess_image(val_image, val_label, val_boxes, val_masks)
#
#          val_images, val_labels = tf.train.batch(
#             [val_image, val_label],
#             batch_size=FLAGS.batch_size,
#             num_threads=FLAGS.num_preprocessing_threads,
#             capacity=5 * FLAGS.batch_size)
#
#          val_batch_queue = slim.prefetch_queue.prefetch_queue(
#             [val_images, val_labels], capacity=2 * FLAGS.num_clones)
#          print(val_batch_queue)

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of networks"""
            images, labels = batch_queue.dequeue()
            print(images, labels)
            images = tf.squeeze(images, [1])
            pred_annotation, fc8s, end_points = vgg_fcn(inputs=images)
            ############################
            ## Loss function #
            ############################
            #slim.losses.softmax_cross_entropy(pred_annotation, labels, label_smoothing=True, weights=1.0)
            print("Pred_annot", pred_annotation, "Labels", labels, "fc8s",
                  fc8s)

            tf.losses.sparse_softmax_cross_entropy(logits=tf.to_float(fc8s),
                                                   labels=tf.to_int32(labels),
                                                   weights=1.0,
                                                   scope="entropy")

            #loss = tf.reduce_mean((tf.losses.sparse_softmax_cross_entropy(logits=tf.to_float(pred_annotation),labels=tf.to_int32(labels),scope="entropy")))

            return images, labels, pred_annotation, end_points

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [train_batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        images, labels, preds, end_points = clones[0].outputs
        summaries.add(tf.summary.image("Original_images", images))
        summaries.add(tf.summary.image("Ground_truth_masks", labels))
        summaries.add(tf.summary.image("Prediction_masks", tf.to_float(preds)))

        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(train_set.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        variables_to_train = _get_variables_to_train()
        for var in variables_to_train:
            print(var.op.name)

        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        print('total_loss', total_loss, 'clone_gradients', clones_gradients)
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)

        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)

        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(
            train_tensor,
            logdir=FLAGS.logs_dir,
            master='',
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Exemple #18
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(num_clones=1,
                                                      clone_on_cpu=True,
                                                      replica_id=0,
                                                      num_replicas=2,
                                                      num_ps_tasks=1)

        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        images, labels = reader()
        with tf.device(deploy_config.inputs_device()):
            #images, labels = reader()
            batch_queue = slim.prefetch_queue.prefetch_queue([images, labels],
                                                             capacity=2 *
                                                             deploy_config)

        def clone_fn(batch_queue):
            images, labels = batch_queue.dequeue()
            logits = models(images)
            slim.losses.softmax_cross_entropy(logits, labels)
            return logits

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf.train.piecewise_constant(
                global_step, [10000, 12000], [0.001, 0.0001, 0.00001])
            optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if sync_replicas:
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=2,
                total_num_replicas=2,
                variable_averages=None,
                variables_to_average=None)

        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=tf.trainable_variables())
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        variables_to_restore = slim.get_variables_to_restore(
            exclude=['global_step'])
        init_fn = slim.assign_from_checkpoint_fn(fine_tune_path,
                                                 variables_to_restore,
                                                 ignore_missing_vars=True)
        print('start~~~~~~~~~~~~')
        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.per_process_gpu_memory_fraction = 0.4
        slim.learning.train(train_tensor,
                            logdir=output_path,
                            master='grpc://192.168.10.47:2222',
                            is_chief=True,
                            init_fn=init_fn,
                            summary_op=summary_op,
                            number_of_steps=max_steps,
                            log_every_n_steps=1,
                            save_summaries_secs=60,
                            save_interval_secs=600,
                            sync_optimizer=optimizer,
                            session_config=session_config)
Exemple #19
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)  #日志级别设置成 INFO
    #DEBUG 指出细粒度信息事件对调试应用程序是非常有帮助的,主要用于开发过程中打印一些运行信息。
    #INFO 消息在粗粒度级别上突出强调应用程序的运行过程。打印一些你感兴趣的或者重要的信息,
    # 这个可以用于生产环境中输出程序运行的一些重要信息,但是不能滥用,避免打印过多的日志。

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = segmentation_dataset.get_dataset(FLAGS.dataset,
                                               FLAGS.train_split,
                                               dataset_dir=FLAGS.dataset_dir)
    add = '/home/cxx/Deeplab/models/research/deeplab/datasets/cityscapes/exp/train_on_train_set/train'
    tf.gfile.MakeDirs(add)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            samples = input_generator.get(
                dataset,
                FLAGS.train_crop_size,
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                common.OUTPUT_TYPE: dataset.num_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    ###
    # get teacher info.
    ###
    teacher_dir = utils.shell_path(args.teacher_dir)
    assert tf.gfile.IsDirectory(teacher_dir)
    json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json'))
    assert len(json_in_dir) == 1
    te_json = json_in_dir[0]
    te_ckpt = tf.train.latest_checkpoint(teacher_dir)
    assert tf.train.checkpoint_exists(te_ckpt)

    with open(te_json, 'rt') as F:
        configs = json.load(F)
    te_hparams = Namespace(**configs)
    teacher = wavenet.Wavenet(te_hparams)

    ###
    # get student info.
    ###
    if args.log_root:
        if args.config is None:
            raise RuntimeError('No config json specified.')
        tf.logging.info('using config form {}'.format(args.config))
        with open(args.config, 'rt') as F:
            configs = json.load(F)
        st_hparams = Namespace(**configs)
        logdir_name = config_str.get_config_time_str(st_hparams,
                                                     'parallel_wavenet',
                                                     EXP_TAG)
        logdir = os.path.join(args.log_root, logdir_name)
        os.makedirs(logdir, exist_ok=True)
        shutil.copy(args.config, logdir)
    else:
        logdir = args.logdir
        config_json = glob.glob(os.path.join(logdir, '*.json'))[0]
        tf.logging.info('using config form {}'.format(config_json))
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        st_hparams = Namespace(**configs)
    tf.logging.info('Saving to {}'.format(logdir))

    pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher,
                                           args.train_path)

    def _data_dep_init():
        inputs_val = reader.get_init_batch(pwn.train_path,
                                           batch_size=args.total_batch_size,
                                           seq_len=pwn.wave_length)
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)
        }

        init_ff_dict = pwn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Calculate initial statistics.')
            init_out = session.run(init_ff_dict,
                                   feed_dict={_inputs_dict['mel']: mel_data})
            new_x = init_out['x']
            mean = init_out['mean_tot']
            scale = init_out['scale_tot']
            _init_logging(new_x, 'new_x')
            _init_logging(mean, 'mean')
            _init_logging(scale, 'scale')
            tf.logging.info('Done Calculate initial statistics.')

        return callback

    def _model_fn(_inputs_dict):
        ff_dict = pwn.feed_forward(_inputs_dict)
        ff_dict.update(_inputs_dict)
        loss_dict = pwn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

        for loss_key, loss_val in loss_dict.items():
            tf.summary.scalar(loss_key, loss_val)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = pwn.get_batch(clone_batch_size)
            # get a mel batch not corresponding to the wave batch.
            # if contrastive loss is not used, this input operation will not be evaluated.
            inputs_dict['mel_rand'] = pwn.get_batch(clone_batch_size)['mel']

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        ###
        # variables to train
        ###
        st_vars = [
            var for var in tf.trainable_variables() if 'iaf' in var.name
        ]

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(pwn.learning_rate_schedule[0])
            for key, value in pwn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=st_vars)
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(st_vars))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        ###
        # restore teacher
        ###
        te_vars = [
            var for var in tf.trainable_variables() if 'iaf' not in var.name
        ]
        # teacher use EMA
        te_vars = {
            '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv
            for tv in te_vars
        }
        restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
            te_ckpt, te_vars)
        data_dep_init_fn = _data_dep_init()

        def group_init_fn(session):
            restore_init_fn(session)
            data_dep_init_fn(session)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            logdir=logdir,
                            number_of_steps=pwn.num_iters,
                            summary_op=summary_op,
                            global_step=global_step,
                            log_every_n_steps=100,
                            save_summaries_secs=600,
                            save_interval_secs=3600,
                            session_config=session_config,
                            init_fn=group_init_fn)
Exemple #21
0
    def trainer_routine(self, dataset_fn, model_fn, train_config, train_dir):
        """Initialize the training job.

    Args:
      dataset_fn: A T object containing dataset_config,
      model_fn:
      model_config:
      train_dir:
    """

        model_instance = model_fn()

        with tf.Graph().as_default():

            # Parameters for a single worker.
            ps_tasks = 0
            worker_replicas = 1
            worker_job_name = 'lonely_worker'
            task = 0
            is_chief = True
            master = ''
            num_clones = 1
            clone_on_cpu = False

            assert num_clones is 1

            deploy_config = model_deploy.DeploymentConfig(
                num_clones=num_clones,
                clone_on_cpu=clone_on_cpu,
                replica_id=task,
                num_replicas=worker_replicas,
                num_ps_tasks=ps_tasks,
                worker_job_name=worker_job_name)

            # Place the global step on the device storing the variables.
            with tf.device(deploy_config.variables_device()):
                global_step = slim.create_global_step()

            with tf.device(deploy_config.inputs_device()):

                input_queue = model_instance.create_input_queue(
                    dataset_fn,
                    train_config.model_config,
                    queue_type=train_config.model_config['input_queue_type'])

            # collections so that they don't have to be passed around.
            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
            global_summaries = set([])

            model_loss = functools.partial(model_instance.create_losses,
                                           train_config=train_config)

            clones = model_deploy.create_clones(deploy_config, model_loss,
                                                [input_queue])
            first_clone_scope = clones[0].scope

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

            with tf.device(deploy_config.optimizer_device()):
                training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                    train_config.model_config.optimizer)
                for var in optimizer_summary_vars:
                    tf.summary.scalar(var.op.name, var, family='LearningRate')

            # import ipdb; ipdb.set_trace()
            with tf.device(deploy_config.optimizer_device()):
                regularization_losses = (
                    None
                    if train_config.model_config.losses.add_regularization_loss
                    else [])

                # # Where variable filters were implemented.
                # trainable_vars = variables_helper.filter_variables(tf.trainable_variables(),
                #                                                   ['grasp'],
                #                                                   invert=False)
                # import ipdb; ipdb.set_trace()

                trainable_vars = tf.trainable_variables()

                total_loss, grads_and_vars = model_deploy.optimize_clones(
                    clones,
                    training_optimizer,
                    regularization_losses=regularization_losses,
                    var_list=trainable_vars)
                total_loss = tf.check_numerics(total_loss,
                                               'LossTensor is inf or nan.')

                # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
                if train_config.model_config.bias_grad_multiplier:
                    biases_regex_list = ['.*/biases']
                    grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                        grads_and_vars,
                        biases_regex_list,
                        multiplier=train_config.bias_grad_multiplier)

                # Optionally freeze some layers by setting their gradients to be zero.
                if train_config.model_config.freeze_variables:
                    grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                        grads_and_vars, train_config.freeze_variables)

                # Optionally clip gradients
                if train_config.model_config.gradient_clipping_by_norm > 0:
                    with tf.name_scope('clip_grads'):
                        grads_and_vars = slim.learning.clip_gradient_norms(
                            grads_and_vars,
                            train_config.gradient_clipping_by_norm)

                # Create gradient updates.
                grad_updates = training_optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                update_ops.append(grad_updates)
                update_op = tf.group(*update_ops, name='update_barrier')
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name='train_op')

            # Add summaries.
            for model_var in slim.get_model_variables():
                global_summaries.add(
                    tf.summary.histogram('ModelVars/' + model_var.op.name,
                                         model_var))
            for loss_tensor in tf.losses.get_losses():
                global_summaries.add(
                    tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                      loss_tensor))
            global_summaries.add(
                tf.summary.scalar('Losses/TotalLoss',
                                  tf.losses.get_total_loss()))

            # Add the summaries from the first clone. These contain the summaries
            # created by model_fn and either optimize_clones() or _gather_clone_loss().
            summaries |= set(
                tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
            summaries |= global_summaries

            # Merge all summaries together.
            summary_op = tf.summary.merge(list(summaries), name='summary_op')

            # Soft placement allows placing on CPU ops without GPU implementation.
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)

            # import ipdb; ipdb.set_trace()
            # Save checkpoints regularly.
            keep_checkpoint_every_n_hours = train_config.model_config.keep_checkpoint_every_n_hours

            # Added Target vars param
            saver = tf.train.Saver(  # target_vars,
                keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

            # Create ops required to initialize the model from a given checkpoint.
            init_fn = None
            if train_config.model_config.fine_tune_checkpoint and False:
                if not train_config.model_config.fine_tune_checkpoint_type:
                    # train_config.from_detection_checkpoint field is deprecated. For
                    # backward compatibility, fine_tune_checkpoint_type is set based on
                    # from_detection_checkpoint.
                    if train_config.model_config.from_detection_checkpoint:
                        train_config.fine_tune_checkpoint_type = 'detection'
                    else:
                        train_config.model_config.fine_tune_checkpoint_type = 'classification'

                var_map = model_instance.restore_map(
                    fine_tune_checkpoint_type=train_config.
                    fine_tune_checkpoint_type,
                    load_all_detection_checkpoint_vars=(
                        train_config.load_all_detection_checkpoint_vars))

                available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        var_map,
                        train_config.model_config.fine_tune_checkpoint,
                        include_global_step=False))

                #  # Add Target Vars
                # print(available_var_map)
                # available_var_map = variables_helper.filter_variables(available_var_map,
                #                                                   ['grasp'],
                #                                                   invert=False)
                # # target_vars = variables_helper.filter_variables(tf.trainable_variables(), 'source', invert=False)
                # # mapping_vars = variables_helper.filter_variables(tf.trainable_variables(), ['source', '.*dann'], invert=False)

                init_saver = tf.train.Saver(available_var_map)

                def initializer_fn(sess):
                    # sess.run(tf.global_variables_initializer())
                    init_saver.restore(sess, train_config.fine_tune_checkpoint)

                init_fn = initializer_fn

            slim.learning.train(
                train_tensor,
                logdir=train_dir,
                master=master,
                is_chief=is_chief,
                session_config=session_config,
                startup_delay_steps=train_config.model_config.
                startup_delay_steps,
                init_fn=init_fn,
                summary_op=summary_op,
                number_of_steps=(train_config.model_config.num_steps
                                 if train_config.model_config.num_steps else
                                 None),
                save_summaries_secs=120,
                save_interval_secs=120,
                sync_optimizer=None,
                saver=saver)
Exemple #22
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=2 * FLAGS.batch_size,
          common_queue_min=1 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4,
            scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    if FLAGS.quantize_delay >= 0:
      tf.contrib.quantize.create_training_graph(
          quant_delay=FLAGS.quantize_delay)

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          total_num_replicas=FLAGS.worker_replicas,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################

    from datetime import datetime
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    output_file = "results/"+ FLAGS.model_name+ '/' + current_time + ".txt"
    builder = tf.profiler.ProfileOptionBuilder
    opts = (builder(builder.time_and_memory()).
        with_step(-1). # with -1, should compute the average of all registered steps.
        with_file_output(output_file).
        select(["peak_bytes","residual_bytes","output_bytes","micros","accelerator_micros","cpu_micros","params","float_ops","occurrence","tensor_value","device","op_types"]).order_by("micros").
        build())
    
    with tf.contrib.tfprof.ProfileContext('profiler',
					 trace_steps=range(0,FLAGS.number_of_steps*20),
					 dump_steps=range(0,FLAGS.number_of_steps*20, 100),
									 debug=True) as pctx:

      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_dir,
          master=FLAGS.master,
          is_chief=(FLAGS.task == 0),
          init_fn=_get_init_fn(),
          summary_op=summary_op,
          number_of_steps=FLAGS.number_of_steps,
          log_every_n_steps=FLAGS.log_every_n_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs,
          sync_optimizer=optimizer if FLAGS.sync_replicas else None)
      
      pctx.profiler.profile_operations(options=opts)
    out = open(output_file)
    text = out.read()
    out.close()
    with open(output_file, 'w') as ofs:
      ofs.write('\n'.join([str(t) for t in FLAGS.flag_values_dict().items()]))
      ofs.write('\n')
      ofs.write(text) 
Exemple #23
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################

        def calculate_pooling_center_loss(features, label, alfa, nrof_classes,
                                          weights, name):
            features = tf.reshape(features, [features.shape[0], -1])
            label = tf.argmax(label, 1)

            nrof_features = features.get_shape()[1]
            centers = tf.get_variable(name, [nrof_classes, nrof_features],
                                      dtype=tf.float32,
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
            label = tf.reshape(label, [-1])
            centers_batch = tf.gather(centers, label)
            centers_batch = tf.nn.l2_normalize(centers_batch, axis=-1)

            diff = (1 - alfa) * (centers_batch - features)
            centers = tf.scatter_sub(centers, label, diff)

            with tf.control_dependencies([centers]):
                distance = tf.square(features - centers_batch)
                distance = tf.reduce_sum(distance, axis=-1)
                center_loss = tf.reduce_mean(distance)

            center_loss = tf.identity(center_loss * weights,
                                      name=name + '_loss')
            return center_loss

        def attention_crop(attention_maps):
            batch_size, height, width, num_parts = attention_maps.shape
            bboxes = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                part_weights = part_weights / (np.sum(part_weights) + 1e-7)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]

                mask = attention_map[:, :, selected_index]

                threshold = random.uniform(0.4, 0.6)
                itemindex = np.where(mask >= mask.max() * threshold)

                ymin = itemindex[0].min() / height - 0.1
                ymax = itemindex[0].max() / height + 0.1
                xmin = itemindex[1].min() / width - 0.1
                xmax = itemindex[1].max() / width + 0.1

                bbox = np.asarray([ymin, xmin, ymax, xmax], dtype=np.float32)
                bboxes.append(bbox)
            bboxes = np.asarray(bboxes, np.float32)
            return bboxes

        def attention_drop(attention_maps):
            batch_size, height, width, num_parts = attention_maps.shape
            masks = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                part_weights = part_weights / (np.sum(part_weights) + 1e-7)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]
                mask = attention_map[:, :, selected_index:selected_index + 1]

                # soft mask
                threshold = random.uniform(0.2, 0.5)
                mask = (mask < threshold * mask.max()).astype(np.float32)
                masks.append(mask)
            masks = np.asarray(masks, dtype=np.float32)
            return masks

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits_1, end_points_1 = network_fn(images)

            attention_maps = end_points_1['attention_maps']
            attention_maps = tf.image.resize_bilinear(
                attention_maps, [train_image_size, train_image_size])

            # attention crop
            bboxes = tf.py_func(attention_crop, [attention_maps], [tf.float32])
            bboxes = tf.reshape(bboxes, [FLAGS.batch_size, 4])
            box_ind = tf.range(FLAGS.batch_size, dtype=tf.int32)
            images_crop = tf.image.crop_and_resize(
                images,
                bboxes,
                box_ind,
                crop_size=[train_image_size, train_image_size])

            # attention drop
            masks = tf.py_func(attention_drop, [attention_maps], [tf.float32])
            masks = tf.reshape(
                masks,
                [FLAGS.batch_size, train_image_size, train_image_size, 1])
            images_drop = images * masks

            logits_2, end_points_2 = network_fn(images_crop, reuse=True)
            logits_3, end_points_3 = network_fn(images_drop, reuse=True)

            slim.losses.softmax_cross_entropy(logits_1,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_1')
            slim.losses.softmax_cross_entropy(logits_2,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_2')
            slim.losses.softmax_cross_entropy(logits_3,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_3')

            embeddings = end_points_1['embeddings']
            center_loss = calculate_pooling_center_loss(
                features=embeddings,
                label=labels,
                alfa=0.95,
                nrof_classes=dataset.num_classes,
                weights=1.0,
                name='center_loss')
            slim.losses.add_loss(center_loss)
            return end_points_1

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        config.gpu_options.allow_growth = True
        # config.gpu_options.per_process_gpu_memory_fraction = 0.9
        config.gpu_options.visible_device_list = FLAGS.gpus

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            session_config=config)
Exemple #24
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        network_fn = nets_factory.get_network(FLAGS.model_name)
        params = network_fn.default_params
        params = params._replace(match_threshold=FLAGS.match_threshold)
        # initalize the net
        net = network_fn(params)
        out_shape = net.params.img_shape
        anchors = net.anchors(out_shape)

        # create batch dataset
        with tf.device(deploy_config.inputs_device()):
            b_image, b_glocalisations, b_gscores = \
            load_batch.get_batch(FLAGS.dataset_dir,
                  FLAGS.num_readers,
                  FLAGS.batch_size,
                  out_shape,
                  net,
                  anchors,
                  FLAGS,
                  file_pattern = FLAGS.file_pattern,
                  is_training = True,
                  shuffe = FLAGS.shuffle_data)
            allgscores = []
            allglocalization = []
            for i in range(len(anchors)):
                allgscores.append(tf.reshape(b_gscores[i], [-1]))
                allglocalization.append(
                    tf.reshape(b_glocalisations[i], [-1, 4]))

            b_gscores = tf.concat(allgscores, 0)
            b_glocalisations = tf.concat(allglocalization, 0)

            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list([b_image, b_glocalisations, b_gscores]),
                num_threads=8,
                capacity=16 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):

            #Allows data parallelism by creating multiple
            #clones of network_fn.

            # Dequeue batch.
            batch_shape = [1] * 3
            b_image, b_glocalisations, b_gscores = \
             tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)
            # Construct SSD network.
            arg_scope = net.arg_scope(weight_decay=FLAGS.weight_decay,
                                      data_format=FLAGS.data_format)
            with slim.arg_scope(arg_scope):
                localisations, logits, end_points = \
                 net.net(b_image, is_training=True, use_batch=FLAGS.use_batch)
            # Add loss function.
            net.losses(logits,
                       localisations,
                       b_glocalisations,
                       b_gscores,
                       negative_ratio=FLAGS.negative_ratio,
                       use_hard_neg=FLAGS.use_hard_neg,
                       alpha=FLAGS.loss_alpha,
                       label_smoothing=FLAGS.label_smoothing)
            return end_points

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        #end_points = clones[0].outputs
        #for end_point in end_points:
        #	x = end_points[end_point]
        #	summaries.add(tf.summary.histogram('activations/' + end_point, x))

        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        #
        #for variable in slim.get_model_variables():
        #	summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, FLAGS.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.fine_tune:
            gradient_multipliers = pickle.load(
                open('nets/multiplier_300.pkl', 'rb'))
        else:
            gradient_multipliers = None

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        if gradient_multipliers:
            with ops.name_scope('multiply_grads'):
                clones_gradients = slim.learning.multiply_gradients(
                    clones_gradients, gradient_multipliers)

        if FLAGS.clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                clones_gradients = slim.learning.clip_gradient_norms(
                    clones_gradients, FLAGS.clip_gradient_norm)
        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        #train_tensor = slim.learning.create_train_op(total_loss, optimizer, gradient_multipliers=gradient_multipliers)
        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction,
            allocator_type="BFC")
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            allow_soft_placement=True,
            inter_op_parallelism_threads=0,
            intra_op_parallelism_threads=1,
        )
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                                      clone_on_cpu=FLAGS.clone_on_cpu,
                                                      replica_id=0,
                                                      num_replicas=1,
                                                      num_ps_tasks=0)

        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        image_count, class_count = get_image_and_class_count(FLAGS.dataset_dir, 'train')
        dataset = get_dataset('aerial', FLAGS.dataset_dir, image_count, class_count, 'train')
        network_fn = get_network_fn(num_classes=(dataset.num_classes), weight_decay=FLAGS.weight_decay)
        image_preprocessing_fn = get_preprocessing()

        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(dataset,
                                                                      num_readers=FLAGS.num_readers,
                                                                      common_queue_capacity=20 * FLAGS.batch_size,
                                                                      common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            image = image_preprocessing_fn(image, 224, 224)
            images, labels = tf.train.batch([image, label],
                                            batch_size=FLAGS.batch_size,
                                            num_threads=FLAGS.num_preprocessing_threads,
                                            capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            batch_queue = slim.prefetch_queue.prefetch_queue([images, labels], capacity=2 * deploy_config.num_clones)

        def clone_fn(batch_queue):
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)
            logits = tf.squeeze(logits) # added -- does this help?
            slim.losses.softmax_cross_entropy(logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
            return(end_points)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x)))
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(deploy_config.optimizer_device()):
            decay_steps = int(dataset.num_samples / FLAGS.batch_size * FLAGS.num_epochs_per_decay)
            learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                                       global_step,
                                                       decay_steps,
                                                       FLAGS.learning_rate_decay_factor,
                                                       staircase=True,
                                                       name='exponential_decay_learning_rate')
            optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                                  decay=FLAGS.rmsprop_decay,
                                                  momentum=FLAGS.rmsprop_momentum,
                                                  epsilon=FLAGS.opt_epsilon)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))



        variables_to_train = _get_variables_to_train()
        total_loss, clones_gradients = model_deploy.optimize_clones(clones, optimizer, var_list=variables_to_train)
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op], total_loss, name='train_op')

        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=_get_init_fn(),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            sync_optimizer=None)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)

            accuracy = slim.metrics.accuracy(tf.to_int32(tf.argmax(logits, 1)),
                                             tf.to_int32(tf.argmax(labels, 1)))
            tf.add_to_collection('accuracy', accuracy)
            end_points['train_accuracy'] = accuracy
            return end_points

        # Get accuracies for the batch

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            if 'accuracy' in end_point:
                continue
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        train_acc = end_points['train_accuracy']
        summaries.add(
            tf.summary.scalar('train_accuracy', end_points['train_accuracy']))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # @philkuz
        # Add accuracy summaries
        # TODO add if statemetn for n iterations
        # images_val, labels_val= tf.train.batch(
        #     [image, label],
        #     batch_size=FLAGS.batch_size,
        #     num_threads=FLAGS.num_preprocessing_threads,
        #     capacity=5 * FLAGS.batch_size)

        # # labels_val = slim.one_hot_encoding(
        # #     labels_val, dataset.num_classes - FLAGS.labels_offset)
        # batch_queue_val = slim.prefetch_queue.prefetch_queue(
        #     [images_val, labels_val], capacity=2 * deploy_config.num_clones)
        # logits, end_points = network_fn(images, reuse=True)
        # # predictions = tf.nn.softmax(logits)
        # predictions = tf.to_in32(tf.argmax(logits,1))

        # logits_val, end_points_val = network_fn(images_val, reuse=True)
        # predictions_val = tf.to_in32(tf.argmax(logits_val,1))

        # labels_val = tf.squeeze(labels_val)
        # labels = tf.squeeze(labels)

        # names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        #       'train/accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        #       'val/accuracy': slim.metrics.streaming_accuracy(predictions_val, labels_val),
        # })
        # for metric_name, metric_value in names_to_values.items():
        #   op = tf.summary.scalar(metric_name, metric_value)
        #   # op = tf.Print(op, [metric_value], metric_name)
        #   summaries.add(op)
        # Add summaries for variables.
        # TODO something to remove some of these from tensorboard scalars
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # @philkuz
        # set the  max_number_of_steps parameter if num_epochs is available
        print('FLAGS.num_epochs', FLAGS.num_epochs)
        if FLAGS.num_epochs is not None and FLAGS.max_number_of_steps is None:
            FLAGS.max_number_of_steps = int(
                FLAGS.num_epochs * dataset.num_samples / FLAGS.batch_size)
            # FLAGS.max_number_of_steps = int(math.round(FLAGS.num_epochs / dataset.num_samples))

        # setup the logdir
        # @philkuz  the train_dir setup
        if FLAGS.experiment_name is not None:
            experiment_dir = 'bs={},lr={},epochs={}/{}'.format(
                FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs,
                FLAGS.experiment_name)
            print(experiment_dir)
            FLAGS.train_dir = os.path.join(FLAGS.train_dir, experiment_dir)
            print(FLAGS.train_dir)

        # @philkuz overriding train_step
        def train_step(sess, train_op, global_step, train_step_kwargs):
            """Function that takes a gradient step and specifies whether to stop.
      Args:
        sess: The current session.
        train_op: An `Operation` that evaluates the gradients and returns the
          total loss.
        global_step: A `Tensor` representing the global training step.
        train_step_kwargs: A dictionary of keyword arguments.
      Returns:
        The total loss and a boolean indicating whether or not to stop training.
      Raises:
        ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
      """
            start_time = time.time()

            trace_run_options = None
            run_metadata = None
            should_acc = True  # TODO make this not hardcoded @philkuz
            if 'should_trace' in train_step_kwargs:
                if 'logdir' not in train_step_kwargs:
                    raise ValueError(
                        'logdir must be present in train_step_kwargs when '
                        'should_trace is present')
                if sess.run(train_step_kwargs['should_trace']):
                    trace_run_options = config_pb2.RunOptions(
                        trace_level=config_pb2.RunOptions.FULL_TRACE)
                    run_metadata = config_pb2.RunMetadata()
            if not should_acc:
                total_loss, np_global_step = sess.run(
                    [train_op, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            else:
                total_loss, acc, np_global_step = sess.run(
                    [train_op, train_acc, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            time_elapsed = time.time() - start_time

            if run_metadata is not None:
                tl = timeline.Timeline(run_metadata.step_stats)
                trace = tl.generate_chrome_trace_format()
                trace_filename = os.path.join(
                    train_step_kwargs['logdir'],
                    'tf_trace-%d.json' % np_global_step)
                tf.logging.info('Writing trace to %s', trace_filename)
                file_io.write_string_to_file(trace_filename, trace)
                if 'summary_writer' in train_step_kwargs:
                    train_step_kwargs['summary_writer'].add_run_metadata(
                        run_metadata, 'run_metadata-%d' % np_global_step)

            if 'should_log' in train_step_kwargs:
                if sess.run(train_step_kwargs['should_log']):
                    if not should_acc:
                        tf.logging.info(
                            'global step %d: loss = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, time_elapsed)
                    else:
                        tf.logging.info(
                            'global step %d: loss = %.4f train_acc = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, acc, time_elapsed)

            if 'should_stop' in train_step_kwargs:
                should_stop = sess.run(train_step_kwargs['should_stop'])
            else:
                should_stop = False

            return total_loss, should_stop

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            train_step_fn=train_step,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Exemple #27
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = _create_input_queue(
                train_config.batch_size // num_clones, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        print(train_config.fine_tune_checkpoint)
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint)
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
Exemple #28
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4,
            scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          total_num_replicas=FLAGS.worker_replicas,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables())

    ###########################
    # Kicks off the training. #
    ###########################
    init = tf.global_variables_initializer()

    # Start running operations on the Graph. allow_soft_placement must be set to
    # True to build towers on GPU, as some of the ops do not have GPU
    # implementations.
    sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)
    if FLAGS.checkpoint_path==FLAGS.train_dir:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))

    # load pretrained weights
    weight_ini_fn = _get_init_fn()
    weight_ini_fn(sess)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

    for step in range(FLAGS.max_number_of_steps):
        start_time = time.time()
        # _, loss_value = sess.run([train_tensor, loss])
        # _, loss_value = sess.run([train_tensor, total_loss])
        loss_value = sess.run(train_tensor)
        duration = time.time() - start_time

        assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

        if step % FLAGS.log_every_n_steps == 0:
            # num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
            num_examples_per_step = FLAGS.batch_size
            examples_per_sec = num_examples_per_step / duration
            # sec_per_batch = duration / FLAGS.num_gpus
            sec_per_batch = duration

            format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
                          'sec/batch)')
            print(format_str % (step, loss_value,
                                examples_per_sec, sec_per_batch))

        if step % FLAGS.summary_snapshot_steps == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)

        # Save the model checkpoint periodically.
        if step % FLAGS.model_snapshot_steps == 0 or (step + 1) == FLAGS.max_number_of_steps:
            checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

    print('OK...')
Exemple #29
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size / config.num_clones

    # Get dataset-dependent information.
    dataset = segmentation_dataset.get_dataset(FLAGS.dataset,
                                               FLAGS.train_split,
                                               dataset_dir=FLAGS.dataset_dir)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default():
        with tf.device(config.inputs_device()):
            samples = input_generator.get(
                dataset,
                FLAGS.train_crop_size,
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                common.OUTPUT_TYPE: dataset.num_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes()
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
Exemple #30
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

        net = model_cmc.Model()

        with tf.device(deploy_config.inputs_device()):
            if (FLAGS.model == 'unet'):
                batch_queue = \
                load_batch.get_batch(FLAGS.dataset_dir,
                                        FLAGS.num_readers,
                                        FLAGS.batch_size,
                                        None,
                                        FLAGS,
                                        file_pattern = FLAGS.file_pattern,
                                        is_training = True,
                                        shuffe = FLAGS.shuffle_data)
            elif (FLAGS.model == 'patch'):
                batch_queue = \
                load_batch_patch.get_batch(FLAGS.dataset_dir,
                                        FLAGS.num_readers,
                                        FLAGS.batch_size,
                                        None,
                                        FLAGS,
                                        file_pattern = FLAGS.file_pattern,
                                        is_training = True,
                                        shuffe = FLAGS.shuffle_data)
            elif (FLAGS.model == 'cmc'):
                batch_queue = \
                load_batch_cmc.get_batch(FLAGS.dataset_dir,
                                        FLAGS.num_readers,
                                        FLAGS.batch_size,
                                        None,
                                        FLAGS,
                                        file_pattern = FLAGS.file_pattern,
                                        is_training = True,
                                        shuffe = FLAGS.shuffle_data)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        print("Batch_loading_successful")

        def clone_fn(batch_queue):
            batch_shape = [1] * 3
            b_image, label = batch_queue

            logits, end_points = net.net(b_image)

            # Add loss function.
            loss, mean_iou = net.weighted_losses(logits, label)
            return end_points, mean_iou

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        end_points, mean_iou = clones[0].outputs
        update_ops.append(mean_iou[1])
        #for end_point in end_points:
        #	x = end_points[end_point]
        #	summaries.add(tf.summary.histogram('activations/' + end_point, x))

        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        #
        #for variable in slim.get_model_variables():
        #	summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, FLAGS.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.fine_tune:
            gradient_multipliers = pickle.load(
                open('nets/multiplier_300.pkl', 'rb'))
        else:
            gradient_multipliers = None

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        if gradient_multipliers:
            with ops.name_scope('multiply_grads'):
                clones_gradients = slim.learning.multiply_gradients(
                    clones_gradients, gradient_multipliers)

        if FLAGS.clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                clones_gradients = slim.learning.clip_gradient_norms(
                    clones_gradients, FLAGS.clip_gradient_norm)
        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        #train_tensor = slim.learning.create_train_op(total_loss, optimizer, gradient_multipliers=gradient_multipliers)
        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #

        def train_step_fn(session, *args, **kwargs):
            # visualizer = Beholder(session=session, logdir=FLAGS.train_dir)
            total_loss, should_stop = train_step(session, *args, **kwargs)

            if train_step_fn.step % FLAGS.validation_check == 0:
                _mean_iou = session.run(train_step_fn.mean_iou)
                print('evaluation step %d - loss = %.4f mean_iou = %.2f%%' %\
                 (train_step_fn.step, total_loss, _mean_iou ))
            # evaluated_tensors = session.run([end_points['conv4'], end_points['up1']])
            # example_frame = session.run(end_points['up2'])
            # visualizer.update(arrays=evaluated_tensors, frame=example_frame)

            train_step_fn.step += 1
            return [total_loss, should_stop]

        train_step_fn.step = 0
        train_step_fn.end_points = end_points
        train_step_fn.mean_iou = mean_iou[0]

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction,
            allocator_type="BFC")
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            allow_soft_placement=True,
            inter_op_parallelism_threads=0,
            intra_op_parallelism_threads=1,
        )
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            train_step_fn=train_step_fn,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
Exemple #31
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_biasCNN.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        dataset_val = dataset_biasCNN.get_dataset(FLAGS.dataset_name,
                                                  'validation',
                                                  FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        network_fn_val = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_biasCNN.get_preprocessing(
            preprocessing_name,
            is_training=True,
            flipLR=FLAGS.flipLR,
            random_scale=FLAGS.random_scale,
            is_windowed=FLAGS.is_windowed)

        image_preprocessing_fn_val = preprocessing_biasCNN.get_preprocessing(
            preprocessing_name,
            is_training=False,
            flipLR=FLAGS.flipLR,
            random_scale=FLAGS.random_scale,
            is_windowed=FLAGS.is_windowed)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ############################################
        # Create a provider for the validation set #
        ############################################
        provider_val = slim.dataset_data_provider.DatasetDataProvider(
            dataset_val,
            shuffle=True,
            common_queue_capacity=2 * FLAGS.batch_size_val,
            common_queue_min=FLAGS.batch_size_val)
        [image_val, label_val] = provider_val.get(['image', 'label'])
        label_val -= FLAGS.labels_offset

        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

        image_val = image_preprocessing_fn_val(image_val, eval_image_size,
                                               eval_image_size)

        images_val, labels_val = tf.train.batch(
            [image_val, label_val],
            batch_size=FLAGS.batch_size_val,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size_val)

        ###############################
        # Define the model (training) #
        ###############################

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()

            with tf.variable_scope('my_scope'):
                logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            tf.contrib.quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        #################################
        # Define the model (validation) #
        #################################

        with tf.variable_scope('my_scope', reuse=True):
            logits_val, _ = network_fn_val(images_val)

        predictions_val = tf.argmax(logits_val, 1)
        labels_val = tf.squeeze(labels_val)

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions_val, labels_val),
            'Recall_5':
            slim.metrics.streaming_recall_at_k(logits_val, labels_val, 5),
        })

        for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection('summaries', op)

        # Gather validation summaries
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Create a non-default saver so we don't delete all the old checkpoints.
        my_saver = tf_saver.Saver(
            max_to_keep=FLAGS.max_checkpoints_to_keep,
            keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        )

        # Create a non-default dictionary of options for train_step_fn
        # This is a hack that lets us pass everything we need to run evaluation, into the training loop function
        from tensorflow.python.framework import ops
        from tensorflow.python.framework import constant_op
        from tensorflow.python.ops import math_ops

        with ops.name_scope('train_step'):
            train_step_kwargs = {}

            if FLAGS.max_number_of_steps:
                should_stop_op = math_ops.greater_equal(
                    global_step, FLAGS.max_number_of_steps)
            else:
                should_stop_op = constant_op.constant(False)
            train_step_kwargs['should_stop'] = should_stop_op
            if FLAGS.log_every_n_steps > 0:
                train_step_kwargs['should_log'] = math_ops.equal(
                    math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)
            train_step_kwargs['should_val'] = math_ops.equal(
                math_ops.mod(global_step, FLAGS.val_every_n_steps), 0)
            train_step_kwargs['eval_op'] = list(names_to_updates.values())


#    assert(FLAGS.max_number_of_steps==100000)
        print(should_stop_op)
        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            saver=my_saver,
            train_step_fn=learning_biasCNN.train_step_fn,
            train_step_kwargs=train_step_kwargs)
Exemple #32
0
  def main(self):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set session_config to allow some operations to be run on cpu.
    session_config = tf.ConfigProto(allow_soft_placement=True, )

    with tf.Graph().as_default():
      ######################
      # Select the dataset #
      ######################
      dataset = self._select_dataset()

      ######################
      # Select the network #
      ######################
      networks = self._select_network()

      #####################################
      # Select the preprocessing function #
      #####################################
      image_preprocessing_fn = self._select_image_preprocessing_fn()

      #######################
      # Config model_deploy #
      #######################
      deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

      global_step = slim.create_global_step()

      ##############################################################
      # Create a dataset provider that loads data from the dataset #
      ##############################################################
      data = self._prepare_data(dataset, image_preprocessing_fn, deploy_config, )
      data_batched = self._get_batch(data)
      batch_names = data_batched.keys()
      batch = data_batched.values()

      ###############
      # Is Training #
      ###############
      if FLAGS.is_training:
        if not os.path.isdir(FLAGS.train_dir):
          util_io.touch_folder(FLAGS.train_dir)
        if not os.path.exists(os.path.join(FLAGS.train_dir, FLAGS_FILE_NAME)):
          FLAGS.append_flags_into_file(os.path.join(FLAGS.train_dir, FLAGS_FILE_NAME))

        try:
          batch_queue = slim.prefetch_queue.prefetch_queue(
            batch, capacity=4 * deploy_config.num_clones)
        except ValueError as e:
          tf.logging.warning('Cannot use batch_queue due to error %s', e)
          batch_queue = batch
        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, self._clone_fn,
                                            GeneralModel._dtype_string_to_dtype(FLAGS.variable_dtype),
                                            [networks, batch_queue, batch_names],
                                            {'global_step': global_step,
                                             'is_training': FLAGS.is_training})
        first_clone_scope = deploy_config.clone_scope(0)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        self._end_points_for_debugging = end_points
        self._add_end_point_summaries(end_points, summaries)
        # Add summaries for images, if there are any.
        self._add_image_summaries(end_points, summaries)
        # Add summaries for losses.
        self._add_loss_summaries(first_clone_scope, summaries, end_points)
        # Add summaries for variables.
        for variable in slim.get_model_variables():
          summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
          moving_average_variables = slim.get_model_variables()
          variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
        else:
          moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by generator_network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
          learning_rate = self._configure_learning_rate(self.num_samples, global_step)
          optimizer = self._configure_optimizer(learning_rate)

        if FLAGS.sync_replicas:
          # If sync_replicas is enabled, the averaging will be done in the chief
          # queue runner.
          optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            total_num_replicas=FLAGS.worker_replicas,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
          # Update ops executed locally by trainer.
          update_ops.append(variable_averages.apply(moving_average_variables))

        summaries.add(tf.summary.scalar('learning_rate', learning_rate))
        # Define optimization process.
        train_tensor = self._add_optimization(clones, optimizer, summaries, update_ops, global_step)

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Define train_step with eval every `eval_every_n_steps`.
        def train_step_fn(session, *args, **kwargs):
          self.do_extra_train_step(session, end_points, global_step)
          total_loss, should_stop = slim.learning.train_step(session, *args, **kwargs)
          return [total_loss, should_stop]

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
          train_tensor,
          train_step_fn=train_step_fn,
          logdir=FLAGS.train_dir,
          master=FLAGS.master,
          is_chief=(FLAGS.task == 0),
          init_fn=self._get_init_fn(FLAGS.checkpoint_path, FLAGS.checkpoint_exclude_scopes),
          summary_op=summary_op,
          number_of_steps=FLAGS.max_number_of_steps,
          log_every_n_steps=FLAGS.log_every_n_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs,
          sync_optimizer=optimizer if FLAGS.sync_replicas else None,
          session_config=session_config)
      ##########################
      # Eval, Export or Output #
      ##########################
      else:
        # Write flags file.
        if not os.path.isdir(FLAGS.eval_dir):
          util_io.touch_folder(FLAGS.eval_dir)
        if not os.path.exists(os.path.join(FLAGS.eval_dir, FLAGS_FILE_NAME)):
          FLAGS.append_flags_into_file(os.path.join(FLAGS.eval_dir, FLAGS_FILE_NAME))

        with tf.variable_scope(tf.get_variable_scope(),
                               custom_getter=model_deploy.get_custom_getter(
                                 GeneralModel._dtype_string_to_dtype(FLAGS.variable_dtype)),
                               reuse=False):
          end_points = self._clone_fn(networks, batch_queue=None, batch_names=batch_names, data_batched=data_batched,
                                      is_training=False, global_step=global_step)

        num_batches = int(math.ceil(self.num_samples / float(FLAGS.batch_size)))

        checkpoint_path = util_misc.get_latest_checkpoint_path(FLAGS.checkpoint_path)

        if FLAGS.moving_average_decay:
          variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
          variables_to_restore = variable_averages.variables_to_restore(
            slim.get_model_variables())
          variables_to_restore[global_step.op.name] = global_step
        else:
          variables_to_restore = slim.get_variables_to_restore()

        saver = None
        if variables_to_restore is not None:
          saver = tf.train.Saver(variables_to_restore)

        session_creator = tf.train.ChiefSessionCreator(
          scaffold=tf.train.Scaffold(saver=saver),
          checkpoint_filename_with_path=checkpoint_path,
          master=FLAGS.master,
          config=session_config)

        ##########
        # Output #
        ##########
        if FLAGS.do_output:
          tf.logging.info('Output mode.')
          output_ops = self._maybe_encode_output_tensor(self._define_outputs(end_points, data_batched))
          start_time = time.time()
          with tf.train.MonitoredSession(
              session_creator=session_creator) as session:
            for i in range(num_batches):
              output_results = session.run([item[-1] for item in output_ops])
              self._write_outputs(output_results, output_ops)
              if i % FLAGS.log_every_n_steps == 0:
                current_time = time.time()
                speed = (current_time - start_time) / (i + 1)
                time_left = speed * (num_batches - i + 1)
                tf.logging.info('%d / %d done. Time left: %f', i + 1, num_batches, time_left)


        ################
        # Export Model #
        ################
        elif FLAGS.do_export:
          tf.logging.info('Exporting trained model to %s', FLAGS.export_path)
          with tf.Session(config=session_config) as session:
            saver.restore(session, checkpoint_path)
            builder = tf.saved_model.builder.SavedModelBuilder(FLAGS.export_path)
            signature_def_map = self._build_signature_def_map(end_points, data_batched)
            assets_collection = self._build_assets_collection(end_points, data_batched)
            legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
            builder.add_meta_graph_and_variables(
              session, [tf.saved_model.tag_constants.SERVING],
              signature_def_map=signature_def_map,
              legacy_init_op=legacy_init_op,
              assets_collection=assets_collection,
            )
          builder.save()
          tf.logging.info('Done exporting!')

        ########
        # Eval #
        ########
        else:
          tf.logging.info('Eval mode.')
          # Add summaries for images, if there are any.
          self._add_image_summaries(end_points, None)

          # Define the metrics:
          metric_map = self._define_eval_metrics(end_points, data_batched)

          names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(metric_map)
          names_to_values = collections.OrderedDict(**names_to_values)
          names_to_updates = collections.OrderedDict(**names_to_updates)

          # Print the summaries to screen.
          for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            if len(value.shape):
              op = tf.summary.tensor_summary(summary_name, value, collections=[])
            else:
              op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

          if not (FLAGS.do_eval_debug or FLAGS.do_custom_eval):
            tf.logging.info('Evaluating %s' % checkpoint_path)

            slim.evaluation.evaluate_once(
              master=FLAGS.master,
              checkpoint_path=checkpoint_path,
              logdir=FLAGS.eval_dir,
              num_evals=num_batches,
              eval_op=list(names_to_updates.values()),
              variables_to_restore=variables_to_restore,
              session_config=session_config)
            return

          ################################
          # `do_eval_debug` flag is true.#
          ################################
          if FLAGS.do_eval_debug:
            eval_ops = list(names_to_updates.values())
            eval_names = list(names_to_updates.keys())

            # Items to write to a html page.
            encode_ops = self._maybe_encode_output_tensor(self.get_items_to_encode(end_points, data_batched))

            with tf.train.MonitoredSession(session_creator=session_creator) as session:
              if eval_ops is not None:
                for i in range(num_batches):
                  eval_result = session.run(eval_ops, None)
                  print('; '.join(('%s:%s' % (name, str(eval_result[i])) for i, name in enumerate(eval_names))))

              # Write to HTML
              if encode_ops:
                for i in range(num_batches):
                  encode_ops_feed_dict = self._get_encode_op_feed_dict(end_points, encode_ops, i)
                  encoded_items = session.run([item[-1] for item in encode_ops], encode_ops_feed_dict)
                  encoded_list = []
                  for j in range(len(encoded_items)):
                    encoded_list.append((encode_ops[j][0], encode_ops[j][1], encoded_items[j].tolist()))

                  eval_items = self.save_images(encoded_list, os.path.join(FLAGS.eval_dir, 'images'))
                  eval_items = self.to_human_friendly(eval_items, )
                  self._write_eval_html(eval_items)
                  if i % 10 == 0:
                    tf.logging.info('%d/%d' % (i, num_batches))
          if FLAGS.do_custom_eval:
            extra_eval = self._define_extra_eval_actions(end_points, data_batched)
            with tf.train.MonitoredSession(session_creator=session_creator) as session:
              self._do_extra_eval_actions(session, extra_eval)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      with tf.device(deploy_config.inputs_device()):
        images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        tf.losses.softmax_cross_entropy(
            logits=end_points['AuxLogits'], onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
      tf.losses.softmax_cross_entropy(
          logits=logits, onehot_labels=labels,
          label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')


    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Exemple #34
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    ###
    # get teacher info.
    ###
    teacher_dir = utils.shell_path(args.teacher_dir)
    assert tf.gfile.IsDirectory(teacher_dir)
    json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json'))
    assert len(json_in_dir) == 1
    te_json = json_in_dir[0]
    te_ckpt = tf.train.latest_checkpoint(teacher_dir)
    assert tf.train.checkpoint_exists(te_ckpt)

    with open(te_json, 'rt') as F:
        configs = json.load(F)
    te_hparams = Namespace(**configs)
    teacher = wavenet.Wavenet(te_hparams)

    ###
    # get student info.
    ###
    if args.config is None:
        raise RuntimeError('No config json specified.')
    with open(args.config, 'rt') as F:
        configs = json.load(F)
    st_hparams = Namespace(**configs)
    pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher,
                                           args.train_path)

    def _model_fn(_inputs_dict):
        ff_dict = pwn.feed_forward(_inputs_dict)
        ff_dict.update(_inputs_dict)
        loss_dict = pwn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

        tf.summary.scalar("kl_loss", loss_dict['kl_loss'])
        tf.summary.scalar("H_Ps", loss_dict['H_Ps'])
        tf.summary.scalar("H_Ps_Pt", loss_dict['H_Ps_Pt'])
        if 'power_loss' in loss_dict:
            tf.summary.scalar('power_loss', loss_dict['power_loss'])

    logdir = args.logdir
    tf.logging.info('Saving to {}'.format(logdir))

    os.makedirs(logdir, exist_ok=True)
    shutil.copy(args.config, logdir)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = pwn.get_batch(clone_batch_size)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        ###
        # variables to train
        ###
        st_vars = [
            var for var in tf.trainable_variables() if 'iaf' in var.name
        ]

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(pwn.learning_rate_schedule[0])
            for key, value in pwn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=st_vars)
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(st_vars))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        ###
        # restore teacher
        ###
        te_vars = [
            var for var in tf.trainable_variables() if 'iaf' not in var.name
        ]
        # teacher use EMA
        te_vars = {
            '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv
            for tv in te_vars
        }
        restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
            te_ckpt, te_vars)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            logdir=logdir,
                            number_of_steps=pwn.num_iters,
                            summary_op=summary_op,
                            global_step=global_step,
                            log_every_n_steps=100,
                            save_summaries_secs=600,
                            save_interval_secs=3600,
                            session_config=session_config,
                            init_fn=restore_init_fn)
Exemple #36
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the training graph is
      completely built. This is helpful to perform additional changes to the
      training graph such as optimizing batchnorm. The function should modify
      the default graph.
  """

  detection_model = create_model_fn()
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]

  with tf.Graph().as_default():
    # Build a configuration specifying multi-GPU and multi-replicas.
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=worker_replicas,
        num_ps_tasks=ps_tasks,
        worker_job_name=worker_job_name)

    # Place the global step on the device storing the variables.
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    with tf.device(deploy_config.inputs_device()):
      input_queue = create_input_queue(
          train_config.batch_size // num_clones, create_tensor_dict_fn,
          train_config.batch_queue_capacity,
          train_config.num_batch_queue_threads,
          train_config.prefetch_queue_capacity, data_augmentation_options)

    # Gather initial summaries.
    # TODO(rathodv): See if summaries can be added/extracted from global tf
    # collections so that they don't have to be passed around.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    global_summaries = set([])

    model_fn = functools.partial(_create_losses,
                                 create_model_fn=create_model_fn,
                                 train_config=train_config)
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
    first_clone_scope = clones[0].scope

    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    with tf.device(deploy_config.optimizer_device()):
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)
      for var in optimizer_summary_vars:
        tf.summary.scalar(var.op.name, var, family='LearningRate')

    sync_optimizer = None
    if train_config.sync_replicas:
      training_optimizer = tf.train.SyncReplicasOptimizer(
          training_optimizer,
          replicas_to_aggregate=train_config.replicas_to_aggregate,
          total_num_replicas=worker_replicas)
      sync_optimizer = training_optimizer

    with tf.device(deploy_config.optimizer_device()):
      regularization_losses = (None if train_config.add_regularization_loss
                               else [])
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, training_optimizer,
          regularization_losses=regularization_losses)
      total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

      # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
      if train_config.bias_grad_multiplier:
        biases_regex_list = ['.*/biases']
        grads_and_vars = variables_helper.multiply_gradients_matching_regex(
            grads_and_vars,
            biases_regex_list,
            multiplier=train_config.bias_grad_multiplier)

      # Optionally freeze some layers by setting their gradients to be zero.
      if train_config.freeze_variables:
        grads_and_vars = variables_helper.freeze_gradients_matching_regex(
            grads_and_vars, train_config.freeze_variables)

      # Optionally clip gradients
      if train_config.gradient_clipping_by_norm > 0:
        with tf.name_scope('clip_grads'):
          grads_and_vars = slim.learning.clip_gradient_norms(
              grads_and_vars, train_config.gradient_clipping_by_norm)

      # Create gradient updates.
      grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                        global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops, name='update_barrier')
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    if graph_hook_fn:
      with tf.device(deploy_config.variables_device()):
        graph_hook_fn()

    # Add summaries.
    for model_var in slim.get_model_variables():
      global_summaries.add(tf.summary.histogram('ModelVars/' +
                                                model_var.op.name, model_var))
    for loss_tensor in tf.losses.get_losses():
      global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                             loss_tensor))
    global_summaries.add(
        tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))
    summaries |= global_summaries

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Save checkpoints regularly.
    keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
    saver = tf.train.Saver(
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

    # Create ops required to initialize the model from a given checkpoint.
    init_fn = None
    if train_config.fine_tune_checkpoint:
      if not train_config.fine_tune_checkpoint_type:
        # train_config.from_detection_checkpoint field is deprecated. For
        # backward compatibility, fine_tune_checkpoint_type is set based on
        # from_detection_checkpoint.
        if train_config.from_detection_checkpoint:
          train_config.fine_tune_checkpoint_type = 'detection'
        else:
          train_config.fine_tune_checkpoint_type = 'classification'
      var_map = detection_model.restore_map(
          fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
          load_all_detection_checkpoint_vars=(
              train_config.load_all_detection_checkpoint_vars))
      available_var_map = (variables_helper.
                           get_variables_available_in_checkpoint(
                               var_map, train_config.fine_tune_checkpoint))
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn

    slim.learning.train(
        train_tensor,
        logdir=train_dir,
        master=master,
        is_chief=is_chief,
        session_config=session_config,
        startup_delay_steps=train_config.startup_delay_steps,
        init_fn=init_fn,
        summary_op=summary_op,
        number_of_steps=(
            train_config.num_steps if train_config.num_steps else None),
        save_summaries_secs=120,
        sync_optimizer=sync_optimizer,
        saver=saver)
Exemple #37
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=4,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (dataset.get_one_shot_iterator(), {
          common.OUTPUT_TYPE: dataset.num_of_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in tf.model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_of_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy,
          FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step,
          FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps,
          FLAGS.learning_power,
          FLAGS.slow_start_step,
          FLAGS.slow_start_learning_rate,
          decay_steps=FLAGS.decay_steps,
          end_learning_rate=FLAGS.end_learning_rate)

      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

      if FLAGS.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      elif FLAGS.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
      else:
        raise ValueError('Unknown optimizer')

    if FLAGS.quantize_delay_step >= 0:
      if FLAGS.num_clones > 1:
        raise ValueError('Quantization doesn\'t support multi-clone yet.')
      contrib_quantize.create_training_graph(
          quant_delay=FLAGS.quantize_delay_step)

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    profile_dir = FLAGS.profile_logdir
    if profile_dir is not None:
      tf.gfile.MakeDirs(profile_dir)

    with contrib_tfprof.ProfileContext(
        enabled=profile_dir is not None, profile_dir=profile_dir):
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_logdir,
          log_every_n_steps=FLAGS.log_steps,
          master=FLAGS.master,
          number_of_steps=FLAGS.training_number_of_steps,
          is_chief=(FLAGS.task == 0),
          session_config=session_config,
          startup_delay_steps=startup_delay_steps,
          init_fn=init_fn,
          summary_op=summary_op,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
Exemple #38
0
def main(unused_argv):
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    with tf.Graph().as_default():
        with tf.device(config.inputs_device()):
            train_crop_size = (None if 0 in FLAGS.train_crop_size else
                               FLAGS.train_crop_size)
            assert FLAGS.dataset
            assert len(FLAGS.dataset) == len(FLAGS.dataset_dir)
            if len(FLAGS.first_frame_finetuning) == 1:
                first_frame_finetuning = (list(FLAGS.first_frame_finetuning) *
                                          len(FLAGS.dataset))
            else:
                first_frame_finetuning = FLAGS.first_frame_finetuning
            if len(FLAGS.three_frame_dataset) == 1:
                three_frame_dataset = (list(FLAGS.three_frame_dataset) *
                                       len(FLAGS.dataset))
            else:
                three_frame_dataset = FLAGS.three_frame_dataset
            assert len(FLAGS.dataset) == len(first_frame_finetuning)
            assert len(FLAGS.dataset) == len(three_frame_dataset)
            datasets, samples_list = zip(*[
                _get_dataset_and_samples(
                    config, train_crop_size, dataset, dataset_dir,
                    bool(first_frame_finetuning_), bool(three_frame_dataset_))
                for dataset, dataset_dir,
                first_frame_finetuning_, three_frame_dataset_ in zip(
                    FLAGS.dataset, FLAGS.dataset_dir, first_frame_finetuning,
                    three_frame_dataset)
            ])
            # Note that this way of doing things is wasteful since it will evaluate
            # all branches but just use one of them. But let's do it anyway for now,
            # since it's easy and will probably be fast enough.
            dataset = datasets[0]
            if len(samples_list) == 1:
                samples = samples_list[0]
            else:
                probabilities = FLAGS.dataset_sampling_probabilities
                if probabilities:
                    assert len(probabilities) == len(samples_list)
                else:
                    # Default to uniform probabilities.
                    probabilities = [
                        1.0 / len(samples_list) for _ in samples_list
                    ]
                probabilities = tf.constant(probabilities)
                logits = tf.log(probabilities[tf.newaxis])
                rand_idx = tf.squeeze(tf.multinomial(logits,
                                                     1,
                                                     output_dtype=tf.int32),
                                      axis=[0, 1])

                def wrap(x):
                    def f():
                        return x

                    return f

                samples = tf.case(
                    {
                        tf.equal(rand_idx, idx): wrap(s)
                        for idx, s in enumerate(samples_list)
                    },
                    exclusive=True)

            # Prefetch_queue requires the shape to be known at graph creation time.
            # So we only use it if we crop to a fixed size.
            if train_crop_size is None:
                inputs_queue = samples
            else:
                inputs_queue = prefetch_queue.prefetch_queue(
                    samples,
                    capacity=FLAGS.prefetch_queue_capacity_factor *
                    config.num_clones,
                    num_threads=FLAGS.prefetch_queue_num_threads)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            if FLAGS.classification_loss == 'triplet':
                embedding_dim = FLAGS.embedding_dimension
                output_type_to_dim = {'embedding': embedding_dim}
            else:
                output_type_to_dim = {common.OUTPUT_TYPE: dataset.num_classes}
            model_args = (inputs_queue, output_type_to_dim,
                          dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.contrib.framework.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            with tf.name_scope('grad_clipping'):
                grads_and_vars = slim.learning.clip_gradient_norms(
                    grads_and_vars, 5.0)

            # Create histogram summaries for the gradients.
            # We have too many summaries for mldash, so disable this one for now.
            # for grad, var in grads_and_vars:
            #   summaries.add(tf.summary.histogram(
            #       var.name.replace(':0', '_0') + '/gradient', grad))

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
Exemple #39
0
def main(_):

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks,
        )

    # Create global_step
    with tf.device(deploy_config.variables_device()):
        global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    keys_to_features = {
        "image/encoded":
        tf.FixedLenFeature((), tf.string, default_value=""),
        "image/format":
        tf.FixedLenFeature((), tf.string, default_value="png"),
        "image/class/label":
        tf.FixedLenFeature([],
                           tf.int64,
                           default_value=tf.zeros([], dtype=tf.int64)),
    }

    items_to_handlers = {
        "image": slim.tfexample_decoder.Image(),
        "label": slim.tfexample_decoder.Tensor("image/class/label"),
    }

    items_to_descs = {
        "image": "Color image",
        "label": "Class idx",
    }

    label_idx_to_name = {}
    for i, label in enumerate(CLASSES):
        label_idx_to_name[i] = label

    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                      items_to_handlers)
    file_pattern = "tfm_clf_%s.*"
    file_pattern = os.path.join(FLAGS.records_name,
                                file_pattern % FLAGS.dataset_split_name)
    dataset = slim.dataset.Dataset(
        data_sources=file_pattern,  # TODO UPDATE
        reader=tf.TFRecordReader,
        decoder=decoder,
        num_samples=80000,  # TODO UPDATE
        items_to_descriptions=items_to_descs,
        num_classes=len(CLASSES),
        labels_to_names=label_idx_to_name,
    )

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True,
    )

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True,
        use_grayscale=FLAGS.use_grayscale)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size,
        )
        [image, label] = provider.get(["image", "label"])
        label -= FLAGS.labels_offset

        train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, train_image_size,
                                       train_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size,
        )
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
        """Allows data parallelism by creating network_fn clones."""
        images, labels = batch_queue.dequeue()
        logits, end_points = network_fn(images)

        #############################
        # Specify the loss function #
        #############################
        if "AuxLogits" in end_points:
            slim.losses.softmax_cross_entropy(
                end_points["AuxLogits"],
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=0.4,
                scope="aux_loss",
            )
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
        return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram("activations/" + end_point, x))
        summaries.add(
            tf.summary.scalar("sparsity/" + end_point, tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
        summaries.add(tf.summary.scalar("losses/%s" % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
        summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
        moving_average_variables = slim.get_model_variables()
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
    else:
        moving_average_variables, variable_averages = None, None

    if FLAGS.quantize_delay >= 0:
        contrib_quantize.create_training_graph(
            quant_delay=FLAGS.quantize_delay)

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples,
                                                 global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar("learning_rate", learning_rate))

    if FLAGS.sync_replicas:
        # If sync_replicas is enabled, the averaging will be done in the chief
        # queue runner.
        optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            total_num_replicas=FLAGS.worker_replicas,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables,
        )
    elif FLAGS.moving_average_decay:
        # Update ops executed locally by trainer.
        update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones, optimizer, var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar("total_loss", total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name="train_op")

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name="summary_op")

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None,
    )
Exemple #40
0
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    if args.log_root:
        if args.config is None:
            raise RuntimeError('No config json specified.')
        config_json = args.config
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        hparams = Namespace(**configs)
        logdir_name = config_str.get_config_time_str(hparams, 'wavenet',
                                                     EXP_TAG)
        logdir = os.path.join(args.log_root, logdir_name)
        os.makedirs(logdir, exist_ok=True)
        shutil.copy(config_json, logdir)
    else:
        logdir = args.logdir
        config_json = glob.glob(os.path.join(logdir, '*.json'))[0]
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        hparams = Namespace(**configs)

    enhance_log.add_log_file(logdir)
    if not args.log_root:
        tf.logging.info('Continue running\n\n')
    tf.logging.info('using config form {}'.format(config_json))
    tf.logging.info('Saving to {}'.format(logdir))

    wn = wavenet.Wavenet(hparams,
                         os.path.abspath(os.path.expanduser(args.train_path)))
    wn_config_str = enhance_log.instance_attr_to_str(wn)
    tf.logging.info('\n' + wn_config_str)

    def _data_dep_init():
        # slim.learning.train runs init_fn earlier than start_queue_runner
        # so the the function got dead locker if use the `input_dict` in L76 as input
        inputs_val = reader.get_init_batch(wn.train_path,
                                           batch_size=args.total_batch_size,
                                           seq_len=wn.wave_length)
        wave_data = inputs_val['wav']
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'wav': tf.placeholder(dtype=tf.float32, shape=wave_data.shape),
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)
        }

        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        init_ff_dict = wn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Calculate initial statistics.')
            init_out = session.run(init_ff_dict,
                                   feed_dict={
                                       _inputs_dict['wav']: wave_data,
                                       _inputs_dict['mel']: mel_data
                                   })
            init_out_params = init_out['out_params']
            if wn.loss_type == 'mol':
                _, mean, log_scale = np.split(init_out_params, 3, axis=2)
                scale = np.exp(np.maximum(log_scale, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(scale, 'scale')
            elif wn.loss_type == 'gauss':
                mean, log_std = np.split(init_out_params, 2, axis=2)
                std = np.exp(np.maximum(log_std, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(std, 'std')
            tf.logging.info('Done Calculate initial statistics.')

        return callback

    def _model_fn(_inputs_dict):
        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        ff_dict = wn.feed_forward(_inputs_dict)
        ff_dict.update(encode_dict)
        loss_dict = wn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = wn.get_batch(clone_batch_size)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(wn.learning_rate_schedule[0])
            for key, value in wn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)

            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=tf.trainable_variables())
            if GRAD_CLIP:
                clone_grads_vars = grad_clip(clone_grads_vars)
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(tf.trainable_variables()))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')
        data_dep_init_fn = _data_dep_init()

        slim.learning.train(train_tensor,
                            logdir=logdir,
                            number_of_steps=wn.num_iters,
                            summary_op=summary_op,
                            global_step=global_step,
                            log_every_n_steps=100,
                            save_summaries_secs=600,
                            save_interval_secs=3600,
                            session_config=session_config,
                            init_fn=data_dep_init_fn)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpus
  if FLAGS.num_clones == -1:
    FLAGS.num_clones = len(FLAGS.gpus.split(','))

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    # tf.set_random_seed(42)
    tf.set_random_seed(0)
    ######################
    # Config model_deploy#
    ######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name,
        FLAGS.dataset_dir.split(','),
        dataset_list_dir=FLAGS.dataset_list_dir,
        num_samples=FLAGS.frames_per_video,
        modality=FLAGS.modality,
        split_id=FLAGS.split_id)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        batch_size=FLAGS.batch_size,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        dropout_keep_prob=(1.0-FLAGS.dropout),
        pooled_dropout_keep_prob=(1.0-FLAGS.pooled_dropout),
        batch_norm=FLAGS.netvlad_batch_norm)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)  # in case of pooling images,
                           # now preprocessing is done video-level

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size,
        bgr_flips=FLAGS.bgr_flip)
      [image, label] = provider.get(['image', 'label'])
      # now note that the above image might be a 23 channel image if you have
      # both RGB and flow streams. It will need to split later, but all the
      # preprocessing will be done consistently for all frames over all streams
      label = tf.string_to_number(label, tf.int32)
      label.set_shape(())
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      scale_ratios=[float(el) for el in FLAGS.scale_ratios.split(',')],
      image = image_preprocessing_fn(image, train_image_size,
                                     train_image_size,
                                     scale_ratios=scale_ratios,
                                     out_dim_scale=FLAGS.out_dim_scale,
                                     model_name=FLAGS.model_name)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      if FLAGS.debug:
        images = tf.Print(images, [labels], 'Read batch')
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)
      summarize_images(images, provider.num_channels_stream)

    ####################
    # Define the model #
    ####################
    kwargs = {}
    if FLAGS.conv_endpoint is not None:
      kwargs['conv_endpoint'] = FLAGS.conv_endpoint
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(
          images, pool_type=FLAGS.pooling,
          classifier_type=FLAGS.classifier_type,
          num_channels_stream=provider.num_channels_stream,
          netvlad_centers=FLAGS.netvlad_initCenters.split(','),
          stream_pool_type=FLAGS.stream_pool_type,
          **kwargs)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=FLAGS.label_smoothing, weight=0.4, scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weight=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    global end_points_debug
    end_points = clones[0].outputs
    end_points_debug = dict(end_points)
    end_points_debug['images'] = images
    end_points_debug['labels'] = labels
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.histogram_summary('activations/' + end_point, x))
      summaries.add(tf.scalar_summary('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.histogram_summary(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.scalar_summary('learning_rate', learning_rate,
                                      name='learning_rate'))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()
    logging.info('Training the following variables: %s' % (
      ' '.join([el.name for el in variables_to_train])))

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)

    # clip the gradients if needed
    if FLAGS.clip_gradients > 0:
      logging.info('Clipping gradients by %f' % FLAGS.clip_gradients)
      with tf.name_scope('clip_gradients'):
        clones_gradients = slim.learning.clip_gradient_norms(
            clones_gradients,
            FLAGS.clip_gradients)

    # Add total_loss to summary.
    summaries.add(tf.scalar_summary('total_loss', total_loss,
                                    name='total_loss'))

    # Create gradient updates.
    train_ops = {}
    if FLAGS.iter_size == 1:
      grad_updates = optimizer.apply_gradients(clones_gradients,
                                               global_step=global_step)
      update_ops.append(grad_updates)

      update_op = tf.group(*update_ops)
      train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                        name='train_op')
      train_ops = train_tensor
    else:
      gvs = [(grad, var) for grad, var in clones_gradients]
      varnames = [var.name for grad, var in gvs]
      varname_to_var = {var.name: var for grad, var in gvs}
      varname_to_grad = {var.name: grad for grad, var in gvs}
      varname_to_ref_grad = {}
      for vn in varnames:
        grad = varname_to_grad[vn]
        print("accumulating ... ", (vn, grad.get_shape()))
        with tf.variable_scope("ref_grad"):
          with tf.device(deploy_config.variables_device()):
            ref_var = slim.local_variable(
                np.zeros(grad.get_shape(),dtype=np.float32),
                name=vn[:-2])
            varname_to_ref_grad[vn] = ref_var

      all_assign_ref_op = [ref.assign(varname_to_grad[vn]) for vn, ref in varname_to_ref_grad.items()]
      all_assign_add_ref_op = [ref.assign_add(varname_to_grad[vn]) for vn, ref in varname_to_ref_grad.items()]
      assign_gradients_ref_op = tf.group(*all_assign_ref_op)
      accmulate_gradients_op = tf.group(*all_assign_add_ref_op)
      with tf.control_dependencies([accmulate_gradients_op]):
        final_gvs = [(varname_to_ref_grad[var.name] / float(FLAGS.iter_size), var) for grad, var in gvs]
        apply_gradients_op = optimizer.apply_gradients(final_gvs, global_step=global_step)
        update_ops.append(apply_gradients_op)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
            total_loss, name='train_op')
      for i in range(FLAGS.iter_size):
        if i == 0:
          train_ops[i] = assign_gradients_ref_op
        elif i < FLAGS.iter_size - 1:  # because apply_gradients also computes
                                       # (see control_dependency), so
                                       # no need of running an extra iteration
          train_ops[i] = accmulate_gradients_op
        else:
          train_ops[i] = train_tensor


    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.merge_summary(list(summaries), name='summary_op')

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.intra_op_parallelism_threads = FLAGS.cpu_threads
    # config.allow_soft_placement = True
    # config.gpu_options.per_process_gpu_memory_fraction=0.7

    ###########################
    # Kicks off the training. #
    ###########################
    logging.info('RUNNING ON SPLIT %d' % FLAGS.split_id)
    slim.learning.train(
        train_ops,
        train_step_fn=train_step,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None,
        session_config=config)
Exemple #42
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the training graph is
      completely built. This is helpful to perform additional changes to the
      training graph such as optimizing batchnorm. The function should modify
      the default graph.
  """

    detection_model = create_model_fn()

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(create_tensor_dict_fn)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            restore_checkpoints = [
                path.strip()
                for path in train_config.fine_tune_checkpoint.split(',')
            ]

            restorers = get_restore_checkpoint_ops(restore_checkpoints,
                                                   detection_model,
                                                   train_config)

            def initializer_fn(sess):
                for i, restorer in enumerate(restorers):
                    restorer.restore(sess, restore_checkpoints[i])

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (
                None if train_config.add_regularization_loss else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                0.9999, global_step)
            update_ops.append(
                variable_averages.apply(moving_average_variables))

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, 'critic_loss'))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      with tf.device(deploy_config.inputs_device()):
        images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        tf.losses.softmax_cross_entropy(
            logits=end_points['AuxLogits'], onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
      tf.losses.softmax_cross_entropy(
          logits=logits, onehot_labels=labels,
          label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=1,
        clone_on_cpu=False,
        replica_id=0,
        num_replicas=1,
        num_ps_tasks=0)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        'flowers', 'train', FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        'mobilenet_v1',
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        'mobilenet_v1',
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=4,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=4,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):

      num_epochs_per_decay = 2.5
      decay_steps = int(dataset.num_samples / FLAGS.batch_size *
                        num_epochs_per_decay)
      learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                  global_step,
                                  decay_steps,
                                  _LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True,
                                    name='exponential_decay_learning_rate')

      optimizer = tf.train.RMSPropOptimizer(
                           learning_rate,
                           decay=FLAGS.rmsprop_decay,
                           momentum=FLAGS.rmsprop_momentum,
                           epsilon=FLAGS.opt_epsilon)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=True,
        session_config=session_config,
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=10,
        save_summaries_secs=300,
        save_interval_secs=300,
        sync_optimizer=optimizer if False else None)