Пример #1
0
  def testCreateOnecloneWithPS(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = BatchNormClassifier
      model_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1,
                                                    num_ps_tasks=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
      self.assertEqual(len(slim.get_variables()), 5)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(len(update_ops), 2)

      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
                                                                optimizer)
      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
      self.assertEqual(total_loss.op.name, 'total_loss')
      for g, v in grads_and_vars:
        self.assertDeviceEqual(g.device, '/job:worker/device:GPU:0')
        self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
Пример #2
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')
    # Sets the threshold for what messages will be logged. (DEBUG / INFO / WARN / ERROR / FATAL)
    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)

        # Create global_step, the training iteration counter.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # Select the dataset.
        dataset = TFrecords2Dataset.get_datasets(FLAGS.dataset_dir)

        # Get the TextBoxes++ network and its anchors.
        text_net = txtbox_384.TextboxNet()

        # Stage 2 training using the 768x768 input size.
        if FLAGS.large_training:
            # replace the input image shape and the extracted feature map size from each indicated layer which
            #associated to each textbox layer.
            text_net.params = text_net.params._replace(img_shape=(768, 768))
            text_net.params = text_net.params._replace(
                feat_shapes=[(96, 96), (48, 48), (24, 24), (12,
                                                            12), (10,
                                                                  10), (8, 8)])

        img_shape = text_net.params.img_shape
        print('img_shape: ' + str(img_shape))

        # Compute the default anchor boxes with the given image shape, get anchor list.
        text_anchors = text_net.anchors(img_shape)

        # Print the training configuration before training.
        tf_utils.print_configuration(FLAGS.__flags, text_net.params,
                                     dataset.data_sources, FLAGS.train_dir)

        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device(deploy_config.inputs_device()):
            # setting the dataset provider
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=1000 * FLAGS.batch_size,
                    common_queue_min=300 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.
            [image, shape, glabels, gbboxes, x1, x2, x3, x4, y1, y2, y3,
             y4] = provider.get([
                 'image', 'shape', 'object/label', 'object/bbox',
                 'object/oriented_bbox/x1', 'object/oriented_bbox/x2',
                 'object/oriented_bbox/x3', 'object/oriented_bbox/x4',
                 'object/oriented_bbox/y1', 'object/oriented_bbox/y2',
                 'object/oriented_bbox/y3', 'object/oriented_bbox/y4'
             ])
            gxs = tf.transpose(tf.stack([x1, x2, x3, x4]))  #shape = (N,4)
            gys = tf.transpose(tf.stack([y1, y2, y3, y4]))
            image = tf.identity(image, 'input_image')
            init_op = tf.global_variables_initializer()
            # tf.global_variables_initializer()

            # Pre-processing image, labels and bboxes.
            training_image_crop_area = FLAGS.training_image_crop_area
            area_split = training_image_crop_area.split(',')
            assert len(area_split) == 2
            training_image_crop_area = [
                float(area_split[0]),
                float(area_split[1])
            ]

            image, glabels, gbboxes, gxs, gys= \
                ssd_vgg_preprocessing.preprocess_for_train(image, glabels, gbboxes, gxs, gys,
                                                        img_shape,
                                                        data_format='NHWC', crop_area_range=training_image_crop_area)

            # Encode groundtruth labels and bboxes.
            image = tf.identity(image, 'processed_image')

            glocalisations, gscores, glabels = \
                text_net.bboxes_encode( glabels, gbboxes, text_anchors, gxs, gys)
            batch_shape = [1] + [len(text_anchors)] * 3

            # Training batches and queue.
            r = tf.train.batch(tf_utils.reshape_list(
                [image, glocalisations, gscores, glabels]),
                               batch_size=FLAGS.batch_size,
                               num_threads=FLAGS.num_preprocessing_threads,
                               capacity=5 * FLAGS.batch_size)

            b_image, b_glocalisations, b_gscores, b_glabels= \
                tf_utils.reshape_list(r, batch_shape)

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list(
                    [b_image, b_glocalisations, b_gscores, b_glabels]),
                capacity=2 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):

            #Allows data parallelism by creating multiple
            #clones of network_fn.
            # Dequeue batch.
            b_image, b_glocalisations, b_gscores, b_glabels = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            # Construct TextBoxes network.
            arg_scope = text_net.arg_scope(weight_decay=FLAGS.weight_decay)
            with slim.arg_scope(arg_scope):
                predictions,localisations, logits, end_points = \
                    text_net.net(b_image, is_training=True)
            # Add loss function.

            text_net.losses(logits,
                            localisations,
                            b_glabels,
                            b_glocalisations,
                            b_gscores,
                            match_threshold=FLAGS.match_threshold,
                            negative_ratio=FLAGS.negative_ratio,
                            alpha=FLAGS.loss_alpha,
                            label_smoothing=FLAGS.label_smoothing,
                            batch_size=FLAGS.batch_size)
            return end_points

        # Gather initial tensorboard summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # =================================================================== #
        # Add summaries from first clone.
        # =================================================================== #
        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        # Add summaries for extra losses.
        for loss in tf.get_collection('EXTRA_LOSSES'):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # =================================================================== #
        # Configure the optimization procedure.
        # =================================================================== #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, dataset.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            # Add summaries for learning_rate.
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        # and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)

        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True,
                                gpu_options=gpu_options)

        saver = tf.train.Saver(max_to_keep=100,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master='',
            is_chief=True,
            # init_op=init_op,
            init_fn=tf_utils.get_init_fn(FLAGS),
            summary_op=summary_op,  ##output variables to logdir
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            saver=saver,
            save_interval_secs=FLAGS.save_interval_secs,
            session_config=config,
            sync_optimizer=None)
Пример #3
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################

        def calculate_pooling_center_loss(features, label, alfa, nrof_classes,
                                          weights, name):
            features = tf.reshape(features, [features.shape[0], -1])
            label = tf.argmax(label, 1)

            nrof_features = features.get_shape()[1]
            centers = tf.get_variable(name, [nrof_classes, nrof_features],
                                      dtype=tf.float32,
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
            label = tf.reshape(label, [-1])
            centers_batch = tf.gather(centers, label)
            centers_batch = tf.nn.l2_normalize(centers_batch, axis=-1)

            diff = (1 - alfa) * (centers_batch - features)
            centers = tf.scatter_sub(centers, label, diff)

            with tf.control_dependencies([centers]):
                distance = tf.square(features - centers_batch)
                distance = tf.reduce_sum(distance, axis=-1)
                center_loss = tf.reduce_mean(distance)

            center_loss = tf.identity(center_loss * weights,
                                      name=name + '_loss')
            return center_loss

        def attention_crop(attention_maps):
            batch_size, height, width, num_parts = attention_maps.shape
            bboxes = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                part_weights = part_weights / (np.sum(part_weights) + 1e-6)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]

                mask = attention_map[:, :, selected_index]

                threshold = random.uniform(0.4, 0.6)
                itemindex = np.where(mask >= mask.max() * threshold)

                ymin = itemindex[0].min() / height - 0.1
                ymax = itemindex[0].max() / height + 0.1
                xmin = itemindex[1].min() / width - 0.1
                xmax = itemindex[1].max() / width + 0.1

                bbox = np.asarray([ymin, xmin, ymax, xmax], dtype=np.float32)
                bboxes.append(bbox)
            bboxes = np.asarray(bboxes, np.float32)
            return bboxes

        def attention_drop(attention_maps):
            batch_size, height, width, num_parts = attention_maps.shape
            masks = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                part_weights = part_weights / np.sum(part_weights)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]
                mask = attention_map[:, :, selected_index:selected_index + 1]

                # soft mask
                threshold = random.uniform(0.2, 0.5)
                mask = (mask < threshold * mask.max()).astype(np.float32)
                masks.append(mask)
            masks = np.asarray(masks, dtype=np.float32)
            return masks

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits_1, end_points_1 = network_fn(images)

            attention_maps = end_points_1['attention_maps']
            attention_maps = tf.image.resize_bilinear(
                attention_maps, [train_image_size, train_image_size])

            # attention crop
            bboxes = tf.py_func(attention_crop, [attention_maps], [tf.float32])
            bboxes = tf.reshape(bboxes, [FLAGS.batch_size, 4])
            box_ind = tf.range(FLAGS.batch_size, dtype=tf.int32)
            images_crop = tf.image.crop_and_resize(
                images,
                bboxes,
                box_ind,
                crop_size=[train_image_size, train_image_size])

            # attention drop
            masks = tf.py_func(attention_drop, [attention_maps], [tf.float32])
            masks = tf.reshape(
                masks,
                [FLAGS.batch_size, train_image_size, train_image_size, 1])
            images_drop = images * masks

            logits_2, end_points_2 = network_fn(images_crop, reuse=True)
            logits_3, end_points_3 = network_fn(images_drop, reuse=True)

            slim.losses.softmax_cross_entropy(logits_1,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_1')
            slim.losses.softmax_cross_entropy(logits_2,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_2')
            slim.losses.softmax_cross_entropy(logits_3,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_3')

            embeddings = end_points_1['embeddings']
            center_loss = calculate_pooling_center_loss(
                features=embeddings,
                label=labels,
                alfa=0.95,
                nrof_classes=dataset.num_classes,
                weights=1.0,
                name='center_loss')
            slim.losses.add_loss(center_loss)
            return end_points_1

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        config.gpu_options.allow_growth = True
        # config.gpu_options.per_process_gpu_memory_fraction = 0.9
        config.gpu_options.visible_device_list = FLAGS.gpus

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            session_config=config)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
    if FLAGS.balance_pos:
      dataset_pos = dataset_factory.get_dataset(
          FLAGS.dataset_name, FLAGS.dataset_split_name_pos, FLAGS.dataset_dir)
    dataset_val = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name_val, FLAGS.dataset_dir)
   # num_batches_val = int(dataset_val.num_samples / FLAGS.batch_size)
    num_batches_val = int(math.ceil(dataset_val.num_samples / float(FLAGS.batch_size)))


    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    network_fn_val = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=False)
    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)
    image_preprocessing_fn_val = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
##################################################################3

#      [image, label] = provider.get(['image', 'label'])
      [image, label, age, exposure, current, thickness, compression] = provider.get(['image', 'label', 'age', 'exposure', 'current', 'thickness', 'compression'])
      label -= FLAGS.labels_offset
      if FLAGS.balance_pos:
        provider_pos = slim.dataset_data_provider.DatasetDataProvider(
            dataset_pos,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size)
        [image_pos, label_pos, age_pos, exposure_pos, current_pos, thickness_pos, compression_pos] = provider_pos.get(['image', 'label', 'age', 'exposure', 'current', 'thickness', 'compression'])
        label_pos -= FLAGS.labels_offset
      provider_val = slim.dataset_data_provider.DatasetDataProvider(
          dataset_val,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image_val, label_val, age_val, exposure_val, current_val, thickness_val, compression_val] = provider_val.get(['image', 'label', 'age', 'exposure', 'current', 'thickness', 'compression'])
      label_val -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)
#
#      def add_parameter(im,ag):
#        im = im[:,:,0]
#        ag = ag*tf.ones(tf.shape(im),dtype=im.dtype)
#        im = tf.expand_dims(im,axis=2)
#        ag = tf.expand_dims(ag,axis=2)
#        output = tf.concat([im,im],axis=2)
#        output = tf.concat([output,ag],axis=2)
#        return output

      def add_parameter(im,param):
        param = param*tf.ones(tf.shape(im[:,:,0]),dtype=im.dtype)
        param = tf.expand_dims(param,axis=2)
        output = tf.concat([im,param],axis=2)
        return output
      
      image = add_parameter(image,age)
      image = add_parameter(image,exposure)
      image = add_parameter(image,current)
      image = add_parameter(image,thickness)
      image = add_parameter(image,compression)

      if FLAGS.balance_pos:
        image_pos = image_preprocessing_fn(image_pos, train_image_size, train_image_size)
	image_pos = add_parameter(image_pos,age_pos)
	image_pos = add_parameter(image_pos,exposure_pos)
	image_pos = add_parameter(image_pos,current_pos)
	image_pos = add_parameter(image_pos,thickness_pos)
	image_pos = add_parameter(image_pos,compression_pos)
      image_val = image_preprocessing_fn_val(image_val, train_image_size, train_image_size)
      image_val = add_parameter(image_val,age_val)
      image_val = add_parameter(image_val,exposure_val)
      image_val = add_parameter(image_val,current_val)
      image_val = add_parameter(image_val,thickness_val)
      image_val = add_parameter(image_val,compression_val)

###################################################################
      if not FLAGS.balance_pos:
        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)
        batch_queue_pos = None
      else:
        images, labels = tf.train.batch(
            [image, label],
            batch_size=int(FLAGS.batch_size/2),
            num_threads=int(FLAGS.num_preprocessing_threads/2),
            capacity=5 * int(FLAGS.batch_size/2))
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)
        images_pos, labels_pos = tf.train.batch(
            [image_pos, label_pos],
            batch_size=int(FLAGS.batch_size/2),
            num_threads=int(FLAGS.num_preprocessing_threads/2),
            capacity=5 * int(FLAGS.batch_size/2))
        labels_pos = slim.one_hot_encoding(
            labels_pos, dataset.num_classes - FLAGS.labels_offset)
        batch_queue_pos = slim.prefetch_queue.prefetch_queue(
            [images_pos, labels_pos], capacity=2 * deploy_config.num_clones)
      images_val, labels_val = tf.train.batch(
          [image_val, label_val],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      batch_queue_val = slim.prefetch_queue.prefetch_queue(
          [images_val, labels_val], capacity=2 * deploy_config.num_clones)


    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue, batch_queue_pos):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      if batch_queue_pos != None:
        images_pos, labels_pos = batch_queue_pos.dequeue()
        images = tf.concat([images_pos, images], axis=0, name='image_pos_neg_stack')
        labels = tf.concat([labels_pos, labels], axis=0, name='label_pos_neg_stack')
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4,
            scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue, batch_queue_pos])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Get the logits for validation set
    with tf.name_scope('validation') as val_scope:
      device = deploy_config.clone_device(0);
      with tf.device(device):
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          images_val, labels_val = batch_queue_val.dequeue();
         # logits_val, end_points_val = network_fn(images_val);
          logits_val, end_points_val = network_fn_val(images_val);
	  scs_val = tf.nn.softmax(logits_val);
	  predictions_val = tf.argmax(logits_val, 1);
	  
    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          total_num_replicas=FLAGS.worker_replicas,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ##############################################
    # Define train step function with validation #
    ##############################################
    def train_step_fn(session, *args, **kwargs):
      def softmax(x):
        """Compute softmax values for each sets of scores in x."""
        x=np.transpose(x);
        return np.transpose(np.exp(x) / np.sum(np.exp(x), axis=0))
      total_loss, should_stop = train_step(session, *args, **kwargs)

      if train_step_fn.step % FLAGS.eval_every_n_steps == 0:
        start = time.time()
        tp = 0;
        fp = 0;
        tn = 0;
        fn = 0;
        y_true = [];
        y_scores = [];
        for eval_iter in range(num_batches_val):
#          pred, lab, logis = session.run([train_step_fn.predictions_val, train_step_fn.labels_val, train_step_fn.logits_val])
          pred, lab, logis, scs = session.run([train_step_fn.predictions_val, train_step_fn.labels_val, train_step_fn.logits_val, train_step_fn.scs_val])
#          scs = softmax(logis);
          scs = scs[:, 1];
          
          y_true = np.append(y_true,lab);
          y_scores = np.append(y_scores, scs);
          pos = np.equal(lab, 1);
          neg = np.invert(pos);
          pred_pos = np.equal(pred, 1);
          pred_neg = np.invert(pred_pos);
          tp += np.sum(np.multiply(pos, pred_pos));
          tn += np.sum(np.multiply(neg, pred_neg));
          fp += np.sum(np.multiply(neg, pred_pos));
          fn += np.sum(np.multiply(pos, pred_neg));
        auc = roc_auc_score(y_true, y_scores);
          
        accuracy = (tp+tn)/(tp+tn+fp+fn);
        logging.info('Step %s - Loss: %.2f Accuracy: %.2f%%, TP: %d, FP: %d, TN: %d, FN: %d, AUC: %.2f%%', str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100, tp, fp, tn, fn, auc*100)	
	
	of.write('Step %s - Loss: %.2f Accuracy: %.2f%%, TP: %d, FP: %d, TN: %d, FN: %d, AUC: %.2f%% \n'% (str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100, tp, fp, tn, fn, auc*100))


        end = time.time()
        logging.info('Took %f seconds for evaluation', end-start);
      train_step_fn.step += 1
      return [total_loss, should_stop]

    output_file_path = os.path.join(FLAGS.train_dir, 'eval.txt')
    of = open(output_file_path, 'a')

    train_step_fn.step = 0
    train_step_fn.predictions_val = predictions_val
    train_step_fn.labels_val = labels_val
    train_step_fn.logits_val = logits_val
    train_step_fn.scs_val = scs_val 

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        saver = tf.train.Saver(max_to_keep=20),
        sync_optimizer=optimizer if FLAGS.sync_replicas else None,
        train_step_fn=train_step_fn)

    of.close()
