예제 #1
0
            [image, label] = provider.get(['image', 'label'])

            image = image_preprocessing_fn(image, params['height'],
                                           params['width'])
            images, labels = tf.train.batch(
                [image, label],
                batch_size=args.batch_size,
                num_threads=args.preprocessing_threads,
                capacity=5 * args.batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)

            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    with tf.name_scope('synchronized_train'):
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf.train.exponential_decay(
                args.learning_rate,
                global_step,
                args.learning_rate_decay_steps,
                args.learning_rate_decay,
                staircase=True,
                name='exponential_decay_learning_rate')
            optimizer = tf.train.AdamOptimizer(learning_rate)
        variables_to_train = tf.trainable_variables()
        total_loss, clones_gradients = model_deploy.optimize_clones(
예제 #2
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = segmentation_dataset.get_dataset(FLAGS.dataset,
                                               FLAGS.train_split,
                                               dataset_dir=FLAGS.dataset_dir)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            samples = input_generator.get(
                dataset,
                FLAGS.train_crop_size,
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                common.OUTPUT_TYPE: dataset.num_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
def main(_):

    if not os.path.isdir(FLAGS.train_dir):
        os.makedirs(FLAGS.train_dir)

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    if not FLAGS.aug_mode:
        raise ValueError('aug_mode need to be speficied.')

    if (not FLAGS.train_image_height) or (not FLAGS.train_image_width):
        raise ValueError(
            'The image height and width must be define explicitly.')

    if FLAGS.hd_data:
        if FLAGS.train_image_height != 400 or FLAGS.train_image_width != 200:
            FLAGS.train_image_height, FLAGS.train_image_width = 400, 200
            print("set the image size to (%d, %d)" % (400, 200))

    # config and print log
    config_and_print_log(FLAGS)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        #####################################
        # Select the preprocessing function #
        #####################################
        img_func = get_img_func()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.DataLoader(FLAGS.model_name,
                                             FLAGS.dataset_name,
                                             FLAGS.dataset_dir, FLAGS.set,
                                             FLAGS.hd_data, img_func,
                                             FLAGS.batch_size, FLAGS.batch_k,
                                             FLAGS.max_number_of_steps,
                                             get_pair_type())

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            sample_number=FLAGS.sample_number)

        ####################
        # Define the model #
        ####################
        def clone_fn(tf_batch_queue):
            return build_graph(tf_batch_queue, network_fn)

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [dataset.tf_batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        # Add summaries for losses.
        loss_dict = {}
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            if loss.name == 'softmax_cross_entropy_loss/value:0':
                loss_dict['clf'] = loss
            elif 'softmax_cross_entropy_loss' in loss.name:
                loss_dict['sample_clf_' +
                          str(loss.name.split('/')[0].split('_')[-1])] = loss
            elif 'entropy' in loss.name:
                loss_dict['entropy'] = loss
            else:
                raise Exception('Loss type error')

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step, FLAGS)
            optimizer = _configure_optimizer(learning_rate)

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        # total_loss is the sum of all LOSSES and REGULARIZATION_LOSSES in tf.GraphKeys
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        train_tensor_list = [train_tensor]
        format_str = 'step %d, loss = %.2f'

        for loss_key in sorted(loss_dict.keys()):
            train_tensor_list.append(loss_dict[loss_key])
            format_str += (', %s_loss = ' % loss_key + '%.8f')

        format_str += ' (%.1f examples/sec; %.3f sec/batch)'

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')

        ###########################
        # Kicks off the training. #
        ###########################
        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # load pretrained weights
        if FLAGS.checkpoint_path is not None:
            print("Load the pretrained weights")
            weight_ini_fn = _get_init_fn()
            weight_ini_fn(sess)
        else:
            print("Train from the scratch")

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        # for step in xrange(FLAGS.max_number_of_steps):
        for step in xrange(FLAGS.max_number_of_steps + 1):
            start_time = time.time()

            loss_value_list = sess.run(train_tensor_list,
                                       feed_dict=dataset.get_feed_dict())

            duration = time.time() - start_time

            # assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % FLAGS.log_every_n_steps == 0:
                # num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                # sec_per_batch = duration / FLAGS.num_gpus
                sec_per_batch = duration

                print(format_str % tuple([step] + loss_value_list +
                                         [examples_per_sec, sec_per_batch]))

            # Save the model checkpoint periodically.
            # if step % FLAGS.model_snapshot_steps == 0 or (step + 1) == FLAGS.max_number_of_steps:
            if step % FLAGS.model_snapshot_steps == 0:
                saver.save(sess, checkpoint_path, global_step=step)

        print('OK...')
def main(_):
  # if not FLAGS.dataset_dir:
  #   raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    # dataset = dataset_factory.get_dataset(
    #     FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ######################
    # Select the network #
    ######################
    # network_fn = nets_factory.get_network_fn(
    #     FLAGS.model_name,
    #     num_classes=(dataset.num_classes - FLAGS.labels_offset),
    #     weight_decay=FLAGS.weight_decay,
    #     is_training=True)

    def localization_net_alpha(inputs, num_transformer, num_theta_params):
        """
        Utilize inception_v2 as the localization net of spatial transformer
        """
        # outputs 7*7*1024: default final_endpoint='Mixed_5c' before full connection layer
        with tf.variable_scope('inception_net'):
            net, _ = inception_v2.inception_v2_base(inputs)

        # fc layer using [1, 1] convolution kernel: 1*1*1024
        with tf.variable_scope('logits'):
            net = slim.conv2d(net, 128, [1, 1], scope='conv2d_a_1x1')
            kernel_size = inception_v2._reduced_kernel_size_for_small_input(net, [7, 7])
            net = slim.conv2d(net, 128, kernel_size, padding='VALID', scope='conv2d_b_{}x{}'.format(*kernel_size))
            init_biase = tf.constant_initializer([1.1, .0, 1.1, .0] * num_transformer)
            logits = slim.conv2d(net, num_transformer * num_theta_params, [1, 1],
                                 weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                                 biases_initializer=init_biase,
                                 normalizer_fn=None, activation_fn=tf.nn.tanh, scope='conv2d_c_1x1')

            return tf.squeeze(logits, [1, 2])

    def _inception_logits(inputs, num_outputs, dropout_keep_prob, activ_fn=None):
        with tf.variable_scope('logits'):
            kernel_size = inception_v2._reduced_kernel_size_for_small_input(inputs, [7, 7])
            # shape ?*1*1*?
            net = slim.avg_pool2d(inputs, kernel_size, padding='VALID')
            # drop out neuron before fc conv
            net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='dropout')
            # [1, 1] fc conv
            logits = slim.conv2d(net, num_outputs, [1, 1], normalizer_fn=None, activation_fn=activ_fn,
                                 scope='conv2_a_1x1')

        return tf.squeeze(logits, [1, 2])

    def network_fn(inputs):
        """Fine grained classification with multiplex spatial transformation channels utilizing inception nets

                """
        end_points = {}
        arg_scope = inception_v2.inception_v2_arg_scope(weight_decay=FLAGS.weight_decay)
        with slim.arg_scope(arg_scope):
            with tf.variable_scope('stn'):
                with tf.variable_scope('localization'):
                    transformer_theta = localization_net_alpha(inputs, NUM_TRANSFORMER, NUM_THETA_PARAMS)
                    transformer_theta_split = tf.split(transformer_theta, NUM_TRANSFORMER, axis=1)
                    end_points['stn/localization/transformer_theta'] = transformer_theta

                transformer_outputs = []
                for theta in transformer_theta_split:
                    transformer_outputs.append(
                        transformer(inputs, theta, transformer_output_size, sampling_kernel='bilinear'))

                inception_outputs = []
                transformer_outputs_shape = [FLAGS.batch_size, transformer_output_size[0],
                                             transformer_output_size[1], 3]
                with tf.variable_scope('classification'):
                    for path_idx, inception_inputs in enumerate(transformer_outputs):
                        with tf.variable_scope('path_{}'.format(path_idx)):
                            inception_inputs.set_shape(transformer_outputs_shape)
                            net, _ = inception_v2.inception_v2_base(inception_inputs)
                            inception_outputs.append(net)
                    # concatenate the endpoints: num_batch*7*7*(num_transformer*1024)
                    multipath_outputs = tf.concat(inception_outputs, axis=-1)

                    # final fc layer logits
                    classification_logits = _inception_logits(multipath_outputs, NUM_CLASSES, dropout_keep_prob)
                    end_points['stn/classification/logits'] = classification_logits

        return classification_logits, end_points


    #####################################
    # Select the preprocessing function #
    #####################################
    # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    # image_preprocessing_fn = preprocessing_factory.get_preprocessing(
    #     preprocessing_name,
    #     is_training=True)

    def image_preprocessing_fn(image, out_height, out_width):
        if image.dtype != tf.float32:
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        image = tf.image.central_crop(image, central_fraction=0.975)
        image = tf.expand_dims(image, 0)
        image = tf.image.resize_bilinear(image, [out_height, out_width], align_corners=False)
        image = tf.squeeze(image, [0])
        image = tf.image.random_flip_left_right(image)
        image = tf.subtract(image, 0.5)
        image = tf.multiply(image, 2.0)
        image.set_shape((out_height, out_width, 3))
        return image

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    def _get_filename_list(file_dir, file):
        filename_path = os.path.join(file_dir, file)
        filename_list = []
        cls_label_list = []
        with open(filename_path, 'r') as f:
            for line in f:
                filename, label, nid, attr = line.strip().split(',')
                filename_list.append(filename)
                cls_label_list.append(int(label))

        return filename_list, cls_label_list

    with tf.device(deploy_config.inputs_device()):
      # create the filename and label example
      filename_list, label_list = _get_filename_list(filename_dir, file)
      num_samples = len(filename_list)
      filename, label = tf.train.slice_input_producer([filename_list, label_list], num_epochs)

      # decode and preprocess the image
      file_content = tf.read_file(filename)
      image = tf.image.decode_jpeg(file_content, channels=3)

      train_image_size = FLAGS.train_image_size or default_image_size
      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=FLAGS.batch_size,
          num_threads=FLAGS.num_preprocessing_threads,
          capacity=5 * FLAGS.batch_size)
      labels = slim.one_hot_encoding(
          labels, NUM_CLASSES - FLAGS.labels_offset)
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      with tf.device(deploy_config.inputs_device()):
        images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        tf.losses.softmax_cross_entropy(
            logits=end_points['AuxLogits'], onehot_labels=labels,
            label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
      tf.losses.softmax_cross_entropy(
          logits=logits, onehot_labels=labels,
          label_smoothing=FLAGS.label_smoothing, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    #################################
    # Configure the moving averages #
    #################################
    if FLAGS.moving_average_decay:
      moving_average_variables = slim.get_model_variables()
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, global_step)
    else:
      moving_average_variables, variable_averages = None, None

    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate_loc = _configure_learning_rate_loc(num_samples, global_step)
      learning_rate_cls = _configure_learning_rate_cls(num_samples, global_step)
      optimizer_loc = _configure_optimizer(learning_rate_loc)
      optimizer_cls = _configure_optimizer(learning_rate_cls)
      summaries.add(tf.summary.scalar('learning_rate_loc', learning_rate_loc))
      summaries.add(tf.summary.scalar('learning_rate_cls', learning_rate_cls))

    if FLAGS.sync_replicas:
      # If sync_replicas is enabled, the averaging will be done in the chief
      # queue runner.
      optimizer_loc = tf.train.SyncReplicasOptimizer(
          opt=optimizer_loc,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
      optimizer_cls = tf.train.SyncReplicasOptimizer(
          opt=optimizer_cls,
          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
          variable_averages=variable_averages,
          variables_to_average=moving_average_variables,
          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
          total_num_replicas=FLAGS.worker_replicas)
    elif FLAGS.moving_average_decay:
      # Update ops executed locally by trainer.
      update_ops.append(variable_averages.apply(moving_average_variables))

    # Variables to train.
    # variables_to_train = _get_variables_to_train()

    loc_vars_to_train = _get_localization_vars_to_train(loc_train_vars_scope)
    cls_vars_to_train = _get_classification_vars_to_train(cls_train_vars_scope)

    #  and returns a train_tensor and summary_op
    _, clones_gradients_loc = model_deploy.optimize_clones(
        clones,
        optimizer_loc,
        var_list=loc_vars_to_train)
    total_loss, clones_gradients_cls = model_deploy.optimize_clones(
        clones,
        optimizer_cls,
        var_list=cls_vars_to_train)

    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates_loc = optimizer_loc.apply_gradients(clones_gradients_loc)
    grad_updates_cls = optimizer_cls.apply_gradients(clones_gradients_cls, global_step=global_step)
    update_ops.append(grad_updates_loc)
    update_ops.append(grad_updates_cls)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')


    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master=FLAGS.master,
        is_chief=(FLAGS.task == 0),
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        sync_optimizer=None)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Config model_deploy #
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # Select the dataset #
        dataset = nsfw.get_split(FLAGS.dataset_split_name, FLAGS.dataset_dir)

        # Select the network #
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        # Select the preprocessing function #
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        # Create a dataset provider that loads data from the dataset #
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        # Define the model #
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            # Specify the loss function #
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # Configure the moving averages #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # Configure the optimization procedure. #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Kicks off the training. #
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
예제 #6
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            iterator = coco.get_dataset(FLAGS.train_data_file,
                                        batch_size=FLAGS.batch_size,
                                        num_epochs=500,
                                        buffer_size=250 * FLAGS.num_clones,
                                        num_parallel_calls=4 *
                                        FLAGS.num_clones,
                                        crop_height=FLAGS.height,
                                        crop_width=FLAGS.width,
                                        resize_shape=FLAGS.width,
                                        data_augment=True)

        def clone_fn(iterator):
            with tf.device(deploy_config.inputs_device()):
                batch_image, batch_labels = iterator.get_next()

            s = batch_labels.get_shape().as_list()
            batch_labels.set_shape([FLAGS.batch_size, s[1], s[2], s[3]])

            s = batch_image.get_shape().as_list()
            batch_image.set_shape([FLAGS.batch_size, s[1], s[2], s[3]])

            num_classes = coco.num_classes()

            logits, end_points = resseg_model(
                batch_image,
                FLAGS.height,
                FLAGS.width,
                FLAGS.scale,
                FLAGS.weight_decay,
                FLAGS.use_seperable_convolution,
                num_classes,
                is_training=True,
                use_batch_norm=FLAGS.use_batch_norm,
                num_units=FLAGS.num_units,
                filter_depth_multiplier=FLAGS.filter_depth_multiplier)

            s = logits.get_shape().as_list()
            with tf.device(deploy_config.inputs_device()):
                lmap_size = 256
                lmap = np.array([0] * lmap_size)
                for k, v in coco.id2trainid_objects.items():
                    lmap[k] = v + 1
                lmap = tf.constant(lmap, tf.uint8)
                down_labels = tf.cast(batch_labels, tf.int32)
                label_mask = tf.squeeze((down_labels < 255))
                down_labels = tf.gather(lmap, down_labels)
                down_labels = tf.cast(down_labels, tf.int32)
                down_labels = tf.reshape(
                    down_labels, tf.TensorShape([FLAGS.batch_size, s[1],
                                                 s[2]]))
                down_labels = tf.cast(label_mask, tf.int32) * down_labels

                fg_weights = tf.constant(FLAGS.foreground_weight,
                                         dtype=tf.int32,
                                         shape=label_mask.shape)
                label_weights = tf.cast(label_mask, tf.int32) * fg_weights

            # Specify the loss
            cross_entropy = tf.losses.sparse_softmax_cross_entropy(
                down_labels, logits, weights=label_weights, scope='xentropy')
            tf.losses.add_loss(cross_entropy)

            return end_points, batch_image, down_labels, logits

        # Gather initial summaries
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [iterator])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(
                FLAGS.num_samples_per_epoch, global_step,
                deploy_config.num_clones)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        end_points, batch_image, down_labels, logits = clones[0].outputs

        cmap = np.array(coco.id2color)
        cmap = tf.constant(cmap, tf.uint8)
        seg_map = tf.gather(cmap, down_labels)

        predictions = tf.argmax(logits, axis=3)
        pred_map = tf.gather(cmap, predictions)

        summaries.add(tf.summary.image('labels', seg_map))
        summaries.add(tf.summary.image('predictions', pred_map))
        summaries.add(tf.summary.image('images', batch_image))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        # Returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if FLAGS.sync_replicas:
            sync_optimizer = opt
            startup_delay_steps = 0
        else:
            sync_optimizer = None
            startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        ###########################
        # Kick off the training.  #
        ###########################
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master=FLAGS.master,
                            is_chief=(FLAGS.task == 0),
                            init_fn=_get_init_fn(),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            startup_delay_steps=startup_delay_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            sync_optimizer=sync_optimizer)