def main(_):
    if ((not FLAGS.dataset_dir_joint)):
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset_joint = dataset_factory.get_dataset(FLAGS.dataset_name_joint,
                                                    FLAGS.dataset_split_name,
                                                    FLAGS.dataset_dir_joint)

        ####################
        # Select the network #
        ####################

        #  network_fn_iris = nets_factory.get_network_fn(
        #     FLAGS.model_name_iris,
        #    num_classes=(dataset.num_classes - FLAGS.labels_offset),
        #    weight_decay=FLAGS.weight_decay,
        #   is_training=True)

        network_fn_joint = nets_factory.get_network_fn_joint(
            FLAGS.model_name_joint,
            num_classes=(dataset_joint.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name_iris = FLAGS.preprocessing_name_iris or FLAGS.model_name_iris
        image_preprocessing_fn_iris = preprocessing_factory.get_preprocessing(
            preprocessing_name_iris, is_training=True)

        preprocessing_name_face = FLAGS.preprocessing_name_face or FLAGS.model_name_face
        image_preprocessing_fn_face = preprocessing_factory.get_preprocessing(
            preprocessing_name_face, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider_joint = slim.dataset_data_provider.DatasetDataProvider(
                dataset_joint,
                shuffle=False,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image_iris, label_iris, image_face,
             label_face] = provider_joint.get(
                 ['image_iris', 'label_iris', 'image_face', 'label_face'])
            label_iris -= FLAGS.labels_offset
            label_face -= FLAGS.labels_offset

            #	train_image_size_iris = FLAGS.train_image_size_iris or network_fn_iris.default_image_size
            new_height_iris = FLAGS.New_Height_Of_Image_iris or network_fn_joint.default_image_size
            new_width_iris = FLAGS.New_Width_Of_Image_iris or network_fn_joint.default_image_size

            #         image = image_preprocessing_fn(image, train_image_size, train_image_size)
            image_iris = image_preprocessing_fn_iris(image_iris,
                                                     new_height_iris,
                                                     new_width_iris)

            new_height_face = FLAGS.New_Height_Of_Image_face or network_fn_joint.default_image_size
            new_width_face = FLAGS.New_Width_Of_Image_face or network_fn_joint.default_image_size

            #         image = image_preprocessing_fn(image, train_image_size, train_image_size)
            image_face = image_preprocessing_fn_face(image_face,
                                                     new_height_face,
                                                     new_width_face)

            #  io.imshow(image)
            #  io.show()
            images_iris, labels_iris, images_face, labels_face = tf.train.batch(
                [image_iris, label_iris, image_face, label_face],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                shapes=[(64, 512, 3), (), (224, 224, 3), ()],
                capacity=5 * FLAGS.batch_size)
            #      tf.image_summary('images', images)
            labels_iris = slim.one_hot_encoding(
                labels_iris, dataset_joint.num_classes - FLAGS.labels_offset)
            labels_face = slim.one_hot_encoding(
                labels_face, dataset_joint.num_classes - FLAGS.labels_offset)
            batch_queue_joint = slim.prefetch_queue.prefetch_queue(
                [images_iris, labels_iris, images_face, labels_face],
                capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################

        def clone_fn(batch_queue_joint):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images_iris, labels_iris, images_face, labels_face = batch_queue_joint.dequeue(
            )
            logits, end_points = network_fn_joint(images_face, images_iris)

            #  def clone_fn_face(batch_queue_face):
            #      """Allows data parallelism by creating multiple clones of network_fn."""
            #    images_face, labels_face = batch_queue_face.dequeue()
            #    logits_face, end_points_face, features_face,model_var_face = network_fn_face(images_face)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels_face,
                    label_smoothing=FLAGS.label_smoothing,
                    weight=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels_face,
                label_smoothing=FLAGS.label_smoothing,
                weight=1.0)

            # Adding the accuracy metric
            with tf.name_scope('accuracy'):
                predictions = tf.argmax(logits, 1)
                labels_face = tf.argmax(labels_face, 1)
                accuracy = tf.reduce_mean(
                    tf.to_float(tf.equal(predictions, labels_face)))
                tf.add_to_collection('accuracy', accuracy)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue_joint])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.histogram_summary('activations/' + end_point, x))
            summaries.add(
                tf.scalar_summary('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.histogram_summary(variable.op.name, variable))

        # Add summaries for the input images.
        summaries.add(
            tf.image_summary('face',
                             images_face,
                             max_images=15,
                             name='Face_images'))
        summaries.add(
            tf.image_summary('iris',
                             images_iris,
                             max_images=15,
                             name='Iris_images'))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset_joint.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(
                tf.scalar_summary('learning_rate',
                                  learning_rate,
                                  name='learning_rate'))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # # Add total_loss to summary.
        # summaries.add(tf.scalar_summary('total_loss', total_loss,
        #                                 name='total_loss'))

        # Add total_loss and accuacy to summary.
        summaries.add(
            tf.scalar_summary('eval/Total_Loss', total_loss,
                              name='total_loss'))
        accuracy = tf.get_collection('accuracy', first_clone_scope)[0]
        summaries.add(
            tf.scalar_summary('eval/Accuracy', accuracy, name='accuracy'))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.merge_summary(list(summaries), name='summary_op')

        init_iris, init_feed = _get_init_op()

        #	var_2=[v for v in tf.all_variables() if v.name == "vgg_19/conv3/conv3_3/weights:0"][0]

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=init_iris,
            init_feed_dict=init_feed,
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
        global_step = tf.train.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    # dataset = dataset_factory.get_dataset(
    #     FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
    dataset = ucf11.get_split(FLAGS.dataset_split_name, FLAGS.dataset_dir)
    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(FLAGS.model_name,
                                             num_classes=(dataset.num_classes -
                                                          FLAGS.labels_offset),
                                             weight_decay=FLAGS.weight_decay,
                                             is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name, is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    train_image_size = FLAGS.train_image_size or network_fn.default_image_size

    input, label = ucf11.build_data(dataset)
    input = image_preprocessing_fn(input, train_image_size, train_image_size)

    inputs, labels = tf.train.batch(
        [input, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)
    labels = slim.one_hot_encoding(labels,
                                   dataset.num_classes - FLAGS.labels_offset)
    batch_queue = slim.prefetch_queue.prefetch_queue([inputs, labels],
                                                     capacity=2 *
                                                     deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
        """Allows data parallelism by creating multiple clones of network_fn."""
        inputs, labels = batch_queue.dequeue()
        images = tf.unstack(inputs, axis=1)

        # with tf.variable_scope('Inception', reuse=True):
        logits1, end_points1 = network_fn(images[0])
        logits2, end_points2 = network_fn(images[1])

        # Feature with size maximum k
        f_k_1 = [
            end_points1['Conv2d_1a_7x7'],
        ]
        f_k_2 = [
            end_points2['Conv2d_1a_7x7'],
        ]

        # Feature with size k/2
        f_k2_1 = [
            end_points1['MaxPool_2a_3x3'],
            end_points1['Conv2d_2b_1x1'],
            end_points1['Conv2d_2c_3x3'],
        ]
        f_k2_2 = [
            end_points2['MaxPool_2a_3x3'],
            end_points2['Conv2d_2b_1x1'],
            end_points2['Conv2d_2c_3x3'],
        ]

        # Feature with size k/4
        f_k4_1 = [
            end_points1['MaxPool_3a_3x3'],
            end_points1['Mixed_3b'],
            end_points1['Mixed_3c'],
        ]
        f_k4_2 = [
            end_points2['MaxPool_3a_3x3'],
            end_points2['Mixed_3b'],
            end_points2['Mixed_3c'],
        ]

        logits_off, end_point_off = off(
            f_k_1,
            f_k_2,
            f_k2_1,
            f_k2_2,
            f_k4_1,
            f_k4_2,
            num_classes=dataset.num_classes - FLAGS.labels_offset,
            resnet_model_name=FLAGS.resnet_model_name,
            resnet_weight_decay=FLAGS.resnet_weight_decay)

        logits_gen = tf.reduce_mean(tf.stack([logits1, logits2], axis=2),
                                    axis=2)

        logits = tf.multiply(logits_gen, logits_off, name='logits')

        #############################
        # Specify the loss function #
        #############################
        if FLAGS.use_off_sub_loss:
            logits_tier1 = end_point_off['logits_tier1']
            logits_tier2 = end_point_off['logits_tier2']

            tf.losses.softmax_cross_entropy(
                logits=logits_tier1,
                onehot_labels=labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=0.2,
                scope='off_tier1_loss')

            tf.losses.softmax_cross_entropy(
                logits=logits_tier2,
                onehot_labels=labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=0.4,
                scope='off_tier2_loss')

        tf.losses.softmax_cross_entropy(logits=logits,
                                        onehot_labels=labels,
                                        label_smoothing=FLAGS.label_smoothing,
                                        weights=1.0)
        return dict(end_points1, **end_point_off)

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram('activations/' + end_point, x))
        summaries.add(
            tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
        summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
        summaries.add(tf.summary.histogram(variable.op.name, variable))

    # # Accuracy
    # predictions = tf.argmax(end_points['Logits'], axis=1)
    # truth = tf.argmax(labels, axis=1)
    # truth = tf.squeeze(truth)
    # # accuracy, accuracy_update = tf.metrics.accuracy(truth, predictions)
    # #
    # # update_ops.append(accuracy_update)
    # # summaries.add(tf.summary.scalar('Accuracy', accuracy))
    #
    # names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
    #     'Accuracy': slim.metrics.streaming_accuracy(predictions, truth),
    #     'Recall_5': slim.metrics.streaming_recall_at_k(
    #         end_points['Logits'], truth, 5),
    # })
    #
    # # Print the summaries to screen.
    # for name, value in names_to_values.items():
    #     summary_name = '%s' % name
    #     op = tf.summary.scalar(summary_name, value, collections=[])
    #     op = tf.Print(op, [value], summary_name)
    #     summaries.add(op)
    #
    # update_ops.append(names_to_updates)

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
        moving_average_variables = slim.get_model_variables()
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
    else:
        moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples,
                                                 global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
        # If sync_replicas is enabled, the averaging will be done in the chief
        # queue runner.
        optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables,
            replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
            total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
        # Update ops executed locally by trainer.
        update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones, optimizer, var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    train_tensor = control_flow_ops.with_dependencies([update_op],
                                                      total_loss,
                                                      name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None,
        session_config=session_config)
Пример #7
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)
    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)
        # Create global_step.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # Select the dataset.
        dataset = dataset_factory.get_dataset(
            FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

        # Get the SSD network and its anchors.
        ssd_class = nets_factory.get_network(FLAGS.model_name)
        ssd_params = ssd_class.default_params._replace(num_classes=FLAGS.num_classes)
        ssd_net = ssd_class(ssd_params)
        ssd_shape = ssd_net.params.img_shape
        ssd_anchors = ssd_net.anchors(ssd_shape)

        # Select the preprocessing function.
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        tf_utils.print_configuration(FLAGS.__flags, ssd_params,
                                     dataset.data_sources, FLAGS.train_dir)
        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device(deploy_config.inputs_device()):
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.
            [image, shape, glabels, gbboxes] = provider.get(['image', 'shape',
                                                             'object/label',
                                                             'object/bbox'])
            # Pre-processing image, labels and bboxes.
            image, glabels, gbboxes = \
                image_preprocessing_fn(image, glabels, gbboxes,
                                       out_shape=ssd_shape,
                                       data_format=DATA_FORMAT)
            # Encode groundtruth labels and bboxes.
            gclasses, glocalisations, gscores = \
                ssd_net.bboxes_encode(glabels, gbboxes, ssd_anchors)
            batch_shape = [1] + [len(ssd_anchors)] * 3

            # Training batches and queue.
            r = tf.train.batch(
                tf_utils.reshape_list([image, gclasses, glocalisations, gscores]),
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(r, batch_shape)

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list([b_image, b_gclasses, b_glocalisations, b_gscores]),
                capacity=2 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple
            clones of network_fn."""
            # Dequeue batch.
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            # Construct SSD network.
            arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay,
                                          data_format=DATA_FORMAT)
            with slim.arg_scope(arg_scope):
                predictions, localisations, logits, end_points = \
                    ssd_net.net(b_image, is_training=True)
            # Add loss function.
            ssd_net.losses(logits, localisations,
                           b_gclasses, b_glocalisations, b_gscores,
                           match_threshold=FLAGS.match_threshold,
                           negative_ratio=FLAGS.negative_ratio,
                           alpha=FLAGS.loss_alpha,
                           label_smoothing=FLAGS.label_smoothing)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # =================================================================== #
        # Add summaries from first clone.
        # =================================================================== #
        clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                            tf.nn.zero_fraction(x)))
        # Add summaries for losses and extra losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # =================================================================== #
        # Configure the optimization procedure.
        # =================================================================== #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(FLAGS,
                                                             dataset.num_samples,
                                                             global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        # and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones,
            optimizer,
            var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(log_device_placement=False,
                                gpu_options=gpu_options)
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master='',
            is_chief=True,
            init_fn=tf_utils.get_init_fn(FLAGS),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            saver=saver,
            save_interval_secs=FLAGS.save_interval_secs,
            session_config=config,
            sync_optimizer=None)
Пример #8
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Пример #9
0
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    if args.log_root:
        if args.config is None:
            raise RuntimeError('No config json specified.')
        config_json = args.config
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        hparams = Namespace(**configs)
        logdir_name = config_str.get_config_time_str(hparams, 'wavenet',
                                                     EXP_TAG)
        logdir = os.path.join(args.log_root, logdir_name)
        os.makedirs(logdir, exist_ok=True)
        shutil.copy(config_json, logdir)
    else:
        logdir = args.logdir
        config_json = glob.glob(os.path.join(logdir, '*.json'))[0]
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        hparams = Namespace(**configs)

    enhance_log.add_log_file(logdir)
    if not args.log_root:
        tf.logging.info('Continue running\n\n')
    tf.logging.info('using config form {}'.format(config_json))
    tf.logging.info('Saving to {}'.format(logdir))

    wn = wavenet.Wavenet(hparams,
                         os.path.abspath(os.path.expanduser(args.train_path)))
    wn_config_str = enhance_log.instance_attr_to_str(wn)
    tf.logging.info('\n' + wn_config_str)

    def _data_dep_init():
        # slim.learning.train runs init_fn earlier than start_queue_runner
        # so the the function got dead locker if use the `input_dict` in L76 as input
        inputs_val = reader.get_init_batch(wn.train_path,
                                           batch_size=args.total_batch_size,
                                           seq_len=wn.wave_length)
        wave_data = inputs_val['wav']
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'wav': tf.placeholder(dtype=tf.float32, shape=wave_data.shape),
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)
        }

        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        init_ff_dict = wn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Calculate initial statistics.')
            init_out = session.run(init_ff_dict,
                                   feed_dict={
                                       _inputs_dict['wav']: wave_data,
                                       _inputs_dict['mel']: mel_data
                                   })
            init_out_params = init_out['out_params']
            if wn.loss_type == 'mol':
                _, mean, log_scale = np.split(init_out_params, 3, axis=2)
                scale = np.exp(np.maximum(log_scale, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(scale, 'scale')
            elif wn.loss_type == 'gauss':
                mean, log_std = np.split(init_out_params, 2, axis=2)
                std = np.exp(np.maximum(log_std, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(std, 'std')
            tf.logging.info('Done Calculate initial statistics.')

        return callback

    def _model_fn(_inputs_dict):
        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        ff_dict = wn.feed_forward(_inputs_dict)
        ff_dict.update(encode_dict)
        loss_dict = wn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = wn.get_batch(clone_batch_size)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(wn.learning_rate_schedule[0])
            for key, value in wn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)

            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=tf.trainable_variables())
            if GRAD_CLIP:
                clone_grads_vars = grad_clip(clone_grads_vars)
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(tf.trainable_variables()))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')
        data_dep_init_fn = _data_dep_init()

        slim.learning.train(train_tensor,
                            logdir=logdir,
                            number_of_steps=wn.num_iters,
                            summary_op=summary_op,
                            global_step=global_step,
                            log_every_n_steps=100,
                            save_summaries_secs=600,
                            save_interval_secs=3600,
                            session_config=session_config,
                            init_fn=data_dep_init_fn)
Пример #10
0
def main(_):

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks,
        )

    # Create global_step
    with tf.device(deploy_config.variables_device()):
        global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    keys_to_features = {
        "image/encoded":
        tf.FixedLenFeature((), tf.string, default_value=""),
        "image/format":
        tf.FixedLenFeature((), tf.string, default_value="png"),
        "image/class/label":
        tf.FixedLenFeature([],
                           tf.int64,
                           default_value=tf.zeros([], dtype=tf.int64)),
    }

    items_to_handlers = {
        "image": slim.tfexample_decoder.Image(),
        "label": slim.tfexample_decoder.Tensor("image/class/label"),
    }

    items_to_descs = {
        "image": "Color image",
        "label": "Class idx",
    }

    label_idx_to_name = {}
    for i, label in enumerate(CLASSES):
        label_idx_to_name[i] = label

    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                      items_to_handlers)
    file_pattern = "tfm_clf_%s.*"
    file_pattern = os.path.join(FLAGS.records_name,
                                file_pattern % FLAGS.dataset_split_name)
    dataset = slim.dataset.Dataset(
        data_sources=file_pattern,  # TODO UPDATE
        reader=tf.TFRecordReader,
        decoder=decoder,
        num_samples=80000,  # TODO UPDATE
        items_to_descriptions=items_to_descs,
        num_classes=len(CLASSES),
        labels_to_names=label_idx_to_name,
    )

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True,
    )

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True,
        use_grayscale=FLAGS.use_grayscale)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size,
        )
        [image, label] = provider.get(["image", "label"])
        label -= FLAGS.labels_offset

        train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, train_image_size,
                                       train_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size,
        )
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
        """Allows data parallelism by creating network_fn clones."""
        images, labels = batch_queue.dequeue()
        logits, end_points = network_fn(images)

        #############################
        # Specify the loss function #
        #############################
        if "AuxLogits" in end_points:
            slim.losses.softmax_cross_entropy(
                end_points["AuxLogits"],
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=0.4,
                scope="aux_loss",
            )
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
        return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram("activations/" + end_point, x))
        summaries.add(
            tf.summary.scalar("sparsity/" + end_point, tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
        summaries.add(tf.summary.scalar("losses/%s" % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
        summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
        moving_average_variables = slim.get_model_variables()
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
    else:
        moving_average_variables, variable_averages = None, None

    if FLAGS.quantize_delay >= 0:
        contrib_quantize.create_training_graph(
            quant_delay=FLAGS.quantize_delay)

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples,
                                                 global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar("learning_rate", learning_rate))

    if FLAGS.sync_replicas:
        # If sync_replicas is enabled, the averaging will be done in the chief
        # queue runner.
        optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            total_num_replicas=FLAGS.worker_replicas,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables,
        )
    elif FLAGS.moving_average_decay:
        # Update ops executed locally by trainer.
        update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones, optimizer, var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar("total_loss", total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name="train_op")

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name="summary_op")

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None,
    )
Пример #11
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

        net = model_cmc.Model()

        with tf.device(deploy_config.inputs_device()):
            if (FLAGS.model == 'unet'):
                batch_queue = \
                load_batch.get_batch(FLAGS.dataset_dir,
                                        FLAGS.num_readers,
                                        FLAGS.batch_size,
                                        None,
                                        FLAGS,
                                        file_pattern = FLAGS.file_pattern,
                                        is_training = True,
                                        shuffe = FLAGS.shuffle_data)
            elif (FLAGS.model == 'patch'):
                batch_queue = \
                load_batch_patch.get_batch(FLAGS.dataset_dir,
                                        FLAGS.num_readers,
                                        FLAGS.batch_size,
                                        None,
                                        FLAGS,
                                        file_pattern = FLAGS.file_pattern,
                                        is_training = True,
                                        shuffe = FLAGS.shuffle_data)
            elif (FLAGS.model == 'cmc'):
                batch_queue = \
                load_batch_cmc.get_batch(FLAGS.dataset_dir,
                                        FLAGS.num_readers,
                                        FLAGS.batch_size,
                                        None,
                                        FLAGS,
                                        file_pattern = FLAGS.file_pattern,
                                        is_training = True,
                                        shuffe = FLAGS.shuffle_data)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        print("Batch_loading_successful")

        def clone_fn(batch_queue):
            batch_shape = [1] * 3
            b_image, label = batch_queue

            logits, end_points = net.net(b_image)

            # Add loss function.
            loss, mean_iou = net.weighted_losses(logits, label)
            return end_points, mean_iou

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        end_points, mean_iou = clones[0].outputs
        update_ops.append(mean_iou[1])
        #for end_point in end_points:
        #	x = end_points[end_point]
        #	summaries.add(tf.summary.histogram('activations/' + end_point, x))

        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        #
        #for variable in slim.get_model_variables():
        #	summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, FLAGS.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.fine_tune:
            gradient_multipliers = pickle.load(
                open('nets/multiplier_300.pkl', 'rb'))
        else:
            gradient_multipliers = None

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        if gradient_multipliers:
            with ops.name_scope('multiply_grads'):
                clones_gradients = slim.learning.multiply_gradients(
                    clones_gradients, gradient_multipliers)

        if FLAGS.clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                clones_gradients = slim.learning.clip_gradient_norms(
                    clones_gradients, FLAGS.clip_gradient_norm)
        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        #train_tensor = slim.learning.create_train_op(total_loss, optimizer, gradient_multipliers=gradient_multipliers)
        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #

        def train_step_fn(session, *args, **kwargs):
            # visualizer = Beholder(session=session, logdir=FLAGS.train_dir)
            total_loss, should_stop = train_step(session, *args, **kwargs)

            if train_step_fn.step % FLAGS.validation_check == 0:
                _mean_iou = session.run(train_step_fn.mean_iou)
                print('evaluation step %d - loss = %.4f mean_iou = %.2f%%' %\
                 (train_step_fn.step, total_loss, _mean_iou ))
            # evaluated_tensors = session.run([end_points['conv4'], end_points['up1']])
            # example_frame = session.run(end_points['up2'])
            # visualizer.update(arrays=evaluated_tensors, frame=example_frame)

            train_step_fn.step += 1
            return [total_loss, should_stop]

        train_step_fn.step = 0
        train_step_fn.end_points = end_points
        train_step_fn.mean_iou = mean_iou[0]

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction,
            allocator_type="BFC")
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            allow_soft_placement=True,
            inter_op_parallelism_threads=0,
            intra_op_parallelism_threads=1,
        )
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            train_step_fn=train_step_fn,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)

            accuracy = slim.metrics.accuracy(tf.to_int32(tf.argmax(logits, 1)),
                                             tf.to_int32(tf.argmax(labels, 1)))
            tf.add_to_collection('accuracy', accuracy)
            end_points['train_accuracy'] = accuracy
            return end_points

        # Get accuracies for the batch

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            if 'accuracy' in end_point:
                continue
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        train_acc = end_points['train_accuracy']
        summaries.add(
            tf.summary.scalar('train_accuracy', end_points['train_accuracy']))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # @philkuz
        # Add accuracy summaries
        # TODO add if statemetn for n iterations
        # images_val, labels_val= tf.train.batch(
        #     [image, label],
        #     batch_size=FLAGS.batch_size,
        #     num_threads=FLAGS.num_preprocessing_threads,
        #     capacity=5 * FLAGS.batch_size)

        # # labels_val = slim.one_hot_encoding(
        # #     labels_val, dataset.num_classes - FLAGS.labels_offset)
        # batch_queue_val = slim.prefetch_queue.prefetch_queue(
        #     [images_val, labels_val], capacity=2 * deploy_config.num_clones)
        # logits, end_points = network_fn(images, reuse=True)
        # # predictions = tf.nn.softmax(logits)
        # predictions = tf.to_in32(tf.argmax(logits,1))

        # logits_val, end_points_val = network_fn(images_val, reuse=True)
        # predictions_val = tf.to_in32(tf.argmax(logits_val,1))

        # labels_val = tf.squeeze(labels_val)
        # labels = tf.squeeze(labels)

        # names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        #       'train/accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        #       'val/accuracy': slim.metrics.streaming_accuracy(predictions_val, labels_val),
        # })
        # for metric_name, metric_value in names_to_values.items():
        #   op = tf.summary.scalar(metric_name, metric_value)
        #   # op = tf.Print(op, [metric_value], metric_name)
        #   summaries.add(op)
        # Add summaries for variables.
        # TODO something to remove some of these from tensorboard scalars
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # @philkuz
        # set the  max_number_of_steps parameter if num_epochs is available
        print('FLAGS.num_epochs', FLAGS.num_epochs)
        if FLAGS.num_epochs is not None and FLAGS.max_number_of_steps is None:
            FLAGS.max_number_of_steps = int(
                FLAGS.num_epochs * dataset.num_samples / FLAGS.batch_size)
            # FLAGS.max_number_of_steps = int(math.round(FLAGS.num_epochs / dataset.num_samples))

        # setup the logdir
        # @philkuz  the train_dir setup
        if FLAGS.experiment_name is not None:
            experiment_dir = 'bs={},lr={},epochs={}/{}'.format(
                FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs,
                FLAGS.experiment_name)
            print(experiment_dir)
            FLAGS.train_dir = os.path.join(FLAGS.train_dir, experiment_dir)
            print(FLAGS.train_dir)

        # @philkuz overriding train_step
        def train_step(sess, train_op, global_step, train_step_kwargs):
            """Function that takes a gradient step and specifies whether to stop.
      Args:
        sess: The current session.
        train_op: An `Operation` that evaluates the gradients and returns the
          total loss.
        global_step: A `Tensor` representing the global training step.
        train_step_kwargs: A dictionary of keyword arguments.
      Returns:
        The total loss and a boolean indicating whether or not to stop training.
      Raises:
        ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
      """
            start_time = time.time()

            trace_run_options = None
            run_metadata = None
            should_acc = True  # TODO make this not hardcoded @philkuz
            if 'should_trace' in train_step_kwargs:
                if 'logdir' not in train_step_kwargs:
                    raise ValueError(
                        'logdir must be present in train_step_kwargs when '
                        'should_trace is present')
                if sess.run(train_step_kwargs['should_trace']):
                    trace_run_options = config_pb2.RunOptions(
                        trace_level=config_pb2.RunOptions.FULL_TRACE)
                    run_metadata = config_pb2.RunMetadata()
            if not should_acc:
                total_loss, np_global_step = sess.run(
                    [train_op, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            else:
                total_loss, acc, np_global_step = sess.run(
                    [train_op, train_acc, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            time_elapsed = time.time() - start_time

            if run_metadata is not None:
                tl = timeline.Timeline(run_metadata.step_stats)
                trace = tl.generate_chrome_trace_format()
                trace_filename = os.path.join(
                    train_step_kwargs['logdir'],
                    'tf_trace-%d.json' % np_global_step)
                tf.logging.info('Writing trace to %s', trace_filename)
                file_io.write_string_to_file(trace_filename, trace)
                if 'summary_writer' in train_step_kwargs:
                    train_step_kwargs['summary_writer'].add_run_metadata(
                        run_metadata, 'run_metadata-%d' % np_global_step)

            if 'should_log' in train_step_kwargs:
                if sess.run(train_step_kwargs['should_log']):
                    if not should_acc:
                        tf.logging.info(
                            'global step %d: loss = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, time_elapsed)
                    else:
                        tf.logging.info(
                            'global step %d: loss = %.4f train_acc = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, acc, time_elapsed)

            if 'should_stop' in train_step_kwargs:
                should_stop = sess.run(train_step_kwargs['should_stop'])
            else:
                should_stop = False

            return total_loss, should_stop

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            train_step_fn=train_step,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Пример #13
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        network_fn = nets_factory.get_network(FLAGS.model_name)
        params = network_fn.default_params
        params = params._replace(match_threshold=FLAGS.match_threshold)
        # initalize the net
        net = network_fn(params)
        out_shape = net.params.img_shape
        anchors = net.anchors(out_shape)

        # create batch dataset
        with tf.device(deploy_config.inputs_device()):
            b_image, b_glocalisations, b_gscores = \
            load_batch.get_batch(FLAGS.dataset_dir,
                  FLAGS.num_readers,
                  FLAGS.batch_size,
                  out_shape,
                  net,
                  anchors,
                  FLAGS,
                  file_pattern = FLAGS.file_pattern,
                  is_training = True,
                  shuffe = FLAGS.shuffle_data)
            allgscores = []
            allglocalization = []
            for i in range(len(anchors)):
                allgscores.append(tf.reshape(b_gscores[i], [-1]))
                allglocalization.append(
                    tf.reshape(b_glocalisations[i], [-1, 4]))

            b_gscores = tf.concat(allgscores, 0)
            b_glocalisations = tf.concat(allglocalization, 0)

            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list([b_image, b_glocalisations, b_gscores]),
                num_threads=8,
                capacity=16 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):

            #Allows data parallelism by creating multiple
            #clones of network_fn.

            # Dequeue batch.
            batch_shape = [1] * 3
            b_image, b_glocalisations, b_gscores = \
             tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)
            # Construct SSD network.
            arg_scope = net.arg_scope(weight_decay=FLAGS.weight_decay,
                                      data_format=FLAGS.data_format)
            with slim.arg_scope(arg_scope):
                localisations, logits, end_points = \
                 net.net(b_image, is_training=True, use_batch=FLAGS.use_batch)
            # Add loss function.
            net.losses(logits,
                       localisations,
                       b_glocalisations,
                       b_gscores,
                       negative_ratio=FLAGS.negative_ratio,
                       use_hard_neg=FLAGS.use_hard_neg,
                       alpha=FLAGS.loss_alpha,
                       label_smoothing=FLAGS.label_smoothing)
            return end_points

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        #end_points = clones[0].outputs
        #for end_point in end_points:
        #	x = end_points[end_point]
        #	summaries.add(tf.summary.histogram('activations/' + end_point, x))

        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        #
        #for variable in slim.get_model_variables():
        #	summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, FLAGS.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.fine_tune:
            gradient_multipliers = pickle.load(
                open('nets/multiplier_300.pkl', 'rb'))
        else:
            gradient_multipliers = None

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        if gradient_multipliers:
            with ops.name_scope('multiply_grads'):
                clones_gradients = slim.learning.multiply_gradients(
                    clones_gradients, gradient_multipliers)

        if FLAGS.clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                clones_gradients = slim.learning.clip_gradient_norms(
                    clones_gradients, FLAGS.clip_gradient_norm)
        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        #train_tensor = slim.learning.create_train_op(total_loss, optimizer, gradient_multipliers=gradient_multipliers)
        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction,
            allocator_type="BFC")
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            allow_soft_placement=True,
            inter_op_parallelism_threads=0,
            intra_op_parallelism_threads=1,
        )
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
def main_fun(argv, ctx):
  import tensorflow as tf
  from tensorflow.python.ops import control_flow_ops
  from datasets import dataset_factory
  from deployment import model_deploy
  from nets import nets_factory
  from preprocessing import preprocessing_factory

  sys.argv = argv

  slim = tf.contrib.slim

  tf.app.flags.DEFINE_integer(
      'num_gpus', '1', 'The number of GPUs to use per node')

  tf.app.flags.DEFINE_boolean('rdma', False, 'Whether to use rdma.')

  tf.app.flags.DEFINE_string(
      'master', '', 'The address of the TensorFlow master to use.')

  tf.app.flags.DEFINE_string(
      'train_dir', '/tmp/tfmodel/',
      'Directory where checkpoints and event logs are written to.')

  tf.app.flags.DEFINE_integer('num_clones', 1,
                              'Number of model clones to deploy.')

  tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
                              'Use CPUs to deploy clones.')

  tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')

  tf.app.flags.DEFINE_integer(
      'num_ps_tasks', 0,
      'The number of parameter servers. If the value is 0, then the parameters '
      'are handled locally by the worker.')

  tf.app.flags.DEFINE_integer(
      'num_readers', 4,
      'The number of parallel readers that read data from the dataset.')

  tf.app.flags.DEFINE_integer(
      'num_preprocessing_threads', 4,
      'The number of threads used to create the batches.')

  tf.app.flags.DEFINE_integer(
      'log_every_n_steps', 10,
      'The frequency with which logs are print.')

  tf.app.flags.DEFINE_integer(
      'save_summaries_secs', 600,
      'The frequency with which summaries are saved, in seconds.')

  tf.app.flags.DEFINE_integer(
      'save_interval_secs', 600,
      'The frequency with which the model is saved, in seconds.')

  tf.app.flags.DEFINE_integer(
      'task', 0, 'Task id of the replica running the training.')

  ######################
  # Optimization Flags #
  ######################

  tf.app.flags.DEFINE_float(
      'weight_decay', 0.00004, 'The weight decay on the model weights.')

  tf.app.flags.DEFINE_string(
      'optimizer', 'rmsprop',
      'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
      '"ftrl", "momentum", "sgd" or "rmsprop".')

  tf.app.flags.DEFINE_float(
      'adadelta_rho', 0.95,
      'The decay rate for adadelta.')

  tf.app.flags.DEFINE_float(
      'adagrad_initial_accumulator_value', 0.1,
      'Starting value for the AdaGrad accumulators.')

  tf.app.flags.DEFINE_float(
      'adam_beta1', 0.9,
      'The exponential decay rate for the 1st moment estimates.')

  tf.app.flags.DEFINE_float(
      'adam_beta2', 0.999,
      'The exponential decay rate for the 2nd moment estimates.')

  tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')

  tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
                            'The learning rate power.')

  tf.app.flags.DEFINE_float(
      'ftrl_initial_accumulator_value', 0.1,
      'Starting value for the FTRL accumulators.')

  tf.app.flags.DEFINE_float(
      'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')

  tf.app.flags.DEFINE_float(
      'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')

  tf.app.flags.DEFINE_float(
      'momentum', 0.9,
      'The momentum for the MomentumOptimizer and RMSPropOptimizer.')

  tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')

  #######################
  # Learning Rate Flags #
  #######################

  tf.app.flags.DEFINE_string(
      'learning_rate_decay_type',
      'exponential',
      'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
      ' or "polynomial"')

  tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

  tf.app.flags.DEFINE_float(
      'end_learning_rate', 0.0001,
      'The minimal end learning rate used by a polynomial decay learning rate.')

  tf.app.flags.DEFINE_float(
      'label_smoothing', 0.0, 'The amount of label smoothing.')

  tf.app.flags.DEFINE_float(
      'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')

  tf.app.flags.DEFINE_float(
      'num_epochs_per_decay', 2.0,
      'Number of epochs after which learning rate decays.')

  tf.app.flags.DEFINE_bool(
      'sync_replicas', False,
      'Whether or not to synchronize the replicas during training.')

  tf.app.flags.DEFINE_integer(
      'replicas_to_aggregate', 1,
      'The Number of gradients to collect before updating params.')

  tf.app.flags.DEFINE_float(
      'moving_average_decay', None,
      'The decay to use for the moving average.'
      'If left as None, then moving averages are not used.')

  #######################
  # Dataset Flags #
  #######################

  tf.app.flags.DEFINE_string(
      'dataset_name', 'imagenet', 'The name of the dataset to load.')

  tf.app.flags.DEFINE_string(
      'dataset_split_name', 'train', 'The name of the train/test split.')

  tf.app.flags.DEFINE_string(
      'dataset_dir', None, 'The directory where the dataset files are stored.')

  tf.app.flags.DEFINE_integer(
      'labels_offset', 0,
      'An offset for the labels in the dataset. This flag is primarily used to '
      'evaluate the VGG and ResNet architectures which do not use a background '
      'class for the ImageNet dataset.')

  tf.app.flags.DEFINE_string(
      'model_name', 'inception_v3', 'The name of the architecture to train.')

  tf.app.flags.DEFINE_string(
      'preprocessing_name', None, 'The name of the preprocessing to use. If left '
      'as `None`, then the model_name flag is used.')

  tf.app.flags.DEFINE_integer(
      'batch_size', 32, 'The number of samples in each batch.')

  tf.app.flags.DEFINE_integer(
      'train_image_size', None, 'Train image size')

  tf.app.flags.DEFINE_integer('max_number_of_steps', None,
                              'The maximum number of training steps.')

  #####################
  # Fine-Tuning Flags #
  #####################

  tf.app.flags.DEFINE_string(
      'checkpoint_path', None,
      'The path to a checkpoint from which to fine-tune.')

  tf.app.flags.DEFINE_string(
      'checkpoint_exclude_scopes', None,
      'Comma-separated list of scopes of variables to exclude when restoring '
      'from a checkpoint.')

  tf.app.flags.DEFINE_string(
      'trainable_scopes', None,
      'Comma-separated list of scopes to filter the set of variables to train.'
      'By default, None would train all the variables.')

  tf.app.flags.DEFINE_boolean(
      'ignore_missing_vars', False,
      'When restoring a checkpoint would ignore missing variables.')

  FLAGS = tf.app.flags.FLAGS
  FLAGS.job_name = ctx.job_name
  FLAGS.task = ctx.task_index
  FLAGS.num_clones = FLAGS.num_gpus
  FLAGS.worker_replicas = len(ctx.cluster_spec['worker'])
  assert(FLAGS.num_ps_tasks == (len(ctx.cluster_spec['ps']) if 'ps' in ctx.cluster_spec else 0))

  def _configure_learning_rate(num_samples_per_epoch, global_step):
    """Configures the learning rate.

    Args:
      num_samples_per_epoch: The number of samples in each epoch of training.
      global_step: The global_step tensor.

    Returns:
      A `Tensor` representing the learning rate.

    Raises:
      ValueError: if
    """
    decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
                      FLAGS.num_epochs_per_decay)
    if FLAGS.sync_replicas:
      decay_steps /= FLAGS.replicas_to_aggregate

    if FLAGS.learning_rate_decay_type == 'exponential':
      return tf.train.exponential_decay(FLAGS.learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True,
                                        name='exponential_decay_learning_rate')
    elif FLAGS.learning_rate_decay_type == 'fixed':
      return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
    elif FLAGS.learning_rate_decay_type == 'polynomial':
      return tf.train.polynomial_decay(FLAGS.learning_rate,
                                       global_step,
                                       decay_steps,
                                       FLAGS.end_learning_rate,
                                       power=1.0,
                                       cycle=False,
                                       name='polynomial_decay_learning_rate')
    else:
      raise ValueError('learning_rate_decay_type [%s] was not recognized',
                       FLAGS.learning_rate_decay_type)


  def _configure_optimizer(learning_rate):
    """Configures the optimizer used for training.

    Args:
      learning_rate: A scalar or `Tensor` learning rate.

    Returns:
      An instance of an optimizer.

    Raises:
      ValueError: if FLAGS.optimizer is not recognized.
    """
    if FLAGS.optimizer == 'adadelta':
      optimizer = tf.train.AdadeltaOptimizer(
          learning_rate,
          rho=FLAGS.adadelta_rho,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'adagrad':
      optimizer = tf.train.AdagradOptimizer(
          learning_rate,
          initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
    elif FLAGS.optimizer == 'adam':
      optimizer = tf.train.AdamOptimizer(
          learning_rate,
          beta1=FLAGS.adam_beta1,
          beta2=FLAGS.adam_beta2,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'ftrl':
      optimizer = tf.train.FtrlOptimizer(
          learning_rate,
          learning_rate_power=FLAGS.ftrl_learning_rate_power,
          initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
          l1_regularization_strength=FLAGS.ftrl_l1,
          l2_regularization_strength=FLAGS.ftrl_l2)
    elif FLAGS.optimizer == 'momentum':
      optimizer = tf.train.MomentumOptimizer(
          learning_rate,
          momentum=FLAGS.momentum,
          name='Momentum')
    elif FLAGS.optimizer == 'rmsprop':
      optimizer = tf.train.RMSPropOptimizer(
          learning_rate,
          decay=FLAGS.rmsprop_decay,
          momentum=FLAGS.momentum,
          epsilon=FLAGS.opt_epsilon)
    elif FLAGS.optimizer == 'sgd':
      optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    else:
      raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
    return optimizer


  def _add_variables_summaries(learning_rate):
    summaries = []
    for variable in slim.get_model_variables():
      summaries.append(tf.summary.histogram(variable.op.name, variable))
    summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
    return summaries


  def _get_init_fn():
    """Returns a function run by the chief worker to warm-start the training.

    Note that the init_fn is only run when initializing the model during the very
    first global step.

    Returns:
      An init function run by the supervisor.
    """
    if FLAGS.checkpoint_path is None:
      return None

    # Warn the user if a checkpoint exists in the train_dir. Then we'll be
    # ignoring the checkpoint anyway.
    if tf.train.latest_checkpoint(FLAGS.train_dir):
      tf.logging.info(
          'Ignoring --checkpoint_path because a checkpoint already exists in %s'
          % FLAGS.train_dir)
      return None

    exclusions = []
    if FLAGS.checkpoint_exclude_scopes:
      exclusions = [scope.strip()
                    for scope in FLAGS.checkpoint_exclude_scopes.split(',')]

    # TODO(sguada) variables.filter_variables()
    variables_to_restore = []
    for var in slim.get_model_variables():
      excluded = False
      for exclusion in exclusions:
        if var.op.name.startswith(exclusion):
          excluded = True
          break
      if not excluded:
        variables_to_restore.append(var)

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)

    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=FLAGS.ignore_missing_vars)


  def _get_variables_to_train():
    """Returns a list of variables to train.

    Returns:
      A list of variables to train by the optimizer.
    """
    if FLAGS.trainable_scopes is None:
      return tf.trainable_variables()
    else:
      scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]

    variables_to_train = []
    for scope in scopes:
      variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
      variables_to_train.extend(variables)
    return variables_to_train

  # main
  cluster_spec, server = TFNode.start_cluster_server(ctx=ctx, num_gpus=FLAGS.num_gpus, rdma=FLAGS.rdma)
  if ctx.job_name == 'ps':
    # `ps` jobs wait for incoming connections from the workers.
    server.join()
  else:
    # `worker` jobs will actually do the work.
    if not FLAGS.dataset_dir:
      raise ValueError('You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
      #######################
      # Config model_deploy #
      #######################
      deploy_config = model_deploy.DeploymentConfig(
          num_clones=FLAGS.num_clones,
          clone_on_cpu=FLAGS.clone_on_cpu,
          replica_id=FLAGS.task,
          num_replicas=FLAGS.worker_replicas,
          num_ps_tasks=FLAGS.num_ps_tasks)

      # Create global_step
      #with tf.device(deploy_config.variables_device()):
      #  global_step = slim.create_global_step()
      with tf.device("/job:ps/task:0"):
        global_step = tf.Variable(0, name="global_step")

      ######################
      # Select the dataset #
      ######################
      dataset = dataset_factory.get_dataset(
          FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

      ######################
      # Select the network #
      ######################
      network_fn = nets_factory.get_network_fn(
          FLAGS.model_name,
          num_classes=(dataset.num_classes - FLAGS.labels_offset),
          weight_decay=FLAGS.weight_decay,
          is_training=True)

      #####################################
      # Select the preprocessing function #
      #####################################
      preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
      image_preprocessing_fn = preprocessing_factory.get_preprocessing(
          preprocessing_name,
          is_training=True)

      ##############################################################
      # Create a dataset provider that loads data from the dataset #
      ##############################################################
      with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= FLAGS.labels_offset

        train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, train_image_size, train_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)

      ####################
      # Define the model #
      ####################
      def clone_fn(batch_queue):
        """Allows data parallelism by creating multiple clones of network_fn."""
        images, labels = batch_queue.dequeue()
        logits, end_points = network_fn(images)

        #############################
        # Specify the loss function #
        #############################
        if 'AuxLogits' in end_points:
          tf.losses.softmax_cross_entropy(
              logits=end_points['AuxLogits'], onehot_labels=labels,
              label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
        tf.losses.softmax_cross_entropy(
            logits=logits, onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=1.0)
        return end_points

      # Gather initial summaries.
      summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

      clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
      first_clone_scope = deploy_config.clone_scope(0)
      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by network_fn.
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

      # Add summaries for end_points.
      end_points = clones[0].outputs
      for end_point in end_points:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram('activations/' + end_point, x))
        summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                        tf.nn.zero_fraction(x)))

      # Add summaries for losses.
      for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
        summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

      # Add summaries for variables.
      for variable in slim.get_model_variables():
        summaries.add(tf.summary.histogram(variable.op.name, variable))

      #################################
      # Configure the moving averages #
      #################################
      if FLAGS.moving_average_decay:
        moving_average_variables = slim.get_model_variables()
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
      else:
        moving_average_variables, variable_averages = None, None

      #########################################
      # Configure the optimization procedure. #
      #########################################
      with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

      if FLAGS.sync_replicas:
        # If sync_replicas is enabled, the averaging will be done in the chief
        # queue runner.
        optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_aggregate,
            variable_averages=variable_averages,
            variables_to_average=moving_average_variables,
            replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
            total_num_replicas=FLAGS.worker_replicas)
      elif FLAGS.moving_average_decay:
        # Update ops executed locally by trainer.
        update_ops.append(variable_averages.apply(moving_average_variables))

      # Variables to train.
      variables_to_train = _get_variables_to_train()

      #  and returns a train_tensor and summary_op
      total_loss, clones_gradients = model_deploy.optimize_clones(
          clones,
          optimizer,
          var_list=variables_to_train)
      # Add total_loss to summary.
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Create gradient updates.
      grad_updates = optimizer.apply_gradients(clones_gradients,
                                               global_step=global_step)
      update_ops.append(grad_updates)

      update_op = tf.group(*update_ops)
      train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                        name='train_op')

      # Add the summaries from the first clone. These contain the summaries
      # created by model_fn and either optimize_clones() or _gather_clone_loss().
      summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                         first_clone_scope))

      # Merge all summaries together.
      summary_op = tf.summary.merge(list(summaries), name='summary_op')


      ###########################
      # Kicks off the training. #
      ###########################
      summary_writer = tf.summary.FileWriter("tensorboard_%d" %(ctx.worker_num), graph=tf.get_default_graph())
      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_dir,
          master=server.target,
          is_chief=(FLAGS.task == 0),
          init_fn=_get_init_fn(),
          summary_op=summary_op,
          number_of_steps=FLAGS.max_number_of_steps,
          log_every_n_steps=FLAGS.log_every_n_steps,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs,
          summary_writer=summary_writer,
          sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Пример #15
0
def main(unused_argv):
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  with tf.Graph().as_default():
    with tf.device(config.inputs_device()):
      train_crop_size = (None if 0 in FLAGS.train_crop_size else
                         FLAGS.train_crop_size)
      assert FLAGS.dataset
      assert len(FLAGS.dataset) == len(FLAGS.dataset_dir)
      if len(FLAGS.first_frame_finetuning) == 1:
        first_frame_finetuning = (list(FLAGS.first_frame_finetuning)
                                  * len(FLAGS.dataset))
      else:
        first_frame_finetuning = FLAGS.first_frame_finetuning
      if len(FLAGS.three_frame_dataset) == 1:
        three_frame_dataset = (list(FLAGS.three_frame_dataset)
                               * len(FLAGS.dataset))
      else:
        three_frame_dataset = FLAGS.three_frame_dataset
      assert len(FLAGS.dataset) == len(first_frame_finetuning)
      assert len(FLAGS.dataset) == len(three_frame_dataset)
      datasets, samples_list = zip(
          *[_get_dataset_and_samples(config, train_crop_size, dataset,
                                     dataset_dir, bool(first_frame_finetuning_),
                                     bool(three_frame_dataset_))
            for dataset, dataset_dir, first_frame_finetuning_,
            three_frame_dataset_ in zip(FLAGS.dataset, FLAGS.dataset_dir,
                                        first_frame_finetuning,
                                        three_frame_dataset)])
      # Note that this way of doing things is wasteful since it will evaluate
      # all branches but just use one of them. But let's do it anyway for now,
      # since it's easy and will probably be fast enough.
      dataset = datasets[0]
      if len(samples_list) == 1:
        samples = samples_list[0]
      else:
        probabilities = FLAGS.dataset_sampling_probabilities
        if probabilities:
          assert len(probabilities) == len(samples_list)
        else:
          # Default to uniform probabilities.
          probabilities = [1.0 / len(samples_list) for _ in samples_list]
        probabilities = tf.constant(probabilities)
        logits = tf.log(probabilities[tf.newaxis])
        rand_idx = tf.squeeze(tf.multinomial(logits, 1, output_dtype=tf.int32),
                              axis=[0, 1])

        def wrap(x):
          def f():
            return x
          return f

        samples = tf.case({tf.equal(rand_idx, idx): wrap(s)
                           for idx, s in enumerate(samples_list)},
                          exclusive=True)

      # Prefetch_queue requires the shape to be known at graph creation time.
      # So we only use it if we crop to a fixed size.
      if train_crop_size is None:
        inputs_queue = samples
      else:
        inputs_queue = prefetch_queue.prefetch_queue(
            samples,
            capacity=FLAGS.prefetch_queue_capacity_factor*config.num_clones,
            num_threads=FLAGS.prefetch_queue_num_threads)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      if FLAGS.classification_loss == 'triplet':
        embedding_dim = FLAGS.embedding_dimension
        output_type_to_dim = {'embedding': embedding_dim}
      else:
        output_type_to_dim = {common.OUTPUT_TYPE: dataset.num_classes}
      model_args = (inputs_queue, output_type_to_dim, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in tf.contrib.framework.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy,
          FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step,
          FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps,
          FLAGS.learning_power,
          FLAGS.slow_start_step,
          FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(grads_and_vars,
                                                          grad_mult)

      with tf.name_scope('grad_clipping'):
        grads_and_vars = slim.learning.clip_gradient_norms(grads_and_vars, 5.0)

      # Create histogram summaries for the gradients.
      # We have too many summaries for mldash, so disable this one for now.
      # for grad, var in grads_and_vars:
      #   summaries.add(tf.summary.histogram(
      #       var.name.replace(':0', '_0') + '/gradient', grad))

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(grads_and_vars,
                                               global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(FLAGS.train_logdir,
                                              FLAGS.tf_initial_checkpoint,
                                              FLAGS.initialize_last_layer,
                                              last_layers,
                                              ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Пример #16
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=4,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (dataset.get_one_shot_iterator(), {
          common.OUTPUT_TYPE: dataset.num_of_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in tf.model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_of_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy,
          FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step,
          FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps,
          FLAGS.learning_power,
          FLAGS.slow_start_step,
          FLAGS.slow_start_learning_rate,
          decay_steps=FLAGS.decay_steps,
          end_learning_rate=FLAGS.end_learning_rate)

      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

      if FLAGS.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      elif FLAGS.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
      else:
        raise ValueError('Unknown optimizer')

    if FLAGS.quantize_delay_step >= 0:
      if FLAGS.num_clones > 1:
        raise ValueError('Quantization doesn\'t support multi-clone yet.')
      contrib_quantize.create_training_graph(
          quant_delay=FLAGS.quantize_delay_step)

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    profile_dir = FLAGS.profile_logdir
    if profile_dir is not None:
      tf.gfile.MakeDirs(profile_dir)

    with contrib_tfprof.ProfileContext(
        enabled=profile_dir is not None, profile_dir=profile_dir):
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      slim.learning.train(
          train_tensor,
          train_step_fn=train_step,
          logdir=FLAGS.train_logdir,
          log_every_n_steps=FLAGS.log_steps,
          master=FLAGS.master,
          number_of_steps=FLAGS.training_number_of_steps,
          is_chief=(FLAGS.task == 0),
          session_config=session_config,
          startup_delay_steps=startup_delay_steps,
          init_fn=init_fn,
          summary_op=summary_op,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
    OUTPUTS_DIR = os.getenv('VH_OUTPUTS_DIR', None)

    checkpoints = tarfile.open(os.path.join(OUTPUTS_DIR, 'checkpoints.tar.gz'), 'w:gz')
    for dirpath, dirnames, filenames in os.walk(FLAGS.train_logdir):
      for filename in filenames:
          filepath   = os.path.join(dirpath, filename)
          checkpoints.add(filepath, filename)
    checkpoints.close()
Пример #17
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.worker_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
      FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,
      num_classes=(dataset.num_classes - FLAGS.labels_offset),
      weight_decay=FLAGS.weight_decay,
      is_training=True,
      width_multiplier=FLAGS.width_multiplier)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,
      is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size)

      # gt_bboxes format [ymin, xmin, ymax, xmax]
      [image, img_shape, gt_labels, gt_bboxes] = provider.get(['image', 'shape',
                                                               'object/label',
                                                               'object/bbox'])

      # Preprocesing
      # gt_bboxes = scale_bboxes(gt_bboxes, img_shape)  # bboxes format [0,1) for tf draw

      image, gt_labels, gt_bboxes = image_preprocessing_fn(image,
                                                           config.IMG_HEIGHT,
                                                           config.IMG_WIDTH,
                                                           labels=gt_labels,
                                                           bboxes=gt_bboxes,
                                                           )

      #############################################
      # Encode annotations for losses computation #
      #############################################

      # anchors format [cx, cy, w, h]
      anchors = tf.convert_to_tensor(config.ANCHOR_SHAPE, dtype=tf.float32)

      # encode annos, box_input format [cx, cy, w, h]
      input_mask, labels_input, box_delta_input, box_input = encode_annos(gt_labels,
                                                                          gt_bboxes,
                                                                          anchors,
                                                                          config.NUM_CLASSES)

      images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input = tf.train.batch(
        [image, input_mask, labels_input, box_delta_input, box_input],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

      batch_queue = slim.prefetch_queue.prefetch_queue(
        [images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, b_input_mask, b_labels_input, b_box_delta_input, b_box_input = batch_queue.dequeue()
      anchors = tf.convert_to_tensor(config.ANCHOR_SHAPE, dtype=tf.float32)
      end_points = network_fn(images)
      end_points["viz_images"] = images
      conv_ds_14 = end_points['MobileNet/conv_ds_14/depthwise_conv']
      dropout = slim.dropout(conv_ds_14, keep_prob=0.5, is_training=True)
      num_output = config.NUM_ANCHORS * (config.NUM_CLASSES + 1 + 4)
      predict = slim.conv2d(dropout, num_output, kernel_size=(3, 3), stride=1, padding='SAME',
                            activation_fn=None,
                            weights_initializer=tf.truncated_normal_initializer(stddev=0.0001),
                            scope="MobileNet/conv_predict")

      with tf.name_scope("Interpre_prediction") as scope:
        pred_box_delta, pred_class_probs, pred_conf, ious, det_probs, det_boxes, det_class = \
          interpre_prediction(predict, b_input_mask, anchors, b_box_input)
        end_points["viz_det_probs"] = det_probs
        end_points["viz_det_boxes"] = det_boxes
        end_points["viz_det_class"] = det_class

      with tf.name_scope("Losses") as scope:
        losses(b_input_mask, b_labels_input, ious, b_box_delta_input, pred_class_probs, pred_conf, pred_box_delta)

      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      if end_point not in ["viz_images", "viz_det_probs", "viz_det_boxes", "viz_det_class"]:
        x = end_points[end_point]
        summaries.add(tf.summary.histogram('activations/' + end_point, x))
        summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                        tf.nn.zero_fraction(x)))

    # Add summaries for det result TODO(shizehao): vizulize prediction


    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
        opt=optimizer,
        replicas_to_aggregate=FLAGS.replicas_to_aggregate,
        variable_averages=variable_averages,
        variables_to_average=moving_average_variables,
        replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
        total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
      clones,
      optimizer,
      var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                      name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
      train_tensor,
      logdir=FLAGS.train_dir,
      master=FLAGS.master,
      is_chief=(FLAGS.task == 0),
      init_fn=_get_init_fn(),
      summary_op=summary_op,
      number_of_steps=FLAGS.max_number_of_steps,
      log_every_n_steps=FLAGS.log_every_n_steps,
      save_summaries_secs=FLAGS.save_summaries_secs,
      save_interval_secs=FLAGS.save_interval_secs,
      sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Пример #18
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the training graph is
      completely built. This is helpful to perform additional changes to the
      training graph such as optimizing batchnorm. The function should modify
      the default graph.
  """

    detection_model = create_model_fn()

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(create_tensor_dict_fn)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            restore_checkpoints = [
                path.strip()
                for path in train_config.fine_tune_checkpoint.split(',')
            ]

            restorers = get_restore_checkpoint_ops(restore_checkpoints,
                                                   detection_model,
                                                   train_config)

            def initializer_fn(sess):
                for i, restorer in enumerate(restorers):
                    restorer.restore(sess, restore_checkpoints[i])

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (
                None if train_config.add_regularization_loss else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                0.9999, global_step)
            update_ops.append(
                variable_averages.apply(moving_average_variables))

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, 'critic_loss'))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
Пример #19
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the training graph is
      completely built. This is helpful to perform additional changes to the
      training graph such as optimizing batchnorm. The function should modify
      the default graph.
  """

  detection_model = create_model_fn()
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]

  with tf.Graph().as_default():
    # Build a configuration specifying multi-GPU and multi-replicas.
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=worker_replicas,
        num_ps_tasks=ps_tasks,
        worker_job_name=worker_job_name)

    # Place the global step on the device storing the variables.
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    with tf.device(deploy_config.inputs_device()):
      input_queue = create_input_queue(
          train_config.batch_size // num_clones, create_tensor_dict_fn,
          train_config.batch_queue_capacity,
          train_config.num_batch_queue_threads,
          train_config.prefetch_queue_capacity, data_augmentation_options)

    # Gather initial summaries.
    # TODO(rathodv): See if summaries can be added/extracted from global tf
    # collections so that they don't have to be passed around.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    global_summaries = set([])

    model_fn = functools.partial(_create_losses,
                                 create_model_fn=create_model_fn,
                                 train_config=train_config)
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
    first_clone_scope = clones[0].scope

    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    with tf.device(deploy_config.optimizer_device()):
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)
      for var in optimizer_summary_vars:
        tf.summary.scalar(var.op.name, var, family='LearningRate')

    sync_optimizer = None
    if train_config.sync_replicas:
      training_optimizer = tf.train.SyncReplicasOptimizer(
          training_optimizer,
          replicas_to_aggregate=train_config.replicas_to_aggregate,
          total_num_replicas=worker_replicas)
      sync_optimizer = training_optimizer

    with tf.device(deploy_config.optimizer_device()):
      regularization_losses = (None if train_config.add_regularization_loss
                               else [])
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, training_optimizer,
          regularization_losses=regularization_losses)
      total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

      # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
      if train_config.bias_grad_multiplier:
        biases_regex_list = ['.*/biases']
        grads_and_vars = variables_helper.multiply_gradients_matching_regex(
            grads_and_vars,
            biases_regex_list,
            multiplier=train_config.bias_grad_multiplier)

      # Optionally freeze some layers by setting their gradients to be zero.
      if train_config.freeze_variables:
        grads_and_vars = variables_helper.freeze_gradients_matching_regex(
            grads_and_vars, train_config.freeze_variables)

      # Optionally clip gradients
      if train_config.gradient_clipping_by_norm > 0:
        with tf.name_scope('clip_grads'):
          grads_and_vars = slim.learning.clip_gradient_norms(
              grads_and_vars, train_config.gradient_clipping_by_norm)

      # Create gradient updates.
      grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                        global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops, name='update_barrier')
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    if graph_hook_fn:
      with tf.device(deploy_config.variables_device()):
        graph_hook_fn()

    # Add summaries.
    for model_var in slim.get_model_variables():
      global_summaries.add(tf.summary.histogram('ModelVars/' +
                                                model_var.op.name, model_var))
    for loss_tensor in tf.losses.get_losses():
      global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                             loss_tensor))
    global_summaries.add(
        tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))
    summaries |= global_summaries

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Save checkpoints regularly.
    keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
    saver = tf.train.Saver(
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

    # Create ops required to initialize the model from a given checkpoint.
    init_fn = None
    if train_config.fine_tune_checkpoint:
      if not train_config.fine_tune_checkpoint_type:
        # train_config.from_detection_checkpoint field is deprecated. For
        # backward compatibility, fine_tune_checkpoint_type is set based on
        # from_detection_checkpoint.
        if train_config.from_detection_checkpoint:
          train_config.fine_tune_checkpoint_type = 'detection'
        else:
          train_config.fine_tune_checkpoint_type = 'classification'
      var_map = detection_model.restore_map(
          fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
          load_all_detection_checkpoint_vars=(
              train_config.load_all_detection_checkpoint_vars))
      available_var_map = (variables_helper.
                           get_variables_available_in_checkpoint(
                               var_map, train_config.fine_tune_checkpoint))
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn

    slim.learning.train(
        train_tensor,
        logdir=train_dir,
        master=master,
        is_chief=is_chief,
        session_config=session_config,
        startup_delay_steps=train_config.startup_delay_steps,
        init_fn=init_fn,
        summary_op=summary_op,
        number_of_steps=(
            train_config.num_steps if train_config.num_steps else None),
        save_summaries_secs=120,
        sync_optimizer=sync_optimizer,
        saver=saver)
def main(train_dir, num_clones, clone_on_cpu, train_size, val_size,
         num_classes, worker_replicas, log_every_n_steps, save_interval_secs,
         weight_decay, opt, l_r, moving_average_decay, d_set,
         max_number_of_steps, check):

    num_ps_tasks = 0
    num_readers = 4
    num_preprocessing_threads = 4
    task = 0

    if not d_set['dataset_dir']:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(d_set['dataset_name'],
                                              d_set['dataset_split_name'],
                                              d_set['dataset_dir'],
                                              train_size=train_size,
                                              val_size=val_size,
                                              num_classes=num_classes)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            d_set['model_name'],
            num_classes=(dataset.num_classes - d_set['labels_offset']),
            weight_decay=weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = d_set['preprocessing_name'] or d_set['model_name']
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=num_readers,
                common_queue_capacity=20 * d_set['batch_size'],
                common_queue_min=10 * d_set['batch_size'])
            [image, label] = provider.get(['image', 'label'])
            label -= d_set['labels_offset']

            train_image_size = d_set[
                'train_image_size'] or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=d_set['batch_size'],
                num_threads=num_preprocessing_threads,
                capacity=5 * d_set['batch_size'])
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - d_set['labels_offset'])
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=l_r['label_smoothing'],
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=l_r['label_smoothing'],
                weights=1.0)
            return end_points

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        #################################
        # Configure the moving averages #
        #################################
        if moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step, l_r,
                                                     d_set['batch_size'])
            optimizer = _configure_optimizer(learning_rate, opt)

        if l_r['sync_replicas']:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=l_r['replicas_to_aggregate'],
                total_num_replicas=worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train(check['trainable_scopes'])

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master='',
            is_chief=(task == 0),
            init_fn=_get_init_fn(train_dir, check),
            number_of_steps=max_number_of_steps,
            log_every_n_steps=log_every_n_steps,
            save_interval_secs=save_interval_secs,
            sync_optimizer=optimizer if l_r['sync_replicas'] else None)
Пример #21
0
def train(data_dicts,
          class_num,
          input_size,
          lr,
          n_epochs,
          num_clones,
          iters_cnt,
          val_every,
          model_init_fn,
          save_cback,
          atrous_rates=[6, 12, 18],
          fine_tune_batch_norm=True,
          output_stride=16):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=num_clones,
                                           clone_on_cpu=clone_on_cpu,
                                           replica_id=task,
                                           num_replicas=num_replicas,
                                           num_ps_tasks=num_ps_tasks)

    with tf.Graph().as_default():
        with tf.device(config.inputs_device()):
            samples = get(data_dicts['train'],
                          input_size,
                          is_training=True,
                          model_variant=model_variant)
            samples_val = get(data_dicts['val'],
                              input_size,
                              is_training=True,
                              model_variant=model_variant)

        inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                     capacity=128 *
                                                     config.num_clones,
                                                     dynamic_pad=True)
        inputs_queue_val = prefetch_queue.prefetch_queue(samples_val,
                                                         capacity=128 *
                                                         config.num_clones,
                                                         dynamic_pad=True)
        coord = tf.train.Coordinator()

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                'semantic': class_num
            }, input_size, atrous_rates, output_stride, fine_tune_batch_norm)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = lr
            optimizer = tf.train.AdamOptimizer(learning_rate)

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')

            model_fn_val = _build_deeplab_val
            model_args_val = (inputs_queue_val, {
                'semantic': class_num
            }, input_size, atrous_rates, output_stride)
            val_clones, val_losses = create_val_clones(num_clones,
                                                       config,
                                                       model_fn_val,
                                                       args=model_args_val)
            val_total_loss = get_clones_val_losses(val_clones, None,
                                                   val_losses)
            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes()
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        coord.clear_stop()
        sess = tf.Session(config=config)

        graph = ops.get_default_graph()
        with graph.as_default():
            with ops.name_scope('init_ops'):
                init_op = variables.global_variables_initializer()
                ready_op = variables.report_uninitialized_variables()
                local_init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    lookup_ops.tables_initializer())

        # graph.finalize()
        sess.run([init_op, ready_op, local_init_op])
        queue_runners = graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
        threads = []
        for qr in queue_runners:
            threads.extend(
                qr.create_threads(sess, coord=coord, daemon=True, start=True))

        # # for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        # #     print(i)
        # vary_23 = [v for v in tf.global_variables() if v.name == 'xception_65/middle_flow/block1/unit_8/xception_module/separable_conv3_depthwise/BatchNorm/moving_mean:0'][0]
        #
        # beta_23 = [v for v in tf.global_variables() if v.name == 'xception_65/middle_flow/block1/unit_8/xception_module/separable_conv3_depthwise/BatchNorm/gamma:0'][0]
        # for i in range(1000):
        #     train_loss = sess.run(train_tensor)
        #     print(train_loss)
        #     vary, beta = sess.run([vary_23, beta_23])
        #     print('mean', vary[0:3])
        #     print('beta', beta[0:3])
        #     if (i + 1) % 10 == 0:
        #         for i in range(10):
        #             val_loss = sess.run(val_total_loss)
        #             vary, beta = sess.run([vary_23, beta_23])
        #             print('mean val', vary[0:3])
        #             print('beta', beta[0:3])
        #             print('VAl_loss', val_loss)

        model_init_fn(sess)
        saver = tf.train.Saver()
        eval_planner = EvalPlanner(n_epochs, val_every)
        progress = sly.progress_counter_train(n_epochs, iters_cnt['train'])
        best_val_loss = float('inf')
        epoch_flt = 0

        for epoch in range(n_epochs):
            logger.info("Before new epoch", extra={'epoch': epoch_flt})
            for train_it in range(iters_cnt['train']):
                total_loss = sess.run(train_tensor)

                metrics_values_train = {
                    'loss': total_loss,
                }

                progress.iter_done_report()
                epoch_flt = epoch_float(epoch, train_it + 1,
                                        iters_cnt['train'])
                sly.report_metrics_training(epoch_flt, metrics_values_train)

                if eval_planner.need_validation(epoch_flt):
                    logger.info("Before validation",
                                extra={'epoch': epoch_flt})

                    overall_val_loss = 0
                    for val_it in range(iters_cnt['val']):
                        overall_val_loss += sess.run(val_total_loss)

                        logger.info("Validation in progress",
                                    extra={
                                        'epoch': epoch_flt,
                                        'val_iter': val_it,
                                        'val_iters': iters_cnt['val']
                                    })

                    metrics_values_val = {
                        'loss': overall_val_loss / iters_cnt['val'],
                    }
                    sly.report_metrics_validation(epoch_flt,
                                                  metrics_values_val)
                    logger.info("Validation has been finished",
                                extra={'epoch': epoch_flt})

                    eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info(
                            'It\'s been determined that current model is the best one for a while.'
                        )

                    save_cback(saver,
                               sess,
                               model_is_best,
                               opt_data={
                                   'epoch': epoch_flt,
                                   'val_metrics': metrics_values_val,
                               })

            logger.info("Epoch was finished", extra={'epoch': epoch_flt})
Пример #22
0
def main():
    print(args)
    prt('')

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))

    np.random.seed(seed=args.seed)

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(num_clones=args.num_gpus,
                                                      clone_on_cpu=False)
        tf.set_random_seed(args.seed)
        #global_step = tf.Variable(0, trainable=False)
        global_step = variables.get_or_create_global_step()

        # Placeholder for the learning rate
        #learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate')

        #batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size')
        with tf.device('/cpu:0'):
            is_training_placeholder = tf.placeholder(tf.bool,
                                                     name='is_training')
            image_paths_placeholder = tf.placeholder(tf.string,
                                                     shape=(None, 3),
                                                     name='image_paths')
            labels_placeholder = tf.placeholder(tf.int64,
                                                shape=(None, 3),
                                                name='labels')

            input_queue = data_flow_ops.FIFOQueue(capacity=100000,
                                                  dtypes=[tf.string, tf.int64],
                                                  shapes=[(3, ), (3, )],
                                                  shared_name=None,
                                                  name=None)
            enqueue_op = input_queue.enqueue_many(
                [image_paths_placeholder, labels_placeholder])

            nrof_preprocess_threads = 8
            images_and_labels = []
            for _ in range(nrof_preprocess_threads):
                filenames, label = input_queue.dequeue()
                #filenames = tf.Print(filenames, [tf.shape(filenames)], 'filenames shape:')
                images = []
                for filename in tf.unstack(filenames):
                    #filename = tf.Print(filename, [filename], 'filename = ')
                    file_contents = tf.read_file(filename)
                    image = tf.image.decode_jpeg(file_contents)
                    #image = tf.Print(image, [tf.shape(image)], 'data count = ')
                    if image.dtype != tf.float32:
                        image = tf.image.convert_image_dtype(image,
                                                             dtype=tf.float32)
                    if args.random_crop:
                        #image = tf.random_crop(image, [args.image_size, args.image_size, 3])
                        bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
                                           dtype=tf.float32,
                                           shape=[1, 1, 4])
                        sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
                            tf.shape(image),
                            bounding_boxes=bbox,
                            area_range=(0.7, 1.0),
                            use_image_if_no_bounding_boxes=True)
                        bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
                        image = tf.slice(image, bbox_begin, bbox_size)
                    #else:
                    #    image = tf.image.resize_image_with_crop_or_pad(image, args.image_size, args.image_size)
                    image = tf.expand_dims(image, 0)
                    image = tf.image.resize_bilinear(
                        image, [args.image_size, args.image_size],
                        align_corners=False)
                    image = tf.squeeze(image, [0])
                    if args.random_flip:
                        image = tf.image.random_flip_left_right(image)
                    image.set_shape((args.image_size, args.image_size, 3))
                    ##pylint: disable=no-member
                    image = tf.subtract(image, 0.5)
                    image = tf.multiply(image, 2.0)
                    #image = tf.Print(image, [tf.shape(image)], 'data count = ')
                    images.append(image)
                    #images.append(tf.image.per_image_standardization(image))
                images_and_labels.append([images, label])

            learning_rate = get_learning_rate(args)
            opt = get_optimizer(args, learning_rate)
            image_batch, label_batch = tf.train.batch_join(
                images_and_labels,
                batch_size=args.batch_size,
                shapes=[(args.image_size, args.image_size, 3), ()],
                enqueue_many=True,
                capacity=4 * nrof_preprocess_threads * args.batch_size,
                allow_smaller_final_batch=False)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [image_batch, label_batch], capacity=9000)

        def clone_fn(_batch_queue):
            _image_batch, _label_batch = _batch_queue.dequeue()
            embeddings = image_to_embedding(_image_batch,
                                            is_training_placeholder, args)

            # Split embeddings into anchor, positive and negative and calculate triplet loss
            anchor, positive, negative = tf.unstack(
                tf.reshape(embeddings, [-1, 3, args.embedding_size]), 3, 1)
            triplet_loss = triplet_loss_fn(anchor, positive, negative,
                                           args.alpha)
            tf.losses.add_loss(triplet_loss)
            #tf.summary.scalar('learning_rate', learning_rate)
            return embeddings, _label_batch, triplet_loss

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone = clones[0]
        triplet_loss = first_clone.outputs[2]
        embeddings = first_clone.outputs[0]
        _label_batch = first_clone.outputs[1]
        #embedding_clones = model_deploy.create_clones(deploy_config, embedding_fn, [batch_queue])

        #first_clone_scope = deploy_config.clone_scope(0)
        #update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
        update_ops = []
        vdic = [
            v for v in tf.trainable_variables() if v.name.find("Logits/") < 0
        ]
        pretrained_saver = tf.train.Saver(vdic)
        saver = tf.train.Saver(max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = get_learning_rate(args)
            opt = get_optimizer(args, learning_rate)

        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, opt, var_list=tf.trainable_variables())

        grad_updates = opt.apply_gradients(clones_gradients,
                                           global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_op = control_flow_ops.with_dependencies([update_op],
                                                      total_loss,
                                                      name='train_op')

        vdic = [
            v for v in tf.trainable_variables() if v.name.find("Logits/") < 0
        ]
        pretrained_saver = tf.train.Saver(vdic)
        saver = tf.train.Saver(max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Start running operations on the Graph.
        #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        #sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        sess = tf.Session()

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={is_training_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={is_training_placeholder: True})

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():

            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                pretrained_saver.restore(
                    sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                eval_one_epoch(args, sess, dataset, image_paths_placeholder,
                               labels_placeholder, is_training_placeholder,
                               enqueue_op, clones)
                # Train for one epoch
                train_one_epoch(args, sess, dataset, image_paths_placeholder,
                                labels_placeholder, is_training_placeholder,
                                enqueue_op, input_queue, clones, total_loss,
                                train_op, summary_op, summary_writer)

                # Save variables and the metagraph if it doesn't exist already
                global_step = variables.get_or_create_global_step()
                step = sess.run(global_step, feed_dict=None)
                print('one epoch finish', step)
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)
                print('saver finish')

    sess.close()
    return model_dir
Пример #23
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():

        #######################
        # Quantizers          #
        #######################

        if FLAGS.intr_grad_quantizer is not '':
            qtype, qargs = utils.split_quantizer_str(FLAGS.intr_grad_quantizer)
            intr_grad_quantizer = utils.quantizer_selector(qtype, qargs)
        else:
            intr_grad_quantizer = None

        if FLAGS.extr_grad_quantizer is not '':
            qtype, qargs = utils.split_quantizer_str(FLAGS.extr_grad_quantizer)
            extr_grad_quantizer = utils.quantizer_selector(qtype, qargs)
        else:
            extr_grad_quantizer = None

        intr_q_map = utils.quantizer_map(FLAGS.intr_qmap)
        extr_q_map = utils.quantizer_map(FLAGS.extr_qmap)
        weight_q_map = utils.quantizer_map(FLAGS.weight_qmap)

        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            intr_q_map=intr_q_map,
            extr_q_map=extr_q_map,
            weight_q_map=weight_q_map)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                tf.losses.softmax_cross_entropy(
                    logits=end_points['AuxLogits'],
                    onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            tf.losses.softmax_cross_entropy(
                logits=logits,
                onehot_labels=labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        #############
        # Summaries #
        #############

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparse_activations/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # Add summaries for sparsity.
        weights_name_list, weights_list = utils.get_variables_list('weights')
        biases_name_list, biases_list = utils.get_variables_list('biases')
        for weight in weights_list:
            summaries.add(
                tf.summary.scalar('weight-sparsity/' + weight.name,
                                  tf.nn.zero_fraction(weight)))
            summaries.add(
                tf.summary.histogram('quantized_weights/' + weight.name,
                                     weight))
        for bias in biases_list:
            summaries.add(
                tf.summary.scalar('weight-sparsity/' + bias.name,
                                  tf.nn.zero_fraction(bias)))
            summaries.add(
                tf.summary.histogram('quantized_weights/' + bias.name, bias))
        # summaries for overall sparsity
        if weights_list is not []:
            weights_overall_sparsity = [
                tf.reshape(x, [tf.size(x)]) for x in weights_list
            ]
            weights_overall_sparsity = tf.concat(weights_overall_sparsity,
                                                 axis=0)
            summaries.add(
                tf.summary.scalar(
                    'weight-sparsity/weights-overall',
                    tf.nn.zero_fraction(weights_overall_sparsity)))
        if biases_list is not []:
            biases_overall_sparsity = [
                tf.reshape(x, [tf.size(x)]) for x in biases_list
            ]
            biases_overall_sparsity = tf.concat(biases_overall_sparsity,
                                                axis=0)
            summaries.add(
                tf.summary.scalar(
                    'weight-sparsity/biases-overall',
                    tf.nn.zero_fraction(biases_overall_sparsity)))

        # Add layerwise weight heatmaps
        for it in range(len(weights_name_list)):
            name = weights_name_list[it]
            weight = weights_list[it]
            if weight.get_shape().ndims == 4:
                image = utils.heatmap_conv(weight, pad=1)
            elif weight.get_shape().ndims == 2:
                image = utils.heatmap_fullyconnect(weight, pad=1)
            else:
                continue
            summaries.add(tf.summary.image(name, image))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate,
                                             intr_grad_quantizer)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        # quantize 'clones_gradients'
        if extr_grad_quantizer is not None:
            clones_gradients = [(extr_grad_quantizer.quantize(gv[0]), gv[1])
                                for gv in clones_gradients]

        # Add gradients to summary
        for gv in clones_gradients:
            summaries.add(
                tf.summary.histogram('gradient/%s' % gv[1].op.name, gv[0]))
            summaries.add(
                tf.summary.scalar('gradient-sparsity/%s' % gv[1].op.name,
                                  tf.nn.zero_fraction(gv[0])))

        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
    #tf.disable_v2_behavior() ###
    tf.compat.v1.disable_eager_execution()
    tf.compat.v1.enable_resource_variables()

    # Enable habana bf16 conversion pass
    if FLAGS.dtype == 'bf16':
        os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path
        FLAGS.precision = 'bf16'
    else:
        os.environ['TF_BF16_CONVERSION'] = "0"

    if FLAGS.use_horovod:
        hvd_init()

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=True,
            use_grayscale=FLAGS.use_grayscale)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #if FLAGS.quantize_delay >= 0:
        #  quantize.create_training_graph(quant_delay=FLAGS.quantize_delay) #for debugging!!

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if horovod_enabled():
            hvd.broadcast_global_variables(0)
        ###########################
        # Kicks off the training. #
        ###########################
        with dump_callback():
            with logger.benchmark_context(FLAGS):
                eps1 = ExamplesPerSecondKerasHook(FLAGS.log_every_n_steps,
                                                  output_dir=FLAGS.train_dir,
                                                  batch_size=FLAGS.batch_size)

                write_hparams_v1(
                    eps1.writer, {
                        'batch_size': FLAGS.batch_size,
                        **{x: getattr(FLAGS, x)
                           for x in FLAGS}
                    })

                train_step_kwargs = {}
                if FLAGS.max_number_of_steps:
                    should_stop_op = math_ops.greater_equal(
                        global_step, FLAGS.max_number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                if FLAGS.log_every_n_steps > 0:
                    train_step_kwargs['should_log'] = math_ops.equal(
                        math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)

                eps1.on_train_begin()
                train_step_kwargs['EPS'] = eps1

                slim.learning.train(
                    train_tensor,
                    logdir=FLAGS.train_dir,
                    train_step_fn=train_step1,
                    train_step_kwargs=train_step_kwargs,
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    init_fn=_get_init_fn(),
                    summary_op=summary_op,
                    summary_writer=None,
                    number_of_steps=FLAGS.max_number_of_steps,
                    log_every_n_steps=FLAGS.log_every_n_steps,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    save_interval_secs=FLAGS.save_interval_secs,
                    sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Пример #25
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)
    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)
        # Create global_step.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # Select the dataset.
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        # Get the SSD network and its anchors.
        ssd_class = nets_factory.get_network(FLAGS.model_name)
        ssd_params = ssd_class.default_params._replace(
            num_classes=FLAGS.num_classes)
        ssd_net = ssd_class(ssd_params)
        ssd_shape = ssd_net.params.img_shape
        ssd_anchors = ssd_net.anchors(ssd_shape)

        # Select the preprocessing function.
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        tf_utils.print_configuration(FLAGS.__flags, ssd_params,
                                     dataset.data_sources, FLAGS.train_dir)
        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device(deploy_config.inputs_device()):
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.
            [image, shape, glabels, gbboxes] = provider.get(
                ['image', 'shape', 'object/label', 'object/bbox'])
            # Pre-processing image, labels and bboxes.
            image, glabels, gbboxes = \
                image_preprocessing_fn(image, glabels, gbboxes,
                                       out_shape=ssd_shape,
                                       data_format=DATA_FORMAT)
            # Encode groundtruth labels and bboxes.
            gclasses, glocalisations, gscores = \
                ssd_net.bboxes_encode(glabels, gbboxes, ssd_anchors)
            batch_shape = [1] + [len(ssd_anchors)] * 3

            # Training batches and queue.
            r = tf.train.batch(tf_utils.reshape_list(
                [image, gclasses, glocalisations, gscores]),
                               batch_size=FLAGS.batch_size,
                               num_threads=FLAGS.num_preprocessing_threads,
                               capacity=5 * FLAGS.batch_size)
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(r, batch_shape)

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list(
                    [b_image, b_gclasses, b_glocalisations, b_gscores]),
                capacity=2 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple
            clones of network_fn."""
            # Dequeue batch.
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            # Construct SSD network.
            arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay,
                                          data_format=DATA_FORMAT)
            with slim.arg_scope(arg_scope):
                predictions, localisations, logits, end_points = \
                    ssd_net.net(b_image, is_training=True)
            # Add loss function.
            ssd_net.losses(logits,
                           localisations,
                           b_gclasses,
                           b_glocalisations,
                           b_gscores,
                           match_threshold=FLAGS.match_threshold,
                           negative_ratio=FLAGS.negative_ratio,
                           alpha=FLAGS.loss_alpha,
                           label_smoothing=FLAGS.label_smoothing)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # =================================================================== #
        # Add summaries from first clone.
        # =================================================================== #
        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        # Add summaries for losses and extra losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # =================================================================== #
        # Configure the optimization procedure.
        # =================================================================== #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, dataset.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        # and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(log_device_placement=False,
                                gpu_options=gpu_options)
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
Пример #26
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_biasCNN.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        dataset_val = dataset_biasCNN.get_dataset(FLAGS.dataset_name,
                                                  'validation',
                                                  FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        network_fn_val = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_biasCNN.get_preprocessing(
            preprocessing_name,
            is_training=True,
            flipLR=FLAGS.flipLR,
            random_scale=FLAGS.random_scale,
            is_windowed=FLAGS.is_windowed)

        image_preprocessing_fn_val = preprocessing_biasCNN.get_preprocessing(
            preprocessing_name,
            is_training=False,
            flipLR=FLAGS.flipLR,
            random_scale=FLAGS.random_scale,
            is_windowed=FLAGS.is_windowed)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ############################################
        # Create a provider for the validation set #
        ############################################
        provider_val = slim.dataset_data_provider.DatasetDataProvider(
            dataset_val,
            shuffle=True,
            common_queue_capacity=2 * FLAGS.batch_size_val,
            common_queue_min=FLAGS.batch_size_val)
        [image_val, label_val] = provider_val.get(['image', 'label'])
        label_val -= FLAGS.labels_offset

        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

        image_val = image_preprocessing_fn_val(image_val, eval_image_size,
                                               eval_image_size)

        images_val, labels_val = tf.train.batch(
            [image_val, label_val],
            batch_size=FLAGS.batch_size_val,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size_val)

        ###############################
        # Define the model (training) #
        ###############################

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()

            with tf.variable_scope('my_scope'):
                logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            tf.contrib.quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        #################################
        # Define the model (validation) #
        #################################

        with tf.variable_scope('my_scope', reuse=True):
            logits_val, _ = network_fn_val(images_val)

        predictions_val = tf.argmax(logits_val, 1)
        labels_val = tf.squeeze(labels_val)

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions_val, labels_val),
            'Recall_5':
            slim.metrics.streaming_recall_at_k(logits_val, labels_val, 5),
        })

        for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection('summaries', op)

        # Gather validation summaries
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Create a non-default saver so we don't delete all the old checkpoints.
        my_saver = tf_saver.Saver(
            max_to_keep=FLAGS.max_checkpoints_to_keep,
            keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        )

        # Create a non-default dictionary of options for train_step_fn
        # This is a hack that lets us pass everything we need to run evaluation, into the training loop function
        from tensorflow.python.framework import ops
        from tensorflow.python.framework import constant_op
        from tensorflow.python.ops import math_ops

        with ops.name_scope('train_step'):
            train_step_kwargs = {}

            if FLAGS.max_number_of_steps:
                should_stop_op = math_ops.greater_equal(
                    global_step, FLAGS.max_number_of_steps)
            else:
                should_stop_op = constant_op.constant(False)
            train_step_kwargs['should_stop'] = should_stop_op
            if FLAGS.log_every_n_steps > 0:
                train_step_kwargs['should_log'] = math_ops.equal(
                    math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)
            train_step_kwargs['should_val'] = math_ops.equal(
                math_ops.mod(global_step, FLAGS.val_every_n_steps), 0)
            train_step_kwargs['eval_op'] = list(names_to_updates.values())


#    assert(FLAGS.max_number_of_steps==100000)
        print(should_stop_op)
        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            saver=my_saver,
            train_step_fn=learning_biasCNN.train_step_fn,
            train_step_kwargs=train_step_kwargs)
Пример #27
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = segmentation_dataset.get_dataset(FLAGS.dataset,
                                               FLAGS.train_split,
                                               dataset_dir=FLAGS.dataset_dir)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            samples = input_generator.get(
                dataset,
                FLAGS.train_crop_size,
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                common.OUTPUT_TYPE: dataset.num_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
Пример #28
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpus
    if FLAGS.num_clones == -1:
        FLAGS.num_clones = len(FLAGS.gpus.split(','))

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # tf.set_random_seed(42)
        tf.set_random_seed(0)
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(
            FLAGS.dataset_name,
            FLAGS.dataset_split_name,
            FLAGS.dataset_dir.split(','),
            dataset_list_dir=FLAGS.dataset_list_dir,
            num_samples=FLAGS.frames_per_video,
            modality=FLAGS.modality,
            split_id=FLAGS.split_id)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            batch_size=FLAGS.batch_size,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            dropout_keep_prob=(1.0 - FLAGS.dropout),
            pooled_dropout_keep_prob=(1.0 - FLAGS.pooled_dropout),
            batch_norm=FLAGS.netvlad_batch_norm)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)  # in case of pooling images,
        # now preprocessing is done video-level

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size,
                bgr_flips=FLAGS.bgr_flip)
            [image, label] = provider.get(['image', 'label'])
            # now note that the above image might be a 23 channel image if you have
            # both RGB and flow streams. It will need to split later, but all the
            # preprocessing will be done consistently for all frames over all streams
            label = tf.string_to_number(label, tf.int32)
            label.set_shape(())
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            scale_ratios = [float(el) for el in FLAGS.scale_ratios.split(',')],
            image = image_preprocessing_fn(image,
                                           train_image_size,
                                           train_image_size,
                                           scale_ratios=scale_ratios,
                                           out_dim_scale=FLAGS.out_dim_scale,
                                           model_name=FLAGS.model_name)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            if FLAGS.debug:
                images = tf.Print(images, [labels], 'Read batch')
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)
            summarize_images(images, provider.num_channels_stream)

        ####################
        # Define the model #
        ####################
        kwargs = {}
        if FLAGS.conv_endpoint is not None:
            kwargs['conv_endpoint'] = FLAGS.conv_endpoint

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(
                images,
                pool_type=FLAGS.pooling,
                classifier_type=FLAGS.classifier_type,
                num_channels_stream=provider.num_channels_stream,
                netvlad_centers=FLAGS.netvlad_initCenters.split(','),
                stream_pool_type=FLAGS.stream_pool_type,
                **kwargs)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weight=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weight=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        global end_points_debug
        end_points = clones[0].outputs
        end_points_debug = dict(end_points)
        end_points_debug['images'] = images
        end_points_debug['labels'] = labels
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.histogram_summary('activations/' + end_point, x))
            summaries.add(
                tf.scalar_summary('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.histogram_summary(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(
                tf.scalar_summary('learning_rate',
                                  learning_rate,
                                  name='learning_rate'))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()
        logging.info('Training the following variables: %s' %
                     (' '.join([el.name for el in variables_to_train])))

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # clip the gradients if needed
        if FLAGS.clip_gradients > 0:
            logging.info('Clipping gradients by %f' % FLAGS.clip_gradients)
            with tf.name_scope('clip_gradients'):
                clones_gradients = slim.learning.clip_gradient_norms(
                    clones_gradients, FLAGS.clip_gradients)

        # Add total_loss to summary.
        summaries.add(
            tf.scalar_summary('total_loss', total_loss, name='total_loss'))

        # Create gradient updates.
        train_ops = {}
        if FLAGS.iter_size == 1:
            grad_updates = optimizer.apply_gradients(clones_gradients,
                                                     global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            train_tensor = control_flow_ops.with_dependencies([update_op],
                                                              total_loss,
                                                              name='train_op')
            train_ops = train_tensor
        else:
            gvs = [(grad, var) for grad, var in clones_gradients]
            varnames = [var.name for grad, var in gvs]
            varname_to_var = {var.name: var for grad, var in gvs}
            varname_to_grad = {var.name: grad for grad, var in gvs}
            varname_to_ref_grad = {}
            for vn in varnames:
                grad = varname_to_grad[vn]
                print("accumulating ... ", (vn, grad.get_shape()))
                with tf.variable_scope("ref_grad"):
                    with tf.device(deploy_config.variables_device()):
                        ref_var = slim.local_variable(np.zeros(
                            grad.get_shape(), dtype=np.float32),
                                                      name=vn[:-2])
                        varname_to_ref_grad[vn] = ref_var

            all_assign_ref_op = [
                ref.assign(varname_to_grad[vn])
                for vn, ref in varname_to_ref_grad.items()
            ]
            all_assign_add_ref_op = [
                ref.assign_add(varname_to_grad[vn])
                for vn, ref in varname_to_ref_grad.items()
            ]
            assign_gradients_ref_op = tf.group(*all_assign_ref_op)
            accmulate_gradients_op = tf.group(*all_assign_add_ref_op)
            with tf.control_dependencies([accmulate_gradients_op]):
                final_gvs = [
                    (varname_to_ref_grad[var.name] / float(FLAGS.iter_size),
                     var) for grad, var in gvs
                ]
                apply_gradients_op = optimizer.apply_gradients(
                    final_gvs, global_step=global_step)
                update_ops.append(apply_gradients_op)
                update_op = tf.group(*update_ops)
                train_tensor = control_flow_ops.with_dependencies(
                    [update_op], total_loss, name='train_op')
            for i in range(FLAGS.iter_size):
                if i == 0:
                    train_ops[i] = assign_gradients_ref_op
                elif i < FLAGS.iter_size - 1:  # because apply_gradients also computes
                    # (see control_dependency), so
                    # no need of running an extra iteration
                    train_ops[i] = accmulate_gradients_op
                else:
                    train_ops[i] = train_tensor

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.merge_summary(list(summaries), name='summary_op')

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.intra_op_parallelism_threads = FLAGS.cpu_threads
        # config.allow_soft_placement = True
        # config.gpu_options.per_process_gpu_memory_fraction=0.7

        ###########################
        # Kicks off the training. #
        ###########################
        logging.info('RUNNING ON SPLIT %d' % FLAGS.split_id)
        slim.learning.train(
            train_ops,
            train_step_fn=train_step,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            session_config=config)
Пример #29
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the inference graph is
      built (before optimization). This is helpful to perform additional changes
      to the training graph such as adding FakeQuant ops. The function should
      modify the default graph.

  Raises:
    ValueError: If both num_clones > 1 and train_config.sync_replicas is true.
  """

  detection_model = create_model_fn()
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]

  with tf.Graph().as_default():
    # Build a configuration specifying multi-GPU and multi-replicas.
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=worker_replicas,
        num_ps_tasks=ps_tasks,
        worker_job_name=worker_job_name)

    # Place the global step on the device storing the variables.
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    if num_clones != 1 and train_config.sync_replicas:
      raise ValueError('In Synchronous SGD mode num_clones must ',
                       'be 1. Found num_clones: {}'.format(num_clones))
    batch_size = train_config.batch_size // num_clones
    if train_config.sync_replicas:
      batch_size //= train_config.replicas_to_aggregate

    with tf.device(deploy_config.inputs_device()):
      input_queue = create_input_queue(
          batch_size, create_tensor_dict_fn,
          train_config.batch_queue_capacity,
          train_config.num_batch_queue_threads,
          train_config.prefetch_queue_capacity, data_augmentation_options)

    # Gather initial summaries.
    # TODO(rathodv): See if summaries can be added/extracted from global tf
    # collections so that they don't have to be passed around.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    global_summaries = set([])

    model_fn = functools.partial(_create_losses,
                                 create_model_fn=create_model_fn,
                                 train_config=train_config)
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
    first_clone_scope = clones[0].scope

    if graph_hook_fn:
      with tf.device(deploy_config.variables_device()):
        graph_hook_fn()

    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    with tf.device(deploy_config.optimizer_device()):
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)
      for var in optimizer_summary_vars:
          global_summaries.add(tf.summary.scalar(var.op.name, var, family='LearningRate'))

    sync_optimizer = None
    if train_config.sync_replicas:
      training_optimizer = tf.train.SyncReplicasOptimizer(
          training_optimizer,
          replicas_to_aggregate=train_config.replicas_to_aggregate,
          total_num_replicas=worker_replicas)
      sync_optimizer = training_optimizer

    with tf.device(deploy_config.optimizer_device()):
      regularization_losses = (None if train_config.add_regularization_loss
                               else [])
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, training_optimizer,
          regularization_losses=regularization_losses)
      total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

      # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
      if train_config.bias_grad_multiplier:
        biases_regex_list = ['.*/biases']
        grads_and_vars = variables_helper.multiply_gradients_matching_regex(
            grads_and_vars,
            biases_regex_list,
            multiplier=train_config.bias_grad_multiplier)

      # Optionally freeze some layers by setting their gradients to be zero.
      if train_config.freeze_variables:
        grads_and_vars = variables_helper.freeze_gradients_matching_regex(
            grads_and_vars, train_config.freeze_variables)

      # Optionally clip gradients
      if train_config.gradient_clipping_by_norm > 0:
        with tf.name_scope('clip_grads'):
          grads_and_vars = slim.learning.clip_gradient_norms(
              grads_and_vars, train_config.gradient_clipping_by_norm)

      # Create gradient updates.
      grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                        global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops, name='update_barrier')
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add summaries.
    for model_var in slim.get_model_variables():
      global_summaries.add(tf.summary.histogram('ModelVars/' +
                                                model_var.op.name, model_var))
    for loss_tensor in tf.losses.get_losses():
      global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                             loss_tensor))
    global_summaries.add(
        tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))
    summaries |= global_summaries

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Save checkpoints regularly.
    keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
    saver = tf.train.Saver(
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

    # Create ops required to initialize the model from a given checkpoint.
    init_fn = None
    if train_config.fine_tune_checkpoint:
      if not train_config.fine_tune_checkpoint_type:
        # train_config.from_detection_checkpoint field is deprecated. For
        # backward compatibility, fine_tune_checkpoint_type is set based on
        # from_detection_checkpoint.
        if train_config.from_detection_checkpoint:
          train_config.fine_tune_checkpoint_type = 'detection'
        else:
          train_config.fine_tune_checkpoint_type = 'classification'
      var_map = detection_model.restore_map(
          fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
          load_all_detection_checkpoint_vars=(
              train_config.load_all_detection_checkpoint_vars),
          load_all_optimizer_checkpoint_vars=(
              train_config.load_all_optimizer_checkpoint_vars))
      available_var_map = (variables_helper.
                           get_variables_available_in_checkpoint(
                               var_map, train_config.fine_tune_checkpoint))
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn

    slim.learning.train(
        train_tensor,
        logdir=train_dir,
        master=master,
        is_chief=is_chief,
        session_config=session_config,
        startup_delay_steps=train_config.startup_delay_steps,
        init_fn=init_fn,
        summary_op=summary_op,
        number_of_steps=(
            train_config.num_steps if train_config.num_steps else None),
        save_summaries_secs=120,
        sync_optimizer=sync_optimizer,
        saver=saver)
def main(unused_argv):
    FLAGS.train_logdir = FLAGS.base_logdir + '/' + FLAGS.task_name
    if FLAGS.restore_name == None:
        FLAGS.restore_logdir = FLAGS.train_logdir
    else:
        FLAGS.restore_logdir = FLAGS.base_logdir + '/' + FLAGS.restore_name

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get logging dir ready.
    if not (os.path.isdir(FLAGS.train_logdir)):
        tf.gfile.MakeDirs(FLAGS.train_logdir)
    elif len(os.listdir(FLAGS.train_logdir)) != 0:
        if not (FLAGS.if_restore):
            if_delete_all = raw_input(
                '#### The log folder %s exists and non-empty; delete all logs? [y/n] '
                % FLAGS.train_logdir)
            if if_delete_all == 'y':
                os.system('rm -rf %s/*' % FLAGS.train_logdir)
                print '==== Log folder emptied.'
        else:
            print '==== Log folder exists; not emptying it because we need to restore from it.'
    tf.logging.info('==== Logging in dir:%s; Training on %s set',
                    FLAGS.train_logdir, FLAGS.train_split)

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.num_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)  # /device:CPU:0

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')
    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = regression_dataset.get_dataset(FLAGS.dataset,
                                             FLAGS.train_split,
                                             dataset_dir=FLAGS.dataset_dir)
    dataset_val = regression_dataset.get_dataset(FLAGS.dataset,
                                                 FLAGS.val_split,
                                                 dataset_dir=FLAGS.dataset_dir)
    print '#### The data has size:', dataset.num_samples, dataset_val.num_samples

    codes = np.load(
        '/ssd2/public/zhurui/Documents/mesh-voxelization/models/cars_64/codes.npy'
    )

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            codes_max = np.amax(codes, axis=1).reshape((-1, 1))
            codes_min = np.amin(codes, axis=1).reshape((-1, 1))
            shape_range = np.hstack(
                (codes_max + (codes_max - codes_min) /
                 (dataset.SHAPE_BINS - 1.), codes_min -
                 (codes_max - codes_min) / (dataset.SHAPE_BINS - 1.)))
            bin_range = [
                np.linspace(r[0], r[1], num=b).tolist()
                for r, b in zip(np.vstack((dataset.pose_range,
                                           shape_range)), dataset.bin_nums)
            ]
            # print np.vstack((dataset.pose_range, shape_range))
            # print bin_range[0]
            # print bin_range[-1]
            outputs_to_num_classes = {}
            outputs_to_indices = {}
            for output, bin_num, idx in zip(dataset.output_names,
                                            dataset.bin_nums,
                                            range(len(dataset.output_names))):
                if FLAGS.if_discrete_loss:
                    outputs_to_num_classes[output] = bin_num
                else:
                    outputs_to_num_classes[output] = 1
                outputs_to_indices[output] = idx
            bin_vals = [tf.constant(value=[bin_range[i]], dtype=tf.float32, shape=[1, dataset.bin_nums[i]], name=name) \
                    for i, name in enumerate(dataset.output_names)]
            # print outputs_to_num_classes
            # print spaces_to_indices

            samples = input_generator.get(dataset,
                                          codes,
                                          clone_batch_size,
                                          dataset_split=FLAGS.train_split,
                                          is_training=True,
                                          model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

            samples_val = input_generator.get(
                dataset_val,
                codes,
                clone_batch_size,
                dataset_split=FLAGS.val_split,
                is_training=False,
                model_variant=FLAGS.model_variant)
            inputs_queue_val = prefetch_queue.prefetch_queue(samples_val,
                                                             capacity=128)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (FLAGS, inputs_queue.dequeue(),
                          outputs_to_num_classes, outputs_to_indices, bin_vals,
                          bin_range, dataset, codes, True, False)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)  # clone_0
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        with tf.device('/device:GPU:3'):
            if FLAGS.if_val:
                ## Construct the validation graph; takes one GPU.
                _build_deeplab(FLAGS,
                               inputs_queue_val.dequeue(),
                               outputs_to_num_classes,
                               outputs_to_indices,
                               bin_vals,
                               bin_range,
                               dataset_val,
                               codes,
                               is_training=False,
                               reuse=True)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for images, labels, semantic predictions
        summary_loss_dict = {}
        if FLAGS.save_summaries_images:
            if FLAGS.num_clones > 1:
                pattern_train = first_clone_scope + '/%s:0'
            else:
                pattern_train = '%s:0'
            pattern_val = 'val-%s:0'
            pattern = pattern_val if FLAGS.if_val else pattern_train

            gather_list = [0] if FLAGS.num_clones < 3 else [0, 1, 2]

            summary_mask = graph.get_tensor_by_name(pattern %
                                                    'not_ignore_mask_in_loss')
            summary_mask = tf.reshape(summary_mask,
                                      [-1, dataset.height, dataset.width, 1])
            summary_mask_float = tf.to_float(summary_mask)
            summaries.add(
                tf.summary.image(
                    'gt/%s' % 'not_ignore_mask',
                    tf.gather(tf.cast(summary_mask_float * 255., tf.uint8),
                              gather_list)))

            summary_image = graph.get_tensor_by_name(pattern % common.IMAGE)
            summaries.add(
                tf.summary.image('gt/%s' % common.IMAGE,
                                 tf.gather(summary_image, gather_list)))

            summary_image_name = graph.get_tensor_by_name(pattern %
                                                          common.IMAGE_NAME)
            summaries.add(
                tf.summary.text('gt/%s' % common.IMAGE_NAME,
                                tf.gather(summary_image_name, gather_list)))

            summary_image_name = graph.get_tensor_by_name(pattern_train %
                                                          common.IMAGE_NAME)
            summaries.add(
                tf.summary.text('gt/%s_train' % common.IMAGE_NAME,
                                tf.gather(summary_image_name, gather_list)))

            summary_vis = graph.get_tensor_by_name(pattern % 'vis')
            summaries.add(
                tf.summary.image('gt/%s' % 'vis',
                                 tf.gather(summary_vis, gather_list)))

            def scale_to_255(tensor, pixel_scaling=None):
                tensor = tf.to_float(tensor)
                if pixel_scaling == None:
                    offset_to_zero = tf.reduce_min(tensor)
                    scale_to_255 = tf.div(
                        255., tf.reduce_max(tensor - offset_to_zero))
                else:
                    offset_to_zero, scale_to_255 = pixel_scaling
                summary_tensor_float = tensor - offset_to_zero
                summary_tensor_float = summary_tensor_float * scale_to_255
                summary_tensor_float = tf.clip_by_value(
                    summary_tensor_float, 0., 255.)
                summary_tensor_uint8 = tf.cast(summary_tensor_float, tf.uint8)
                return summary_tensor_uint8, (offset_to_zero, scale_to_255)

            label_outputs = graph.get_tensor_by_name(pattern %
                                                     'label_pose_shape_map')
            label_id_outputs = graph.get_tensor_by_name(
                pattern % 'pose_shape_label_id_map')
            logit_outputs = graph.get_tensor_by_name(
                pattern % 'scaled_prob_logits_pose_shape_map')

            summary_rot_diffs = graph.get_tensor_by_name(pattern %
                                                         'rot_error_map')
            summary_rot_diffs = tf.where(summary_mask, summary_rot_diffs,
                                         tf.zeros_like(summary_rot_diffs))
            summary_rot_diffs_uint8, _ = scale_to_255(summary_rot_diffs)
            summaries.add(
                tf.summary.image(
                    'metrics_map/%s' % 'rot_diffs',
                    tf.gather(summary_rot_diffs_uint8, gather_list)))

            summary_trans_diffs = graph.get_tensor_by_name(pattern %
                                                           'trans_error_map')
            summary_trans_diffs = tf.where(summary_mask, summary_trans_diffs,
                                           tf.zeros_like(summary_trans_diffs))
            summary_trans_diffs_uint8, _ = scale_to_255(summary_trans_diffs)
            summaries.add(
                tf.summary.image('metrics_map/%s' % 'trans_diffs',
                                 tf.gather(summary_trans_diffs, gather_list)))

            shape_id_outputs = graph.get_tensor_by_name(pattern %
                                                        'shape_id_map')
            shape_id_outputs = tf.where(summary_mask, shape_id_outputs + 1,
                                        tf.zeros_like(shape_id_outputs))
            summary_shape_id_output_uint8, _ = scale_to_255(shape_id_outputs)
            summaries.add(
                tf.summary.image(
                    'shape/shape_id_map',
                    tf.gather(summary_shape_id_output_uint8, gather_list)))

            shape_id_outputs_gt = graph.get_tensor_by_name(pattern %
                                                           'shape_id_map_gt')
            shape_id_outputs_gt = tf.where(summary_mask,
                                           shape_id_outputs_gt + 1,
                                           tf.zeros_like(shape_id_outputs))
            summary_shape_id_output_uint8_gt, _ = scale_to_255(
                shape_id_outputs_gt)
            summaries.add(
                tf.summary.image(
                    'shape/shape_id_map_gt',
                    tf.gather(summary_shape_id_output_uint8_gt, gather_list)))

            if FLAGS.if_summary_metrics:
                shape_id_outputs = graph.get_tensor_by_name(
                    pattern % 'shape_id_map_predict')
                summary_shape_id_output = tf.where(
                    summary_mask, shape_id_outputs,
                    tf.zeros_like(shape_id_outputs))
                summary_shape_id_output_uint8, _ = scale_to_255(
                    summary_shape_id_output)
                summaries.add(
                    tf.summary.image(
                        'shape/shape_id_map_predict',
                        tf.gather(summary_shape_id_output_uint8, gather_list)))

                shape_id_sim_map_train = graph.get_tensor_by_name(
                    pattern_train % 'shape_id_sim_map')
                # shape_id_sim_map_train = tf.where(summary_mask, shape_id_sim_map_train, tf.zeros_like(shape_id_sim_map_train))
                shape_id_sim_map_uint8_train, _ = scale_to_255(
                    shape_id_sim_map_train, pixel_scaling=(0., 255.))
                summaries.add(
                    tf.summary.image(
                        'metrics_map/shape_id_sim_map-trainInv',
                        tf.gather(shape_id_sim_map_uint8_train, gather_list)))

                shape_id_sim_map = graph.get_tensor_by_name(pattern %
                                                            'shape_id_sim_map')
                # shape_id_sim_map = tf.where(summary_mask, shape_id_sim_map, tf.zeros_like(shape_id_sim_map))
                shape_id_sim_map_uint8, _ = scale_to_255(shape_id_sim_map,
                                                         pixel_scaling=(0.,
                                                                        255.))
                summaries.add(
                    tf.summary.image(
                        'metrics_map/shape_id_sim_map-valInv',
                        tf.gather(shape_id_sim_map_uint8, gather_list)))

            for output_idx, output in enumerate(dataset.output_names):
                # # Scale up summary image pixel values for better visualization.
                summary_label_output = tf.gather(label_outputs, [output_idx],
                                                 axis=3)
                summary_label_output = tf.where(
                    summary_mask, summary_label_output,
                    tf.zeros_like(summary_label_output))
                summary_label_output_uint8, pixel_scaling = scale_to_255(
                    summary_label_output)
                summaries.add(
                    tf.summary.image(
                        'output/%s_label' % output,
                        tf.gather(summary_label_output_uint8, gather_list)))

                summary_logit_output = tf.gather(logit_outputs, [output_idx],
                                                 axis=3)
                summary_logit_output = tf.where(
                    summary_mask, summary_logit_output,
                    tf.zeros_like(summary_logit_output))
                summary_logit_output_uint8, _ = scale_to_255(
                    summary_logit_output, pixel_scaling)
                summaries.add(
                    tf.summary.image(
                        'output/%s_logit' % output,
                        tf.gather(summary_logit_output_uint8, gather_list)))

                # summary_label_id_output = tf.to_float(tf.gather(label_id_outputs, [output_idx], axis=3))
                # summary_label_id_output = tf.where(summary_mask, summary_label_id_output+1, tf.zeros_like(summary_label_id_output))
                # summary_label_id_output_uint8, _ = scale_to_255(summary_label_id_output)
                # summary_label_id_output_uint8 = tf.identity(summary_label_id_output_uint8, 'tttt'+output)
                # summaries.add(tf.summary.image(
                #     'test/%s_label_id' % output, tf.gather(summary_label_id_output_uint8, gather_list)))

                summary_diff = tf.abs(
                    tf.to_float(summary_label_output_uint8) -
                    tf.to_float(summary_logit_output_uint8))
                summary_diff = tf.where(summary_mask, summary_diff,
                                        tf.zeros_like(summary_diff))
                summaries.add(
                    tf.summary.image(
                        'diff_map/%s_ldiff' % output,
                        tf.gather(tf.cast(summary_diff, tf.uint8),
                                  gather_list)))

                summary_loss = graph.get_tensor_by_name(
                    (pattern % 'loss_slice_reg_').replace(':0', '') + output +
                    ':0')
                summaries.add(
                    tf.summary.scalar(
                        'slice_loss/' + (pattern % 'reg_').replace(':0', '') +
                        output, summary_loss))

                summary_loss = graph.get_tensor_by_name(
                    (pattern % 'loss_slice_cls_').replace(':0', '') + output +
                    ':0')
                summaries.add(
                    tf.summary.scalar(
                        'slice_loss/' + (pattern % 'cls_').replace(':0', '') +
                        output, summary_loss))

            for pattern in [pattern_train, pattern_val
                            ] if FLAGS.if_val else [pattern_train]:
                add_metrics = ['loss_all_shape_id_cls_metric'
                               ] if FLAGS.if_summary_metrics else []
                for loss_name in [
                        'loss_reg_rot_quat_metric', 'loss_reg_rot_quat',
                        'loss_reg_trans_metric', 'loss_reg_trans',
                        'loss_cls_ALL', 'loss_reg_shape'
                ] + add_metrics:
                    if pattern == pattern_val:
                        summary_loss_avg = graph.get_tensor_by_name(pattern %
                                                                    loss_name)
                        # summary_loss_dict['val-'+loss_name] = summary_loss_avg
                    else:
                        summary_loss_avg = train_utils.get_avg_tensor_from_scopes(
                            FLAGS.num_clones, '%s:0', graph, config, loss_name)
                        # summary_loss_dict['train-'+loss_name] = summary_loss_avg
                    summaries.add(
                        tf.summary.scalar(
                            ('total_loss/' + pattern % loss_name).replace(
                                ':0', ''), summary_loss_avg))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            # optimizer = tf.train.AdamOptimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            print '------ total_loss', total_loss, tf.get_collection(
                tf.GraphKeys.LOSSES, first_clone_scope)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss/train', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            print '////last layers', last_layers

            # Filter trainable variables for last layers ONLY.
            # grads_and_vars = train_utils.filter_gradients(last_layers, grads_and_vars)

            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = False

        def train_step_fn(sess, train_op, global_step, train_step_kwargs):
            train_step_fn.step += 1  # or use global_step.eval(session=sess)

            # calc training losses
            loss, should_stop = slim.learning.train_step(
                sess, train_op, global_step, train_step_kwargs)
            print loss
            # print 'loss: ', loss
            # first_clone_test = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'shape_map')).strip('/'))
            # test = sess.run(first_clone_test)
            # # print test
            # print 'test: ', test.shape, np.max(test), np.min(test), np.mean(test), test.dtype
            should_stop = 0

            if FLAGS.if_val and train_step_fn.step % FLAGS.val_interval_steps == 0:
                # first_clone_test = graph.get_tensor_by_name('val-loss_all:0')
                # test = sess.run(first_clone_test)
                print '-- Validating...'
                first_clone_test = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'shape_id_map')).strip('/'))
                first_clone_test2 = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'shape_id_sim_map')).strip('/'))
                # 'ttttrow:0')
                first_clone_test3 = graph.get_tensor_by_name((
                    '%s/%s:0' %
                    (first_clone_scope, 'not_ignore_mask_in_loss')).strip('/'))
                # 'ttttrow:0')
                test_out, test_out2, test_out3 = sess.run(
                    [first_clone_test, first_clone_test2, first_clone_test3])
                # test_out = test[:, :, :, 3]
                test_out = test_out[test_out3]
                # test_out2 = test2[:, :, :, 3]
                test_out2 = test_out2[test_out3]
                # print test_out
                print 'shape_id_map: ', test_out.shape, np.max(
                    test_out), np.min(test_out), np.mean(test_out), np.median(
                        test_out), test_out.dtype
                print 'shape_id_sim_map: ', test_out2.shape, np.max(
                    test_out2), np.min(test_out2), np.mean(
                        test_out2), np.median(test_out2), test_out2.dtype
                print 'masks sum: ', test_out3.dtype, np.sum(
                    test_out3.astype(float))
                # assert np.max(test_out) == np.max(test_out2), 'MAtch1!!!'
                # assert np.min(test_out) == np.min(test_out2), 'MAtch2!!!'

            # first_clone_label = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'pose_map')).strip('/')) # clone_0/val-loss:0
            # # first_clone_pose_dict = graph.get_tensor_by_name(
            # #         ('%s/%s:0' % (first_clone_scope, 'pose_dict')).strip('/'))
            # first_clone_logit = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'scaled_regression')).strip('/'))
            # not_ignore_mask = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'not_ignore_mask_in_loss')).strip('/'))
            # label, logits, mask = sess.run([first_clone_label, first_clone_logit, not_ignore_mask])
            # mask = np.reshape(mask, (-1, FLAGS.train_crop_size[0], FLAGS.train_crop_size[1], dataset.num_classes))

            # print '... shapes, types, loss', label.shape, label.dtype, logits.shape, logits.dtype, loss
            # print 'mask', mask.shape, np.mean(mask)
            # logits[mask==0.] = 0.
            # print 'logits', logits.shape, np.max(logits), np.min(logits), np.mean(logits), logits.dtype
            # for idx in range(6):
            #     print idx, np.max(label[:, :, :, idx]), np.min(label[:, :, :, idx])
            # label = label[:, :, :, 5]
            # print 'label', label.shape, np.max(label), np.min(label), np.mean(label), label.dtype
            # print pose_dict, pose_dict.shape
            # # print 'training....... logits stats: ', np.max(logits), np.min(logits), np.mean(logits)
            # # label_one_piece = label[0, :, :, 0]
            # # print 'training....... label stats', np.max(label_one_piece), np.min(label_one_piece), np.sum(label_one_piece[label_one_piece!=255.])
            return [loss, should_stop]

        train_step_fn.step = 0

        # trainables = [v.name for v in tf.trainable_variables()]
        # alls =[v.name for v in tf.all_variables()]
        # print '----- Trainables %d: '%len(trainables), trainables
        # print '----- All %d: '%len(alls), alls[:10]
        # print '===== ', len(list(set(trainables) - set(alls)))
        # print '===== ', len(list(set(alls) - set(trainables)))

        if FLAGS.if_print_tensors:
            for op in tf.get_default_graph().get_operations():
                print str(op.name)

        # Start the training.
        slim.learning.train(train_tensor,
                            train_step_fn=train_step_fn,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.restore_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.if_restore,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
Пример #31
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = _create_input_queue(
                train_config.batch_size // num_clones, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint)
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        # session_config.gpu_options.allow_growth = True
        session_config.gpu_options.per_process_gpu_memory_fraction = 0.8

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
Пример #32
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        ###########################
        # Kicks off the training. #
        ###########################
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)
        if FLAGS.checkpoint_path == FLAGS.train_dir:
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))

        # load pretrained weights
        weight_ini_fn = _get_init_fn()
        weight_ini_fn(sess)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        for step in xrange(FLAGS.max_number_of_steps):
            start_time = time.time()
            # _, loss_value = sess.run([train_tensor, loss])
            # _, loss_value = sess.run([train_tensor, total_loss])
            loss_value = sess.run(train_tensor)
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % FLAGS.log_every_n_steps == 0:
                # num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                # sec_per_batch = duration / FLAGS.num_gpus
                sec_per_batch = duration

                format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str %
                      (step, loss_value, examples_per_sec, sec_per_batch))

            if step % FLAGS.summary_snapshot_steps == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % FLAGS.model_snapshot_steps == 0 or (
                    step + 1) == FLAGS.max_number_of_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

        print('OK...')
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Пример #34
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        # with tf.device(deploy_config.variables_device()):
        global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        # with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= FLAGS.labels_offset

        train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        image = image_preprocessing_fn(image, train_image_size,
                                       train_image_size)

        images, labels = tf.train.batch(
            [image, label],
            batch_size=FLAGS.batch_size,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            tf.contrib.quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        # with tf.device(deploy_config.optimizer_device()):
        learning_rate = _configure_learning_rate(dataset.num_samples,
                                                 global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################

        ## todos ubaid: tensor graph, metagraph, timingcode

        print(len(tf.get_default_graph().get_operations()))
        all_ops = tf.get_default_graph().get_operations()
        adj_list_graph_notensors = {}
        for op in all_ops:
            adj_list_graph_notensors[op.name] = set(
                [inp.name.split(":")[0] for inp in op.inputs])
        adj_list_graph_notensors = {
            op_name: list(op_deps)
            for op_name, op_deps in adj_list_graph_notensors.items()
        }
        with open('%s/org_train_graph_notensors.json' % (FLAGS.log_fn_pref),
                  'w') as outfile:
            json.dump(adj_list_graph_notensors, outfile)

        adj_list_graph = {}
        for op in all_ops:
            adj_list_graph[op.name] = set([inp.name for inp in op.inputs])
        adj_list_graph = {
            op_name: list(op_deps)
            for op_name, op_deps in adj_list_graph.items()
        }
        with open('%s/org_train_graph.json' % (FLAGS.log_fn_pref),
                  'w') as outfile:
            json.dump(adj_list_graph, outfile)

        print(':::::: done creating model ::::::')
        print(len(all_ops))
        print(FLAGS.log_fn_pref)
        print(FLAGS.doLog)

        metagraph = tf.train.export_meta_graph()
        temp_meta = MessageToJson(metagraph.graph_def)
        with open('%s/metagraph.json' % (FLAGS.log_fn_pref), 'w') as outfile:
            json.dump(temp_meta, outfile)

        config_proto = tf.ConfigProto(
            log_device_placement=False,
            allow_soft_placement=False,
            graph_options=tf.GraphOptions(build_cost_model=1))

        config_proto.gpu_options.allow_growth = True
        config_proto.intra_op_parallelism_threads = 1
        config_proto.inter_op_parallelism_threads = 1
        print("*********!!!!!!!No opts!!!!!!!*********")
        config_proto.graph_options.optimizer_options.opt_level = -1
        config_proto.graph_options.rewrite_options.constant_folding = (
            rewriter_config_pb2.RewriterConfig.OFF)
        config_proto.graph_options.rewrite_options.arithmetic_optimization = (
            rewriter_config_pb2.RewriterConfig.OFF)
        config_proto.graph_options.rewrite_options.dependency_optimization = (
            rewriter_config_pb2.RewriterConfig.OFF)
        config_proto.graph_options.rewrite_options.layout_optimizer = (
            rewriter_config_pb2.RewriterConfig.OFF)

        print(FLAGS.train_dir)

        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            session_config=config_proto,
            master=FLAGS.master,
            #is_chief=(FLAGS.task == 0),
            is_chief=1,
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            #number_of_steps=FLAGS.max_number_of_steps,
            number_of_steps=14,
            log_every_n_steps=FLAGS.log_every_n_steps,
            trace_every_n_steps=1,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            #        sync_optimizer=optimizer if FLAGS.sync_replicas else None)
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            log_fn_pref=FLAGS.log_fn_pref,
            doLogs=FLAGS.doLog)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_classification.get_dataset(
            FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes,
            FLAGS.labels_to_names_path)
        """
    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    """
        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=None,
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset
            print('label')
            print(label)

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=2 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)
            with tf.variable_scope("InceptionV3", reuse=True) as scope:
                # 辅助分类节点部分
                with slim.arg_scope(
                    [slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
                        stride=1,
                        padding="SAME"):
                    # 通过end_points取到Mixed_6e
                    aux_logits = end_points["Mixed_6e"]
                    with tf.variable_scope("AuxLogits", reuse=tf.AUTO_REUSE):
                        aux_logits = slim.avg_pool2d(aux_logits,
                                                     kernel_size=[5, 5],
                                                     stride=3,
                                                     padding="VALID",
                                                     scope="Avgpool_1a_5x5")
                        aux_logits = slim.conv2d(aux_logits,
                                                 num_outputs=128,
                                                 kernel_size=[1, 1],
                                                 scope="Conv2d_1b_1x1")
                        aux_logits = slim.conv2d(
                            aux_logits,
                            num_outputs=768,
                            kernel_size=[5, 5],
                            weights_initializer=trunc_normal(0.01),
                            padding="VALID",
                            scope="Conv2d_2a_5x5")
                        print('aux_logits')
                        print(aux_logits)
                        aux_logits = slim.conv2d(
                            aux_logits,
                            num_outputs=2,
                            kernel_size=[1, 1],
                            activation_fn=None,
                            normalizer_fn=None,
                            weights_initializer=trunc_normal(0.001),
                            scope="Conv2d_10b_1x1")
                        # 消除tensor中前两个维度为1的维度
                        aux_logits = tf.squeeze(aux_logits,
                                                axis=[1, 2],
                                                name="SpatialSqueeze")
                        end_points[
                            "AuxLogits"] = aux_logits  # 将辅助节点分类的输出aux_logits存到end_points中
            net = slim.dropout(logits, keep_prob=0.8, scope='Dropout_lb')
            net = tf.squeeze(net, axis=[1, 2])
            print('logits')
            print(logits)
            print('labels')
            print(labels)
            with tf.name_scope('output'):
                weights = tf.Variable(
                    tf.truncated_normal([2048, 2], stddev=0.001))
                biases = tf.Variable(tf.zeros([2]))
                logits2 = tf.matmul(net, weights) + biases
                final_tensor = tf.nn.softmax(logits2, name='prob')
                end_points["final_tensor"] = final_tensor
                end_points["labels"] = labels
                print('weights')
                print(weights)
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits2, labels=labels)
                cross_entropy_mean = tf.reduce_mean(cross_entropy)
                slim.losses.add_loss(cross_entropy_mean)
            #weights = tf.get_collection('weights')
            #print('weights')
            #print(weights)
            #biase = tf.add_n(tf.get_collection('biases'), 'loss2')
            """
      # 损失
      regularization_loss = tf.reduce_mean(tf.square(weights))
      hinge_loss = tf.reduce_mean(
          tf.square(
              tf.maximum(
                  tf.zeros([16, 2]), 1 - labels * logits
              )
          )
      )
      # with tf.name_scope("loss"):
      loss = regularization_loss + 1 * hinge_loss
      """
            """
      bottleneck_input = tf.squeeze(logits)
      # 全连接层
      with tf.name_scope('output'):
          weights = tf.Variable(tf.truncated_normal([2048, 2], stddev=0.001))
          biases = tf.Variable(tf.zeros([2]))
          logits = tf.matmul(bottleneck_input, weights) + biases
          final_tensor = tf.nn.softmax(logits, name='prob')
          # 损失
          regularization_loss = tf.reduce_mean(tf.square(weights))
          hinge_loss = tf.reduce_mean(
              tf.square(
                  tf.maximum(
                      tf.zeros([16, 2]), 1 - labels * logits
                  )
              )
          )
          # with tf.name_scope("loss"):
          my_loss = regularization_loss + 1.0 * hinge_loss
          slim.losses.add_loss(my_loss)
      """
            """
      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4,
            scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
      """
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        losses_sum = 0
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
            losses_sum += loss
        print('tf.GraphKeys.LOSSES')
        print(tf.GraphKeys.LOSSES)

        with tf.name_scope('evaluation'):
            correct_prediction = tf.equal(
                tf.argmax(end_points["final_tensor"], 1),
                tf.argmax(end_points["labels"], 1))
            evaluation_step = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32))
            summaries.add(tf.summary.scalar('accuracy', evaluation_step))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        print('total_loss')
        print(total_loss)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Пример #36
0
def train():
    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ####################
        # Select the network #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=FLAGS.NUM_CLASSES,
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        if hasattr(network_fn, 'rnn_part'):
            load_batch_size = FLAGS.batch_size * deploy_config.num_clones
        else:
            load_batch_size = FLAGS.batch_size
        with tf.device(deploy_config.inputs_device()):
            dataset_size, images, labels, video_name = async_loader.video_inputs(
                FLAGS.dataset_list,
                FLAGS.dataset_dir,
                FLAGS.resize_image_size,
                FLAGS.train_image_size,
                load_batch_size,
                FLAGS.n_steps,
                FLAGS.modality,
                FLAGS.read_stride,
                image_preprocessing_fn,
                shuffle=True,
                label_from_one=(FLAGS.labels_offset > 0),
                length1=FLAGS.length,
                crop=2,
                merge_label=FLAGS.merge_label)
            labels = slim.one_hot_encoding(labels, FLAGS.NUM_CLASSES)
            if hasattr(network_fn, 'rnn_part'):
                assert load_batch_size % FLAGS.n_steps == 0
                total_video_num = int(load_batch_size / FLAGS.n_steps)
                # Split images and labels for cnn
                split_images = tf.split(images, deploy_config.num_clones, 0)
                cnn_labels = labels
                if FLAGS.merge_label:
                    cnn_labels = tf.reshape(cnn_labels,
                                            [1, -1, FLAGS.NUM_CLASSES])
                    cnn_labels = tf.tile(cnn_labels, [FLAGS.n_steps, 1, 1])
                    cnn_labels = tf.reshape(cnn_labels,
                                            [-1, FLAGS.NUM_CLASSES])
                split_cnn_labels = tf.split(cnn_labels,
                                            deploy_config.num_clones, 0)
                # Split labels for rnn
                if not FLAGS.merge_label:
                    split_rnn_labels = tf.reshape(
                        labels,
                        [FLAGS.n_steps, total_video_num, FLAGS.NUM_CLASSES])
                    assert total_video_num % deploy_config.num_clones == 0
                    split_rnn_labels = tf.split(split_rnn_labels,
                                                deploy_config.num_clones, 1)
                    each_video_num = int(total_video_num /
                                         deploy_config.num_clones)
                    split_rnn_labels = [
                        tf.reshape(label, [
                            FLAGS.n_steps * each_video_num, FLAGS.NUM_CLASSES
                        ]) for label in split_rnn_labels
                    ]
                else:
                    split_rnn_labels = tf.split(labels,
                                                deploy_config.num_clones, 0)
            else:
                batch_queue = slim.prefetch_queue.prefetch_queue(
                    [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        if hasattr(network_fn, 'rnn_part'):
            cnn_outputs = []
            end_point_outputs = []

            def clone_bn_part(split_batchs, split_cnn_labels, cnn_outputs,
                              end_point_outputs):
                batch = split_batchs[0]
                split_batchs.remove(batch)
                logits, end_points = network_fn(batch)
                cnn_outputs.append(logits)
                end_point_outputs.append(end_points)
                labels = split_cnn_labels[0]
                split_cnn_labels.remove(labels)
                #############################
                # Specify the loss function #
                #############################
                if 'AuxLogits' in end_points:
                    tf.losses.softmax_cross_entropy(
                        logits=end_points['AuxLogits'],
                        onehot_labels=labels,
                        label_smoothing=FLAGS.label_smoothing,
                        weights=0.4,
                        scope='aux_loss')
                return end_points

            def clone_rnn(cnn_outputs, split_rnn_labels, end_point_outputs):
                cnn_output = cnn_outputs[0]
                cnn_outputs.remove(cnn_output)
                end_point_output = end_point_outputs[0]
                end_point_outputs.remove(end_point_output)
                labels = split_rnn_labels[0]
                split_rnn_labels.remove(labels)
                logits, end_points = network_fn.rnn_part(cnn_output)
                end_points.update(end_point_output)
                #############################
                # Specify the loss function #
                #############################
                tf.losses.softmax_cross_entropy(
                    logits=logits,
                    onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=1.0)
                return end_points

            # Run BN part, CNN and RNN should have different labels because of the different sample order
            model_deploy.create_clones(deploy_config,
                                       clone_bn_part, [
                                           split_images, split_cnn_labels,
                                           cnn_outputs, end_point_outputs
                                       ],
                                       gpu_offset=1)
            # Merge on another GPU to avoid transport data back to original GPUs
            assert len(
                model_deploy.get_available_gpus()) > deploy_config.num_clones
            with tf.device(deploy_config.clone_device(0)):
                # Concat all clones to one tensor
                cnn_outputs = tf.concat(values=cnn_outputs, axis=0)
                output_shape = cnn_outputs.get_shape().as_list()
                # Reshape to expose the video number dimension
                cnn_outputs = tf.reshape(cnn_outputs,
                                         [FLAGS.n_steps, total_video_num] +
                                         output_shape[1:])
                # Split in the video number dimension, so that each clone has an input for lstm
                cnn_outputs = tf.split(cnn_outputs, deploy_config.num_clones,
                                       1)
                # Merge n_steps and video number dimension
                cnn_outputs = [
                    tf.reshape(output, [-1] + output_shape[1:])
                    for output in cnn_outputs
                ]
            # Run RNN part on another GPU #deploy_config.num_clones
            #  clones = model_deploy.create_extra_clones_on_another_gpu(deploy_config, clone_rnn,
            #  [cnn_outputs, split_rnn_labels, end_point_outputs])
            clones = model_deploy.create_clones(
                deploy_config,
                clone_rnn, [cnn_outputs, split_rnn_labels, end_point_outputs],
                gpu_offset=1)
        else:

            def clone_fn(batch_queue):
                """Allows data parallelism by creating multiple clones of network_fn."""
                images, labels = batch_queue.dequeue()
                logits, end_points = network_fn(images)
                #############################
                # Specify the loss function #
                #############################
                if 'AuxLogits' in end_points:
                    tf.losses.softmax_cross_entropy(
                        logits=end_points['AuxLogits'],
                        onehot_labels=labels,
                        label_smoothing=FLAGS.label_smoothing,
                        weights=0.4,
                        scope='aux_loss')
                tf.losses.softmax_cross_entropy(
                    logits=logits,
                    onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=1.0)
                return end_points

            clones = model_deploy.create_clones(deploy_config, clone_fn,
                                                [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = train_util._configure_learning_rate(global_step)
            optimizer = train_util._configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = train_util._get_variables_to_train()
        # Variables to restore and decay
        variables_to_restore = train_util._get_variables_to_restore()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones,
            optimizer,
            var_list=variables_to_train,
            gate_gradients=optimizer.GATE_OP,
            colocate_gradients_with_ops=True)

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Gradient decay and clipping
        if not FLAGS.no_decay:
            # Set up learning rate decay
            lr_mul = {var: 0.1 for var in variables_to_restore}
            clones_gradients = tf.contrib.slim.learning.multiply_gradients(
                clones_gradients, lr_mul)
        if FLAGS.grad_clipping is not None:
            clones_gradients = tf.contrib.slim.learning.clip_gradient_norms(
                clones_gradients, FLAGS.grad_clipping)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement)
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=train_util._get_init_fn(variables_to_restore),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None,
            trace_every_n_steps=FLAGS.trace_every_n_steps,
            session_config=sess_config)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpus
  if FLAGS.num_clones == -1:
    FLAGS.num_clones = len(FLAGS.gpus.split(','))

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    # tf.set_random_seed(42)
    tf.set_random_seed(0)
    ######################
    # Config model_deploy#
    ######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name,
        FLAGS.dataset_dir.split(','),
        dataset_list_dir=FLAGS.dataset_list_dir,
        num_samples=FLAGS.frames_per_video,
        modality=FLAGS.modality,
        split_id=FLAGS.split_id)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        batch_size=FLAGS.batch_size,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        dropout_keep_prob=(1.0-FLAGS.dropout),
        pooled_dropout_keep_prob=(1.0-FLAGS.pooled_dropout),
        batch_norm=FLAGS.netvlad_batch_norm)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)  # in case of pooling images,
                           # now preprocessing is done video-level

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size,
        bgr_flips=FLAGS.bgr_flip)
      [image, label] = provider.get(['image', 'label'])
      # now note that the above image might be a 23 channel image if you have
      # both RGB and flow streams. It will need to split later, but all the
      # preprocessing will be done consistently for all frames over all streams
      label = tf.string_to_number(label, tf.int32)
      label.set_shape(())
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      scale_ratios=[float(el) for el in FLAGS.scale_ratios.split(',')],
      image = image_preprocessing_fn(image, train_image_size,
                                     train_image_size,
                                     scale_ratios=scale_ratios,
                                     out_dim_scale=FLAGS.out_dim_scale,
                                     model_name=FLAGS.model_name)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      if FLAGS.debug:
        images = tf.Print(images, [labels], 'Read batch')
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)
      summarize_images(images, provider.num_channels_stream)

    ####################
    # Define the model #
    ####################
    kwargs = {}
    if FLAGS.conv_endpoint is not None:
      kwargs['conv_endpoint'] = FLAGS.conv_endpoint
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(
          images, pool_type=FLAGS.pooling,
          classifier_type=FLAGS.classifier_type,
          num_channels_stream=provider.num_channels_stream,
          netvlad_centers=FLAGS.netvlad_initCenters.split(','),
          stream_pool_type=FLAGS.stream_pool_type,
          **kwargs)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=FLAGS.label_smoothing, weight=0.4, scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weight=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    global end_points_debug
    end_points = clones[0].outputs
    end_points_debug = dict(end_points)
    end_points_debug['images'] = images
    end_points_debug['labels'] = labels
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.histogram_summary('activations/' + end_point, x))
      summaries.add(tf.scalar_summary('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.histogram_summary(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.scalar_summary('learning_rate', learning_rate,
                                      name='learning_rate'))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()
    logging.info('Training the following variables: %s' % (
      ' '.join([el.name for el in variables_to_train])))

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)

    # clip the gradients if needed
    if FLAGS.clip_gradients > 0:
      logging.info('Clipping gradients by %f' % FLAGS.clip_gradients)
      with tf.name_scope('clip_gradients'):
        clones_gradients = slim.learning.clip_gradient_norms(
            clones_gradients,
            FLAGS.clip_gradients)

    # Add total_loss to summary.
    summaries.add(tf.scalar_summary('total_loss', total_loss,
                                    name='total_loss'))

    # Create gradient updates.
    train_ops = {}
    if FLAGS.iter_size == 1:
      grad_updates = optimizer.apply_gradients(clones_gradients,
                                               global_step=global_step)
      update_ops.append(grad_updates)

      update_op = tf.group(*update_ops)
      train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
                                                        name='train_op')
      train_ops = train_tensor
    else:
      gvs = [(grad, var) for grad, var in clones_gradients]
      varnames = [var.name for grad, var in gvs]
      varname_to_var = {var.name: var for grad, var in gvs}
      varname_to_grad = {var.name: grad for grad, var in gvs}
      varname_to_ref_grad = {}
      for vn in varnames:
        grad = varname_to_grad[vn]
        print("accumulating ... ", (vn, grad.get_shape()))
        with tf.variable_scope("ref_grad"):
          with tf.device(deploy_config.variables_device()):
            ref_var = slim.local_variable(
                np.zeros(grad.get_shape(),dtype=np.float32),
                name=vn[:-2])
            varname_to_ref_grad[vn] = ref_var

      all_assign_ref_op = [ref.assign(varname_to_grad[vn]) for vn, ref in varname_to_ref_grad.items()]
      all_assign_add_ref_op = [ref.assign_add(varname_to_grad[vn]) for vn, ref in varname_to_ref_grad.items()]
      assign_gradients_ref_op = tf.group(*all_assign_ref_op)
      accmulate_gradients_op = tf.group(*all_assign_add_ref_op)
      with tf.control_dependencies([accmulate_gradients_op]):
        final_gvs = [(varname_to_ref_grad[var.name] / float(FLAGS.iter_size), var) for grad, var in gvs]
        apply_gradients_op = optimizer.apply_gradients(final_gvs, global_step=global_step)
        update_ops.append(apply_gradients_op)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
            total_loss, name='train_op')
      for i in range(FLAGS.iter_size):
        if i == 0:
          train_ops[i] = assign_gradients_ref_op
        elif i < FLAGS.iter_size - 1:  # because apply_gradients also computes
                                       # (see control_dependency), so
                                       # no need of running an extra iteration
          train_ops[i] = accmulate_gradients_op
        else:
          train_ops[i] = train_tensor


    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.merge_summary(list(summaries), name='summary_op')

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.intra_op_parallelism_threads = FLAGS.cpu_threads
    # config.allow_soft_placement = True
    # config.gpu_options.per_process_gpu_memory_fraction=0.7

    ###########################
    # Kicks off the training. #
    ###########################
    logging.info('RUNNING ON SPLIT %d' % FLAGS.split_id)
    slim.learning.train(
        train_ops,
        train_step_fn=train_step,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None,
        session_config=config)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)

        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        image_count, class_count = get_image_and_class_count(
            FLAGS.dataset_dir, 'train')
        dataset = get_dataset('aerial', FLAGS.dataset_dir, image_count,
                              class_count, 'train')
        network_fn = get_network_fn(num_classes=(dataset.num_classes),
                                    weight_decay=FLAGS.weight_decay)
        image_preprocessing_fn = get_preprocessing()

        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            image = image_preprocessing_fn(image, 224, 224)
            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        def clone_fn(batch_queue):
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)
            logits = tf.squeeze(logits)  # added -- does this help?
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return (end_points)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(deploy_config.optimizer_device()):
            decay_steps = int(dataset.num_samples / FLAGS.batch_size *
                              FLAGS.num_epochs_per_decay)
            learning_rate = tf.train.exponential_decay(
                FLAGS.learning_rate,
                global_step,
                decay_steps,
                FLAGS.learning_rate_decay_factor,
                staircase=True,
                name='exponential_decay_learning_rate')
            optimizer = tf.train.RMSPropOptimizer(
                learning_rate,
                decay=FLAGS.rmsprop_decay,
                momentum=FLAGS.rmsprop_momentum,
                epsilon=FLAGS.opt_epsilon)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        variables_to_train = _get_variables_to_train()
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=_get_init_fn(),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            sync_optimizer=None)
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = FLAGS.train_image_size or network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      with tf.device(deploy_config.inputs_device()):
        images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        tf.losses.softmax_cross_entropy(
            logits=end_points['AuxLogits'], onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
      tf.losses.softmax_cross_entropy(
          logits=logits, onehot_labels=labels,
          label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer = tf.train.SyncReplicasOptimizer(
          opt=optimizer,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Пример #40
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=1,
        clone_on_cpu=False,
        replica_id=0,
        num_replicas=1,
        num_ps_tasks=0)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        'flowers', 'train', FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        'mobilenet_v1',
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        'mobilenet_v1',
        is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=4,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= FLAGS.labels_offset

      train_image_size = network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=4,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):

      num_epochs_per_decay = 2.5
      decay_steps = int(dataset.num_samples / FLAGS.batch_size *
                        num_epochs_per_decay)
      learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                  global_step,
                                  decay_steps,
                                  _LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True,
                                    name='exponential_decay_learning_rate')

      optimizer = tf.train.RMSPropOptimizer(
                           learning_rate,
                           decay=FLAGS.rmsprop_decay,
                           momentum=FLAGS.rmsprop_momentum,
                           epsilon=FLAGS.opt_epsilon)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=True,
        session_config=session_config,
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=10,
        save_summaries_secs=300,
        save_interval_secs=300,
        sync_optimizer=optimizer if False else None)