예제 #7
0
def main(model_root, datasets_dir, model_name):
    # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    # 训练相关参数设置
    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=False,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=num_ps_tasks)

        global_step = slim.create_global_step()

        train_dir = os.path.join(model_root, model_name)
        dataset = convert_data.get_datasets('train', dataset_dir=datasets_dir)

        network_fn = net_select.get_network_fn(model_name,
                                               num_classes=dataset.num_classes,
                                               weight_decay=weight_decay,
                                               is_training=True)

        image_preprocessing_fn = preprocessing_select.get_preprocessing(
            model_name, is_training=True)

        print("the data_sources:", dataset.data_sources)

        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=num_readers,
                common_queue_capacity=20 * batch_size,
                common_queue_min=10 * batch_size)
            [image, label] = provider.get(['image', 'label'])

            train_image_size = network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.compat.v1.train.batch(
                [image, label],
                batch_size=batch_size,
                num_threads=num_preprocessing_threads,
                capacity=5 * batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        def calculate_pooling_center_loss(features, label, alfa, nrof_classes,
                                          weights, name):
            features = tf.reshape(features, [features.shape[0], -1])
            label = tf.argmax(label, 1)

            nrof_features = features.get_shape()[1]
            centers = tf.compat.v1.get_variable(
                name, [nrof_classes, nrof_features],
                dtype=tf.float32,
                initializer=tf.constant_initializer(0),
                trainable=False)
            label = tf.reshape(label, [-1])
            centers_batch = tf.gather(centers, label)
            centers_batch = tf.nn.l2_normalize(centers_batch, axis=-1)

            diff = (1 - alfa) * (centers_batch - features)
            centers = tf.compat.v1.scatter_sub(centers, label, diff)

            with tf.control_dependencies([centers]):
                distance = tf.square(features - centers_batch)
                distance = tf.reduce_sum(distance, axis=-1)
                center_loss = tf.reduce_mean(distance)

            center_loss = tf.identity(center_loss * weights,
                                      name=name + '_loss')
            return center_loss

        def attention_crop(attention_maps):
            '''
            利用attention map 做数据增强,这里是论文中的Crop Mask
            :param attention_maps: Feature maps降维得到的
            :return:
            '''
            batch_size, height, width, num_parts = attention_maps.shape
            bboxes = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                part_weights = part_weights / np.sum(part_weights)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]

                mask = attention_map[:, :, selected_index]

                threshold = random.uniform(0.4, 0.6)
                itemindex = np.where(mask >= mask.max() * threshold)

                ymin = itemindex[0].min() / height - 0.1
                ymax = itemindex[0].max() / height + 0.1
                xmin = itemindex[1].min() / width - 0.1
                xmax = itemindex[1].max() / width + 0.1

                bbox = np.asarray([ymin, xmin, ymax, xmax], dtype=np.float32)
                bboxes.append(bbox)
            bboxes = np.asarray(bboxes, np.float32)
            return bboxes

        def attention_drop(attention_maps):
            '''
            这里是attention drop部分,目的是为了让模型可以注意到物体的其他部位(因不同attention map可能聚焦了同一部位)
            :param attention_maps:
            :return:
            '''
            batch_size, height, width, num_parts = attention_maps.shape
            masks = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                if (np.sum(part_weights) != 0):
                    part_weights = part_weights / np.sum(part_weights)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]
                mask = attention_map[:, :, selected_index:selected_index + 1]

                # soft mask
                threshold = random.uniform(0.2, 0.5)
                mask = (mask < threshold * mask.max()).astype(np.float32)
                masks.append(mask)
            masks = np.asarray(masks, dtype=np.float32)
            return masks

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits_1, end_points_1 = network_fn(images)

            attention_maps = end_points_1['attention_maps']
            attention_maps = tf.image.resize(
                attention_maps, [train_image_size, train_image_size],
                method=tf.image.ResizeMethod.BILINEAR)

            # attention crop
            bboxes = tf.compat.v1.py_func(attention_crop, [attention_maps],
                                          [tf.float32])
            bboxes = tf.reshape(bboxes, [batch_size, 4])
            box_ind = tf.range(batch_size, dtype=tf.int32)
            images_crop = tf.image.crop_and_resize(
                images,
                bboxes,
                box_ind,
                crop_size=[train_image_size, train_image_size])

            # attention drop
            masks = tf.compat.v1.py_func(attention_drop, [attention_maps],
                                         [tf.float32])
            masks = tf.reshape(
                masks, [batch_size, train_image_size, train_image_size, 1])
            images_drop = images * masks

            logits_2, end_points_2 = network_fn(images_crop, reuse=True)
            logits_3, end_points_3 = network_fn(images_drop, reuse=True)

            slim.losses.softmax_cross_entropy(logits_1,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_1')
            slim.losses.softmax_cross_entropy(logits_2,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_2')
            slim.losses.softmax_cross_entropy(logits_3,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_3')

            embeddings = end_points_1['embeddings']
            center_loss = calculate_pooling_center_loss(
                features=embeddings,
                label=labels,
                alfa=0.95,
                nrof_classes=dataset.num_classes,
                weights=1.0,
                name='center_loss')
            slim.losses.add_loss(center_loss)

            return end_points_1

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES,
                                                first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step)
            optimizer = configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = get_variables_to_train(trainable_scopes)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES,
                                        first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.compat.v1.summary.merge_all()

        config = tf.compat.v1.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)
        config.gpu_options.allow_growth = True
        config.gpu_options.visible_device_list = "0"

        save_model_path = os.path.join(checkpoint_path, model_name,
                                       "%s.ckpt" % model_name)
        print(save_model_path)

        # saver = tf.compat.v1.train.import_meta_graph('%s.meta'%save_model_path, clear_devices=True)
        tf.compat.v1.disable_eager_execution()
        # train the model
        slim.learning.train(
            train_op=train_tensor,
            logdir=train_dir,
            is_chief=(task == 0),
            init_fn=_get_init_fn(save_model_path, train_dir=train_dir),
            summary_op=summary_op,
            number_of_steps=max_number_of_steps,
            log_every_n_steps=log_every_n_steps,
            save_summaries_secs=save_summaries_secs,
            save_interval_secs=save_interval_secs,
            # sync_optimizer=None,
            session_config=config)