def load_imagenet(ckpt_path):
    """Initialize the network parameters for our xception-lite using ImageNet pretrained weight
    Args:
    Path to the checkpoint
    Returns:
    Function that takes a session and initializes the network

    # ckpt_path: the full path to the model checkpoint (pre-trained model)
    # vars_corresp: A list of `Variable` objects or a dictionary mapping names in the
    # checkpoint (pre-trained model) to the corresponding variables to initialize.
    """

    reader = tf.train.NewCheckpointReader(ckpt_path)
    var_to_shape_map = reader.get_variable_to_shape_map()

    vars_corresp = dict()

    for v in var_to_shape_map:
        if "entry_flow" in v and 'gamma' not in v and 'depthwise/BatchNorm' not in v and 'Momentum' not in v:
            vars_corresp[v] = slim.get_model_variables('xfcn/' + v)[0]
        elif "middle_flow" in v and 'gamma' not in v and 'depthwise/BatchNorm' not in v and 'Momentum' not in v:
            for i in range(1, 3):
                if 'unit_{}/'.format(i) in v:
                    vars_corresp[v] = slim.get_model_variables('xfcn/' + v)[0]
        elif "exit_flow" in v and 'gamma' not in v and 'depthwise/BatchNorm' not in v and 'Momentum' not in v:
            if 'block1/' in v:
                vars_corresp[v] = slim.get_model_variables('xfcn/' + v)[0]
        elif 'shortcut' in v and 'Momentum' not in v and 'gamma' not in v:
            vars_corresp[v] = slim.get_model_variables('xfcn/' + v)[0]

    init_fn = slim.assign_from_checkpoint_fn(ckpt_path, vars_corresp)

    return init_fn
def _get_init_fn():
  """Returns a function to initialize model from a checkpoint."""
  if FLAGS.checkpoint_path is None:
    return None

  # Warn the user if a checkpoint exists in the train_dir. Then we'll be
  # ignoring the checkpoint anyway.
  if tf.train.latest_checkpoint(FLAGS.train_dir):
    tf.logging.info(
        'Ignoring --checkpoint_path because a checkpoint already exists in %s',
        FLAGS.train_dir)
    return None

  exclusions = []
  if FLAGS.checkpoint_exclude_scopes:
    exclusions = [scope.strip()
                  for scope in FLAGS.checkpoint_exclude_scopes.split(',')]

  variables_to_restore = []
  for var in slim.get_model_variables():
    for exclusion in exclusions:
      if var.op.name.startswith(exclusion):
        break
    else:
      variables_to_restore.append(var)

  if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
    checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
  else:
    checkpoint_path = FLAGS.checkpoint_path

  tf.logging.info('Fine-tuning from %s', checkpoint_path)

  return slim.assign_from_checkpoint_fn(checkpoint_path, variables_to_restore)
Ejemplo n.º 3
0
 def testModelVariables(self):
   batch_size = 5
   height, width = 231, 231
   num_classes = 1000
   with self.test_session():
     inputs = tf.random.uniform((batch_size, height, width, 3))
     overfeat.overfeat(inputs, num_classes)
     expected_names = ['overfeat/conv1/weights',
                       'overfeat/conv1/biases',
                       'overfeat/conv2/weights',
                       'overfeat/conv2/biases',
                       'overfeat/conv3/weights',
                       'overfeat/conv3/biases',
                       'overfeat/conv4/weights',
                       'overfeat/conv4/biases',
                       'overfeat/conv5/weights',
                       'overfeat/conv5/biases',
                       'overfeat/fc6/weights',
                       'overfeat/fc6/biases',
                       'overfeat/fc7/weights',
                       'overfeat/fc7/biases',
                       'overfeat/fc8/weights',
                       'overfeat/fc8/biases',
                      ]
     model_variables = [v.op.name for v in slim.get_model_variables()]
     self.assertSetEqual(set(model_variables), set(expected_names))
Ejemplo n.º 4
0
def _get_init_fn(checkpoint_path, train_dir):
    if checkpoint_path is None:
        return None
    if tf.train.latest_checkpoint(train_dir):
        tf.compat.v1.logging.info(
            'Ignoring --checkpoint_path because a checkpoint already exists in %s'
            % train_dir)
        return None

    exclusions = []
    if checkpoint_exclude_scopes:
        exclusions = [
            scope.strip() for scope in checkpoint_exclude_scopes.split(',')
        ]

    # TODO(sguada) variables.filter_variables()
    variables_to_restore = []
    for var in slim.get_model_variables():
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                break
        else:
            variables_to_restore.append(var)

    if tf.io.gfile.isdir(checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
    else:
        checkpoint_path = checkpoint_path

    tf.compat.v1.logging.info('Fine-tuning from %s' % checkpoint_path)

    return slim.assign_from_checkpoint_fn(checkpoint_path,
                                          variables_to_restore,
                                          ignore_missing_vars=True)
Ejemplo n.º 5
0
 def testModelVariables(self):
     batch_size = 5
     height, width = 224, 224
     num_classes = 1000
     with self.test_session():
         inputs = tf.random.uniform((batch_size, height, width, 3))
         vgg.vgg_a(inputs, num_classes)
         expected_names = [
             'vgg_a/conv1/conv1_1/weights',
             'vgg_a/conv1/conv1_1/biases',
             'vgg_a/conv2/conv2_1/weights',
             'vgg_a/conv2/conv2_1/biases',
             'vgg_a/conv3/conv3_1/weights',
             'vgg_a/conv3/conv3_1/biases',
             'vgg_a/conv3/conv3_2/weights',
             'vgg_a/conv3/conv3_2/biases',
             'vgg_a/conv4/conv4_1/weights',
             'vgg_a/conv4/conv4_1/biases',
             'vgg_a/conv4/conv4_2/weights',
             'vgg_a/conv4/conv4_2/biases',
             'vgg_a/conv5/conv5_1/weights',
             'vgg_a/conv5/conv5_1/biases',
             'vgg_a/conv5/conv5_2/weights',
             'vgg_a/conv5/conv5_2/biases',
             'vgg_a/fc6/weights',
             'vgg_a/fc6/biases',
             'vgg_a/fc7/weights',
             'vgg_a/fc7/biases',
             'vgg_a/fc8/weights',
             'vgg_a/fc8/biases',
         ]
         model_variables = [v.op.name for v in slim.get_model_variables()]
         self.assertSetEqual(set(model_variables), set(expected_names))
 def testModelHasExpectedNumberOfParameters(self):
   batch_size = 5
   height, width = 299, 299
   inputs = tf.random.uniform((batch_size, height, width, 3))
   with slim.arg_scope(inception.inception_v3_arg_scope()):
     inception.inception_v3_base(inputs)
   total_params, _ = slim.model_analyzer.analyze_vars(
       slim.get_model_variables())
   self.assertAlmostEqual(21802784, total_params)
Ejemplo n.º 7
0
 def testModelHasExpectedNumberOfParameters(self):
     batch_size = 5
     height, width = 224, 224
     inputs = tf.random.uniform((batch_size, height, width, 3))
     with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
                         normalizer_fn=slim.batch_norm):
         mobilenet_v1.mobilenet_v1_base(inputs)
         total_params, _ = slim.model_analyzer.analyze_vars(
             slim.get_model_variables())
         self.assertAlmostEqual(3217920, total_params)
def _get_init_fn():
    """Returns a function run by the chief worker to warm-start the training.

  Note that the init_fn is only run when initializing the model during the very
  first global step.

  Returns:
    An init function run by the supervisor.
  """
    if FLAGS.checkpoint_path is None:
        return None

    # Warn the user if a checkpoint exists in the train_dir. Then we'll be
    # ignoring the checkpoint anyway.
    if tf.train.latest_checkpoint(FLAGS.train_dir):
        if not FLAGS.continue_training:
            raise ValueError(
                'continue_training set to False but there is a checkpoint in the training_dir.'
            )
        tf.compat.v1.logging.info(
            'Ignoring --checkpoint_path because a checkpoint already exists in %s'
            % FLAGS.train_dir)
        return None

    exclusions = []
    if FLAGS.checkpoint_exclude_scopes:
        exclusions = [
            scope.strip()
            for scope in FLAGS.checkpoint_exclude_scopes.split(',')
        ]

    # TODO(sguada) variables.filter_variables()
    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    if tf.io.gfile.isdir(FLAGS.checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
        checkpoint_path = FLAGS.checkpoint_path

    tf.compat.v1.logging.info('Fine-tuning from %s' % checkpoint_path)

    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=FLAGS.ignore_missing_vars)
Ejemplo n.º 9
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.compat.v1.GraphKeys.LOSSES,
                                      first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #if FLAGS.quantize_delay >= 0:
        #  tf.contrib.quantize.create_training_graph(
        #      quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES,
                              first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Ejemplo n.º 10
0
def build_frontend(inputs,
                   frontend,
                   is_training=True,
                   pretrained_dir="models",
                   num_classes=None):
    if frontend == 'ResNet50':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(
                inputs,
                is_training=is_training,
                scope='resnet_v2_50',
                num_classes=num_classes)
            frontend_scope = 'resnet_v2_50'
            init_fn = slim.assign_from_checkpoint_fn(
                model_path=os.path.join(pretrained_dir, 'resnet_v2_50.ckpt'),
                var_list=slim.get_model_variables('resnet_v2_50'),
                ignore_missing_vars=True)
    elif frontend == 'ResNet101':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_101(
                inputs,
                is_training=is_training,
                scope='resnet_v2_101',
                num_classes=num_classes)
            frontend_scope = 'resnet_v2_101'
            init_fn = slim.assign_from_checkpoint_fn(
                model_path=os.path.join(pretrained_dir, 'resnet_v2_101.ckpt'),
                var_list=slim.get_model_variables('resnet_v2_101'),
                ignore_missing_vars=True)
    elif frontend == 'ResNet152':
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_152(
                inputs,
                is_training=is_training,
                scope='resnet_v2_152',
                num_classes=num_classes)
            frontend_scope = 'resnet_v2_152'
            init_fn = slim.assign_from_checkpoint_fn(
                model_path=os.path.join(pretrained_dir, 'resnet_v2_152.ckpt'),
                var_list=slim.get_model_variables('resnet_v2_152'),
                ignore_missing_vars=True)
    elif frontend == 'MobileNetV2':
        with slim.arg_scope(mobilenet_v2.training_scope()):
            logits, end_points = mobilenet_v2.mobilenet(
                inputs,
                is_training=is_training,
                scope='mobilenet_v2',
                base_only=True)
            frontend_scope = 'mobilenet_v2'
            init_fn = slim.assign_from_checkpoint_fn(
                model_path=os.path.join(pretrained_dir, 'mobilenet_v2.ckpt'),
                var_list=slim.get_model_variables('mobilenet_v2'),
                ignore_missing_vars=True)
    elif frontend == 'InceptionV4':
        with slim.arg_scope(inception_v4.inception_v4_arg_scope()):
            logits, end_points = inception_v4.inception_v4(
                inputs, is_training=is_training, scope='inception_v4')
            frontend_scope = 'inception_v4'
            init_fn = slim.assign_from_checkpoint_fn(
                model_path=os.path.join(pretrained_dir, 'inception_v4.ckpt'),
                var_list=slim.get_model_variables('inception_v4'),
                ignore_missing_vars=True)
    else:
        raise ValueError(
            "Unsupported fronetnd model '%s'. This function only supports ResNet50, ResNet101, ResNet152, and MobileNetV2"
            % (frontend))

    return logits, end_points, frontend_scope, init_fn
Ejemplo n.º 11
0
def build_icnet(inputs, label_size, num_classes, preset_model='ICNet', pooling_type = "MAX",
    frontend="ResNet101", weight_decay=1e-5, is_training=True, pretrained_dir="models"):
    """
    Builds the ICNet model. 

    Arguments:
      inputs: The input tensor
      label_size: Size of the final label tensor. We need to know this for proper upscaling 
      preset_model: Which model you want to use. Select which ResNet model to use for feature extraction 
      num_classes: Number of classes
      pooling_type: Max or Average pooling

    Returns:
      ICNet model
    """

    inputs_4 = tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*4,  tf.shape(inputs)[2]*4])   
    inputs_2 = tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*2,  tf.shape(inputs)[2]*2])
    inputs_1 = inputs

    if frontend == 'Res50':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits_32, end_points_32 = resnet_v2.resnet_v2_50(inputs_4, is_training=is_training, scope='resnet_v2_50')
            logits_16, end_points_16 = resnet_v2.resnet_v2_50(inputs_2, is_training=is_training, scope='resnet_v2_50')
            logits_8, end_points_8 = resnet_v2.resnet_v2_50(inputs_1, is_training=is_training, scope='resnet_v2_50')
            resnet_scope='resnet_v2_50'
            # ICNet requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_50.ckpt'), slim.get_model_variables('resnet_v2_50'))
    elif frontend == 'Res101':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits_32, end_points_32 = resnet_v2.resnet_v2_101(inputs_4, is_training=is_training, scope='resnet_v2_101')
            logits_16, end_points_16 = resnet_v2.resnet_v2_101(inputs_2, is_training=is_training, scope='resnet_v2_101')
            logits_8, end_points_8 = resnet_v2.resnet_v2_101(inputs_1, is_training=is_training, scope='resnet_v2_101')
            resnet_scope='resnet_v2_101'
            # ICNet requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_101.ckpt'), slim.get_model_variables('resnet_v2_101'))
    elif frontend == 'Res152':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits_32, end_points_32 = resnet_v2.resnet_v2_152(inputs_4, is_training=is_training, scope='resnet_v2_152')
            logits_16, end_points_16 = resnet_v2.resnet_v2_152(inputs_2, is_training=is_training, scope='resnet_v2_152')
            logits_8, end_points_8 = resnet_v2.resnet_v2_152(inputs_1, is_training=is_training, scope='resnet_v2_152')
            resnet_scope='resnet_v2_152'
            # ICNet requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_152.ckpt'), slim.get_model_variables('resnet_v2_152'))
    else:
        raise ValueError("Unsupported ResNet model '%s'. This function only supports ResNet 50, ResNet 101, and ResNet 152" % (frontend))



    feature_map_shape = [int(x / 32.0) for x in label_size]
    block_32 = PyramidPoolingModule(end_points_32['pool3'], feature_map_shape=feature_map_shape, pooling_type=pooling_type)

    out_16, block_16 = CFFBlock(psp_32, end_points_16['pool3'])
    out_8, block_8 = CFFBlock(block_16, end_points_8['pool3'])
    out_4 = Upsampling_by_scale(out_8, scale=2)
    out_4 = slim.conv2d(out_4, num_classes, [1, 1], activation_fn=None) 

    out_full = Upsampling_by_scale(out_4, scale=2)
    
    out_full = slim.conv2d(out_full, num_classes, [1, 1], activation_fn=None, scope='logits')

    net = tf.concat([out_16, out_8, out_4, out_final])

    return net, init_fn
Ejemplo n.º 12
0
def build_gcn(inputs,
              num_classes,
              preset_model='GCN-Res101',
              weight_decay=1e-5,
              is_training=True,
              upscaling_method="bilinear",
              pretrained_dir="models"):
    """
    Builds the GCN model.

    Arguments:
      inputs: The input tensor
      preset_model: Which model you want to use. Select which ResNet model to use for feature extraction
      num_classes: Number of classes

    Returns:
      GCN model
    """

    if preset_model == 'GCN-Res50':
        with slim.arg_scope(
                resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_50(
                inputs, is_training=is_training, scope='resnet_v2_50')
            resnet_scope = 'resnet_v2_50'
            # GCN requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(
                os.path.join(pretrained_dir, 'resnet_v2_50.ckpt'),
                slim.get_model_variables('resnet_v2_50'))
    elif preset_model == 'GCN-Res101':
        with slim.arg_scope(
                resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_101(
                inputs, is_training=is_training, scope='resnet_v2_101')
            resnet_scope = 'resnet_v2_101'
            # GCN requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(
                os.path.join(pretrained_dir, 'resnet_v2_101.ckpt'),
                slim.get_model_variables('resnet_v2_101'))
    elif preset_model == 'GCN-Res152':
        with slim.arg_scope(
                resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_152(
                inputs, is_training=is_training, scope='resnet_v2_152')
            resnet_scope = 'resnet_v2_152'
            # GCN requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(
                os.path.join(pretrained_dir, 'resnet_v2_152.ckpt'),
                slim.get_model_variables('resnet_v2_152'))
    else:
        raise ValueError(
            "Unsupported ResNet model '%s'. This function only supports ResNet 101 and ResNet 152"
            % (preset_model))

    res = [
        end_points['pool5'], end_points['pool4'], end_points['pool3'],
        end_points['pool2']
    ]

    down_5 = GlobalConvBlock(res[0], n_filters=21, size=3)
    down_5 = BoundaryRefinementBlock(down_5, n_filters=21, kernel_size=[3, 3])
    down_5 = ConvUpscaleBlock(down_5,
                              n_filters=21,
                              kernel_size=[3, 3],
                              scale=2)

    down_4 = GlobalConvBlock(res[1], n_filters=21, size=3)
    down_4 = BoundaryRefinementBlock(down_4, n_filters=21, kernel_size=[3, 3])
    down_4 = tf.add(down_4, down_5)
    down_4 = BoundaryRefinementBlock(down_4, n_filters=21, kernel_size=[3, 3])
    down_4 = ConvUpscaleBlock(down_4,
                              n_filters=21,
                              kernel_size=[3, 3],
                              scale=2)

    down_3 = GlobalConvBlock(res[2], n_filters=21, size=3)
    down_3 = BoundaryRefinementBlock(down_3, n_filters=21, kernel_size=[3, 3])
    down_3 = tf.add(down_3, down_4)
    down_3 = BoundaryRefinementBlock(down_3, n_filters=21, kernel_size=[3, 3])
    down_3 = ConvUpscaleBlock(down_3,
                              n_filters=21,
                              kernel_size=[3, 3],
                              scale=2)

    down_2 = GlobalConvBlock(res[3], n_filters=21, size=3)
    down_2 = BoundaryRefinementBlock(down_2, n_filters=21, kernel_size=[3, 3])
    down_2 = tf.add(down_2, down_3)
    down_2 = BoundaryRefinementBlock(down_2, n_filters=21, kernel_size=[3, 3])
    down_2 = ConvUpscaleBlock(down_2,
                              n_filters=21,
                              kernel_size=[3, 3],
                              scale=2)

    net = BoundaryRefinementBlock(down_2, n_filters=21, kernel_size=[3, 3])
    net = ConvUpscaleBlock(net, n_filters=21, kernel_size=[3, 3], scale=2)
    net = BoundaryRefinementBlock(net, n_filters=21, kernel_size=[3, 3])

    net = slim.conv2d(net,
                      num_classes, [1, 1],
                      activation_fn=None,
                      scope='logits')

    return net, init_fn
Ejemplo n.º 13
0
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Forces all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            # Loads content images.
            content_inputs_, _ = image_utils.imagenet_inputs(
                FLAGS.batch_size, FLAGS.image_size)

            # Loads style images.
            [style_inputs_, _, _] = image_utils.arbitrary_style_image_inputs(
                FLAGS.style_dataset_file,
                batch_size=FLAGS.batch_size,
                image_size=FLAGS.image_size,
                shuffle=True,
                center_crop=FLAGS.center_crop,
                augment_style_images=FLAGS.augment_style_images,
                random_style_image_size=FLAGS.random_style_image_size)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and content weight flags.
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Define the model
            stylized_images, total_loss, loss_dict, \
                  _ = build_mobilenet_model.build_mobilenet_model(
                      content_inputs_,
                      style_inputs_,
                      mobilenet_trainable=False,
                      style_params_trainable=True,
                      transformer_trainable=True,
                      mobilenet_end_point='layer_19',
                      transformer_alpha=FLAGS.alpha,
                      style_prediction_bottleneck=100,
                      adds_losses=True,
                      content_weights=content_weights,
                      style_weights=style_weights,
                      total_variation_weight=FLAGS.total_variation_weight,
                  )

            # Adding scalar summaries to the tensorboard.
            for key in loss_dict:
                tf.summary.scalar(key, loss_dict[key])

            # Adding Image summaries to the tensorboard.
            tf.summary.image('image/0_content_inputs', content_inputs_, 3)
            tf.summary.image('image/1_style_inputs_aug', style_inputs_, 3)
            tf.summary.image('image/2_stylized_images', stylized_images, 3)

            # Set up training
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            # Function to restore VGG16 parameters.
            init_fn_vgg = slim.assign_from_checkpoint_fn(
                vgg.checkpoint_file(), slim.get_variables('vgg_16'))

            # Function to restore Mobilenet V2 parameters.
            mobilenet_variables_dict = {
                var.op.name: var
                for var in slim.get_model_variables('MobilenetV2')
            }
            init_fn_mobilenet = slim.assign_from_checkpoint_fn(
                FLAGS.mobilenet_checkpoint, mobilenet_variables_dict)

            # Function to restore VGG16 and Mobilenet V2 parameters.
            def init_sub_networks(session):
                init_fn_vgg(session)
                init_fn_mobilenet(session)

            # Run training
            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_sub_networks,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Ejemplo n.º 14
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the inference graph is
      built (before optimization). This is helpful to perform additional changes
      to the training graph such as adding FakeQuant ops. The function should
      modify the default graph.

  Raises:
    ValueError: If both num_clones > 1 and train_config.sync_replicas is true.
  """

  detection_model = create_model_fn()
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]

  with tf.Graph().as_default():
    # Build a configuration specifying multi-GPU and multi-replicas.
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=worker_replicas,
        num_ps_tasks=ps_tasks,
        worker_job_name=worker_job_name)

    # Place the global step on the device storing the variables.
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    if num_clones != 1 and train_config.sync_replicas:
      raise ValueError('In Synchronous SGD mode num_clones must ',
                       'be 1. Found num_clones: {}'.format(num_clones))
    batch_size = train_config.batch_size // num_clones
    if train_config.sync_replicas:
      batch_size //= train_config.replicas_to_aggregate

    with tf.device(deploy_config.inputs_device()):
      input_queue = create_input_queue(
          batch_size, create_tensor_dict_fn,
          train_config.batch_queue_capacity,
          train_config.num_batch_queue_threads,
          train_config.prefetch_queue_capacity, data_augmentation_options)

    # Gather initial summaries.
    # TODO(rathodv): See if summaries can be added/extracted from global tf
    # collections so that they don't have to be passed around.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    global_summaries = set([])

    model_fn = functools.partial(_create_losses,
                                 create_model_fn=create_model_fn,
                                 train_config=train_config)
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
    first_clone_scope = clones[0].scope

    if graph_hook_fn:
      with tf.device(deploy_config.variables_device()):
        graph_hook_fn()

    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    with tf.device(deploy_config.optimizer_device()):
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)
      for var in optimizer_summary_vars:
        tf.summary.scalar(var.op.name, var, family='LearningRate')

    sync_optimizer = None
    if train_config.sync_replicas:
      training_optimizer = tf.train.SyncReplicasOptimizer(
          training_optimizer,
          replicas_to_aggregate=train_config.replicas_to_aggregate,
          total_num_replicas=worker_replicas)
      sync_optimizer = training_optimizer

    with tf.device(deploy_config.optimizer_device()):
      regularization_losses = (None if train_config.add_regularization_loss
                               else [])
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, training_optimizer,
          regularization_losses=regularization_losses)
      total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

      # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
      if train_config.bias_grad_multiplier:
        biases_regex_list = ['.*/biases']
        grads_and_vars = variables_helper.multiply_gradients_matching_regex(
            grads_and_vars,
            biases_regex_list,
            multiplier=train_config.bias_grad_multiplier)

      # Optionally freeze some layers by setting their gradients to be zero.
      if train_config.freeze_variables:
        grads_and_vars = variables_helper.freeze_gradients_matching_regex(
            grads_and_vars, train_config.freeze_variables)

      # Optionally clip gradients
      if train_config.gradient_clipping_by_norm > 0:
        with tf.name_scope('clip_grads'):
          grads_and_vars = slim.learning.clip_gradient_norms(
              grads_and_vars, train_config.gradient_clipping_by_norm)

      # Create gradient updates.
      grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                        global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops, name='update_barrier')
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add summaries.
    for model_var in slim.get_model_variables():
      global_summaries.add(tf.summary.histogram('ModelVars/' +
                                                model_var.op.name, model_var))
    for loss_tensor in tf.losses.get_losses():
      global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                             loss_tensor))
    global_summaries.add(
        tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))
    summaries |= global_summaries

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Save checkpoints regularly.
    keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
    saver = tf.train.Saver(
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

    # Create ops required to initialize the model from a given checkpoint.
    init_fn = None
    if train_config.fine_tune_checkpoint:
      if not train_config.fine_tune_checkpoint_type:
        # train_config.from_detection_checkpoint field is deprecated. For
        # backward compatibility, fine_tune_checkpoint_type is set based on
        # from_detection_checkpoint.
        if train_config.from_detection_checkpoint:
          train_config.fine_tune_checkpoint_type = 'detection'
        else:
          train_config.fine_tune_checkpoint_type = 'classification'
      var_map = detection_model.restore_map(
          fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
          load_all_detection_checkpoint_vars=(
              train_config.load_all_detection_checkpoint_vars))
      available_var_map = (variables_helper.
                           get_variables_available_in_checkpoint(
                               var_map, train_config.fine_tune_checkpoint,
                               include_global_step=False))
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn

    slim.learning.train(
        train_tensor,
        logdir=train_dir,
        master=master,
        is_chief=is_chief,
        session_config=session_config,
        startup_delay_steps=train_config.startup_delay_steps,
        init_fn=init_fn,
        summary_op=summary_op,
        number_of_steps=(
            train_config.num_steps if train_config.num_steps else None),
        save_summaries_secs=120,
        sync_optimizer=sync_optimizer,
        saver=saver)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.compat.v1.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)

            accuracy = slim.metrics.accuracy(
                tf.cast(tf.argmax(input=logits, axis=1), dtype=tf.int32),
                tf.cast(tf.argmax(input=labels, axis=1), dtype=tf.int32))
            tf.compat.v1.add_to_collection('accuracy', accuracy)
            end_points['train_accuracy'] = accuracy
            return end_points

        # Get accuracies for the batch

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            if 'accuracy' in end_point:
                continue
            x = end_points[end_point]
            summaries.add(
                tf.compat.v1.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.compat.v1.summary.scalar('sparsity/' + end_point,
                                            tf.nn.zero_fraction(x)))
        train_acc = end_points['train_accuracy']
        summaries.add(
            tf.compat.v1.summary.scalar('train_accuracy',
                                        end_points['train_accuracy']))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES,
                                                first_clone_scope):
            summaries.add(
                tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        # @philkuz
        # Add accuracy summaries
        # TODO add if statemetn for n iterations
        # images_val, labels_val= tf.train.batch(
        #     [image, label],
        #     batch_size=FLAGS.batch_size,
        #     num_threads=FLAGS.num_preprocessing_threads,
        #     capacity=5 * FLAGS.batch_size)

        # # labels_val = slim.one_hot_encoding(
        # #     labels_val, dataset.num_classes - FLAGS.labels_offset)
        # batch_queue_val = slim.prefetch_queue.prefetch_queue(
        #     [images_val, labels_val], capacity=2 * deploy_config.num_clones)
        # logits, end_points = network_fn(images, reuse=True)
        # # predictions = tf.nn.softmax(logits)
        # predictions = tf.to_in32(tf.argmax(logits,1))

        # logits_val, end_points_val = network_fn(images_val, reuse=True)
        # predictions_val = tf.to_in32(tf.argmax(logits_val,1))

        # labels_val = tf.squeeze(labels_val)
        # labels = tf.squeeze(labels)

        # names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        #       'train/accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        #       'val/accuracy': slim.metrics.streaming_accuracy(predictions_val, labels_val),
        # })
        # for metric_name, metric_value in names_to_values.items():
        #   op = tf.summary.scalar(metric_name, metric_value)
        #   # op = tf.Print(op, [metric_value], metric_name)
        #   summaries.add(op)
        # Add summaries for variables.
        # TODO something to remove some of these from tensorboard scalars
        for variable in slim.get_model_variables():
            summaries.add(
                tf.compat.v1.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(
                tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.compat.v1.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES,
                                        first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.compat.v1.summary.merge(list(summaries),
                                                name='summary_op')

        # @philkuz
        # set the  max_number_of_steps parameter if num_epochs is available
        print('FLAGS.num_epochs', FLAGS.num_epochs)
        if FLAGS.num_epochs is not None and FLAGS.max_number_of_steps is None:
            FLAGS.max_number_of_steps = int(
                FLAGS.num_epochs * dataset.num_samples / FLAGS.batch_size)
            # FLAGS.max_number_of_steps = int(math.round(FLAGS.num_epochs / dataset.num_samples))

        # setup the logdir
        # @philkuz  the train_dir setup
        if FLAGS.experiment_name is not None:
            experiment_dir = 'bs={},lr={},epochs={}/{}'.format(
                FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs,
                FLAGS.experiment_name)
            print(experiment_dir)
            FLAGS.train_dir = os.path.join(FLAGS.train_dir, experiment_dir)
            print(FLAGS.train_dir)

        # @philkuz overriding train_step
        def train_step(sess, train_op, global_step, train_step_kwargs):
            """Function that takes a gradient step and specifies whether to stop.
      Args:
        sess: The current session.
        train_op: An `Operation` that evaluates the gradients and returns the
          total loss.
        global_step: A `Tensor` representing the global training step.
        train_step_kwargs: A dictionary of keyword arguments.
      Returns:
        The total loss and a boolean indicating whether or not to stop training.
      Raises:
        ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
      """
            start_time = time.time()

            trace_run_options = None
            run_metadata = None
            should_acc = True  # TODO make this not hardcoded @philkuz
            if 'should_trace' in train_step_kwargs:
                if 'logdir' not in train_step_kwargs:
                    raise ValueError(
                        'logdir must be present in train_step_kwargs when '
                        'should_trace is present')
                if sess.run(train_step_kwargs['should_trace']):
                    trace_run_options = config_pb2.RunOptions(
                        trace_level=config_pb2.RunOptions.FULL_TRACE)
                    run_metadata = config_pb2.RunMetadata()
            if not should_acc:
                total_loss, np_global_step = sess.run(
                    [train_op, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            else:
                total_loss, acc, np_global_step = sess.run(
                    [train_op, train_acc, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            time_elapsed = time.time() - start_time

            if run_metadata is not None:
                tl = timeline.Timeline(run_metadata.step_stats)
                trace = tl.generate_chrome_trace_format()
                trace_filename = os.path.join(
                    train_step_kwargs['logdir'],
                    'tf_trace-%d.json' % np_global_step)
                tf.compat.v1.logging.info('Writing trace to %s',
                                          trace_filename)
                file_io.write_string_to_file(trace_filename, trace)
                if 'summary_writer' in train_step_kwargs:
                    train_step_kwargs['summary_writer'].add_run_metadata(
                        run_metadata, 'run_metadata-%d' % np_global_step)

            if 'should_log' in train_step_kwargs:
                if sess.run(train_step_kwargs['should_log']):
                    if not should_acc:
                        tf.compat.v1.logging.info(
                            'global step %d: loss = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, time_elapsed)
                    else:
                        tf.compat.v1.logging.info(
                            'global step %d: loss = %.4f train_acc = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, acc, time_elapsed)

            if 'should_stop' in train_step_kwargs:
                should_stop = sess.run(train_step_kwargs['should_stop'])
            else:
                should_stop = False

            return total_loss, should_stop

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            train_step_fn=train_step,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Ejemplo n.º 16
0
def main(model_root, datasets_dir, model_name):
    # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    # 训练相关参数设置
    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=False,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=num_ps_tasks)

        global_step = slim.create_global_step()

        train_dir = os.path.join(model_root, model_name)
        dataset = convert_data.get_datasets('train', dataset_dir=datasets_dir)

        network_fn = net_select.get_network_fn(model_name,
                                               num_classes=dataset.num_classes,
                                               weight_decay=weight_decay,
                                               is_training=True)

        image_preprocessing_fn = preprocessing_select.get_preprocessing(
            model_name, is_training=True)

        print("the data_sources:", dataset.data_sources)

        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=num_readers,
                common_queue_capacity=20 * batch_size,
                common_queue_min=10 * batch_size)
            [image, label] = provider.get(['image', 'label'])

            train_image_size = network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.compat.v1.train.batch(
                [image, label],
                batch_size=batch_size,
                num_threads=num_preprocessing_threads,
                capacity=5 * batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        def calculate_pooling_center_loss(features, label, alfa, nrof_classes,
                                          weights, name):
            features = tf.reshape(features, [features.shape[0], -1])
            label = tf.argmax(label, 1)

            nrof_features = features.get_shape()[1]
            centers = tf.compat.v1.get_variable(
                name, [nrof_classes, nrof_features],
                dtype=tf.float32,
                initializer=tf.constant_initializer(0),
                trainable=False)
            label = tf.reshape(label, [-1])
            centers_batch = tf.gather(centers, label)
            centers_batch = tf.nn.l2_normalize(centers_batch, axis=-1)

            diff = (1 - alfa) * (centers_batch - features)
            centers = tf.compat.v1.scatter_sub(centers, label, diff)

            with tf.control_dependencies([centers]):
                distance = tf.square(features - centers_batch)
                distance = tf.reduce_sum(distance, axis=-1)
                center_loss = tf.reduce_mean(distance)

            center_loss = tf.identity(center_loss * weights,
                                      name=name + '_loss')
            return center_loss

        def attention_crop(attention_maps):
            '''
            利用attention map 做数据增强,这里是论文中的Crop Mask
            :param attention_maps: Feature maps降维得到的
            :return:
            '''
            batch_size, height, width, num_parts = attention_maps.shape
            bboxes = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                part_weights = part_weights / np.sum(part_weights)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]

                mask = attention_map[:, :, selected_index]

                threshold = random.uniform(0.4, 0.6)
                itemindex = np.where(mask >= mask.max() * threshold)

                ymin = itemindex[0].min() / height - 0.1
                ymax = itemindex[0].max() / height + 0.1
                xmin = itemindex[1].min() / width - 0.1
                xmax = itemindex[1].max() / width + 0.1

                bbox = np.asarray([ymin, xmin, ymax, xmax], dtype=np.float32)
                bboxes.append(bbox)
            bboxes = np.asarray(bboxes, np.float32)
            return bboxes

        def attention_drop(attention_maps):
            '''
            这里是attention drop部分,目的是为了让模型可以注意到物体的其他部位(因不同attention map可能聚焦了同一部位)
            :param attention_maps:
            :return:
            '''
            batch_size, height, width, num_parts = attention_maps.shape
            masks = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                if (np.sum(part_weights) != 0):
                    part_weights = part_weights / np.sum(part_weights)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]
                mask = attention_map[:, :, selected_index:selected_index + 1]

                # soft mask
                threshold = random.uniform(0.2, 0.5)
                mask = (mask < threshold * mask.max()).astype(np.float32)
                masks.append(mask)
            masks = np.asarray(masks, dtype=np.float32)
            return masks

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits_1, end_points_1 = network_fn(images)

            attention_maps = end_points_1['attention_maps']
            attention_maps = tf.image.resize(
                attention_maps, [train_image_size, train_image_size],
                method=tf.image.ResizeMethod.BILINEAR)

            # attention crop
            bboxes = tf.compat.v1.py_func(attention_crop, [attention_maps],
                                          [tf.float32])
            bboxes = tf.reshape(bboxes, [batch_size, 4])
            box_ind = tf.range(batch_size, dtype=tf.int32)
            images_crop = tf.image.crop_and_resize(
                images,
                bboxes,
                box_ind,
                crop_size=[train_image_size, train_image_size])

            # attention drop
            masks = tf.compat.v1.py_func(attention_drop, [attention_maps],
                                         [tf.float32])
            masks = tf.reshape(
                masks, [batch_size, train_image_size, train_image_size, 1])
            images_drop = images * masks

            logits_2, end_points_2 = network_fn(images_crop, reuse=True)
            logits_3, end_points_3 = network_fn(images_drop, reuse=True)

            slim.losses.softmax_cross_entropy(logits_1,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_1')
            slim.losses.softmax_cross_entropy(logits_2,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_2')
            slim.losses.softmax_cross_entropy(logits_3,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_3')

            embeddings = end_points_1['embeddings']
            center_loss = calculate_pooling_center_loss(
                features=embeddings,
                label=labels,
                alfa=0.95,
                nrof_classes=dataset.num_classes,
                weights=1.0,
                name='center_loss')
            slim.losses.add_loss(center_loss)

            return end_points_1

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES,
                                                first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step)
            optimizer = configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = get_variables_to_train(trainable_scopes)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES,
                                        first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.compat.v1.summary.merge_all()

        config = tf.compat.v1.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)
        config.gpu_options.allow_growth = True
        config.gpu_options.visible_device_list = "0"

        save_model_path = os.path.join(checkpoint_path, model_name,
                                       "%s.ckpt" % model_name)
        print(save_model_path)

        # saver = tf.compat.v1.train.import_meta_graph('%s.meta'%save_model_path, clear_devices=True)
        tf.compat.v1.disable_eager_execution()
        # train the model
        slim.learning.train(
            train_op=train_tensor,
            logdir=train_dir,
            is_chief=(task == 0),
            init_fn=_get_init_fn(save_model_path, train_dir=train_dir),
            summary_op=summary_op,
            number_of_steps=max_number_of_steps,
            log_every_n_steps=log_every_n_steps,
            save_summaries_secs=save_summaries_secs,
            save_interval_secs=save_interval_secs,
            # sync_optimizer=None,
            session_config=config)
Ejemplo n.º 17
0
def build_refinenet(inputs, num_classes, preset_model='RefineNet-Res101', weight_decay=1e-5, is_training=True, upscaling_method="bilinear", pretrained_dir="models"):
    """
    Builds the RefineNet model.

    Arguments:
      inputs: The input tensor
      preset_model: Which model you want to use. Select which ResNet model to use for feature extraction
      num_classes: Number of classes

    Returns:
      RefineNet model
    """

    if preset_model == 'RefineNet-Res50':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_50(inputs, is_training=is_training, scope='resnet_v2_50')
            # RefineNet requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_50.ckpt'), slim.get_model_variables('resnet_v2_50'))
    elif preset_model == 'RefineNet-Res101':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_101(inputs, is_training=is_training, scope='resnet_v2_101')
            # RefineNet requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_101.ckpt'), slim.get_model_variables('resnet_v2_101'))
    elif preset_model == 'RefineNet-Res152':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_152(inputs, is_training=is_training, scope='resnet_v2_152')
            # RefineNet requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_152.ckpt'), slim.get_model_variables('resnet_v2_152'))
    else:
    	raise ValueError("Unsupported ResNet model '%s'. This function only supports ResNet 101 and ResNet 152" % (preset_model))




    high = [end_points['pool5'], end_points['pool4'],
         end_points['pool3'], end_points['pool2']]

    low = [None, None, None, None]

    # Get the feature maps to the proper size with bottleneck
    high[0]=slim.conv2d(high[0], 512, 1)
    high[1]=slim.conv2d(high[1], 256, 1)
    high[2]=slim.conv2d(high[2], 256, 1)
    high[3]=slim.conv2d(high[3], 256, 1)

    # RefineNet
    low[0]=RefineBlock(high_inputs=high[0],low_inputs=None) # Only input ResNet 1/32
    low[1]=RefineBlock(high[1],low[0]) # High input = ResNet 1/16, Low input = Previous 1/16
    low[2]=RefineBlock(high[2],low[1]) # High input = ResNet 1/8, Low input = Previous 1/8
    low[3]=RefineBlock(high[3],low[2]) # High input = ResNet 1/4, Low input = Previous 1/4

    # g[3]=Upsampling(g[3],scale=4)

    net = low[3]

    net = ResidualConvUnit(net)
    net = ResidualConvUnit(net)

    if upscaling_method.lower() == "conv":
        net = ConvUpscaleBlock(net, 128, kernel_size=[3, 3], scale=2)
        net = ConvBlock(net, 128)
        net = ConvUpscaleBlock(net, 64, kernel_size=[3, 3], scale=2)
        net = ConvBlock(net, 64)
    elif upscaling_method.lower() == "bilinear":
        net = Upsampling(net, scale=4)

    net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, scope='logits')

    return net, init_fn
Ejemplo n.º 18
0
    def __init__(self,
                 net_name,
                 snapshot_path,
                 feature_norm_method=None,
                 should_restore_classifier=False,
                 gpu_memory_fraction=None,
                 vgg_16_heads=None):
        """
        Args:
            snapshot_path: path or dir with checkpoints
            feature_norm_method:
            should_restore_classifier: if None - do not restore last layer from the snapshot,
                         otherwise must be equal to the number of classes of the snapshot.
                         if vgg_16_heads is not None then the classifiers will be restored anyway.

        """
        self.net_name = net_name
        if net_name != 'vgg_16_multihead' and vgg_16_heads is not None:
            raise ValueError(
                'vgg_16_heads must be not None only for vgg_16_multihead')
        if net_name == 'vgg_16_multihead' and vgg_16_heads is None:
            raise ValueError(
                'vgg_16_heads must be not None for vgg_16_multihead')

        if tf.io.gfile.isdir(snapshot_path):
            snapshot_path = tf.train.latest_checkpoint(snapshot_path)

        if not isinstance(feature_norm_method, list):
            feature_norm_method = [feature_norm_method]
        accepable_methods = [None, 'signed_sqrt', 'unit_norm']
        for method in feature_norm_method:
            if method not in accepable_methods:
                raise ValueError(
                    'unknown norm method: {}. Use one of {}'.format(
                        method, accepable_methods))
        self.feature_norm_method = feature_norm_method
        if vgg_16_heads is not None:
            should_restore_classifier = True

        if should_restore_classifier:
            if vgg_16_heads is None:
                reader = pywrap_tensorflow.NewCheckpointReader(snapshot_path)
                if net_name == 'inception_v1':
                    var_value = reader.get_tensor(
                        'InceptionV1/Logits/Conv2d_0c_1x1/weights')
                else:
                    var_value = reader.get_tensor('vgg_16/fc8/weights')
                num_classes = var_value.shape[3]
            else:
                num_classes = vgg_16_heads
        else:
            num_classes = 2 if vgg_16_heads is None else vgg_16_heads

        network_fn = nets_factory.get_network_fn(net_name,
                                                 num_classes=num_classes,
                                                 is_training=False)
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            net_name, is_training=False)

        eval_image_size = network_fn.default_image_size
        self.img_resize_shape = (eval_image_size, eval_image_size
                                 )  # (224, 224) for VGG

        with tf.Graph().as_default() as graph:
            self.graph = graph
            with tf.compat.v1.variable_scope('input'):
                input_pl = tf.compat.v1.placeholder(
                    tf.float32,
                    shape=[None, eval_image_size, eval_image_size, 3],
                    name='x')
                # not used
                is_phase_train_pl = tf.compat.v1.placeholder(
                    tf.bool, shape=tuple(), name='is_phase_train')

            function_to_map = lambda x: image_preprocessing_fn(
                x, eval_image_size, eval_image_size)
            images = tf.map_fn(function_to_map, input_pl)

            logits, self.end_points = network_fn(images)
            self.__dict__.update(self.end_points)
            if net_name == 'inception_v1':
                for tensor_name in [
                        'Branch_0/Conv2d_0a_1x1', 'Branch_1/Conv2d_0a_1x1',
                        'Branch_1/Conv2d_0b_3x3', 'Branch_2/Conv2d_0a_1x1',
                        'Branch_2/Conv2d_0b_3x3', 'Branch_3/MaxPool_0a_3x3',
                        'Branch_3/Conv2d_0b_1x1'
                ]:
                    full_tensor_name = 'InceptionV1/InceptionV1/Mixed_4d/' + tensor_name
                    if 'MaxPool' in tensor_name:
                        full_tensor_name += '/MaxPool:0'
                    else:
                        full_tensor_name += '/Relu:0'
                    short_name = 'Mixed_4d/' + tensor_name
                    self.__dict__[short_name] = tf.compat.v1.get_default_graph(
                    ).get_tensor_by_name(full_tensor_name)
                self.MaxPool_0a_7x7 = tf.compat.v1.get_default_graph(
                ).get_tensor_by_name(
                    "InceptionV1/Logits/MaxPool_0a_7x7/AvgPool:0")
            elif net_name in ['vgg_16', 'vgg_16_multihead']:
                for layer_name in ['fc6', 'fc7'] + \
                        ['conv{0}/conv{0}_{1}'.format(i, j) for i in range(3, 6) for j in range(1, 4)]:
                    self.__dict__['vgg_16/{}_prerelu'.format(layer_name)] = \
                        tf.compat.v1.get_default_graph().get_tensor_by_name("vgg_16/{}/BiasAdd:0".format(layer_name))
            config = tf.compat.v1.ConfigProto(
                gpu_options=tf.compat.v1.GPUOptions(
                    per_process_gpu_memory_fraction=gpu_memory_fraction))
            self.sess = tf.compat.v1.Session(config=config)

            if should_restore_classifier:
                variables_to_restore = slim.get_model_variables()
            else:
                variables_to_restore = [
                    var for var in slim.get_model_variables()
                    if not var.op.name.startswith(classifier_scope[net_name])
                ]

            init_fn = slim.assign_from_checkpoint_fn(snapshot_path,
                                                     variables_to_restore)
            init_fn(self.sess)
Ejemplo n.º 19
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the training graph is
      completely built. This is helpful to perform additional changes to the
      training graph such as optimizing batchnorm. The function should modify
      the default graph.
  """

    detection_model = create_model_fn()

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(create_tensor_dict_fn)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            restore_checkpoints = [
                path.strip()
                for path in train_config.fine_tune_checkpoint.split(',')
            ]

            restorers = get_restore_checkpoint_ops(restore_checkpoints,
                                                   detection_model,
                                                   train_config)

            def initializer_fn(sess):
                for i, restorer in enumerate(restorers):
                    restorer.restore(sess, restore_checkpoints[i])

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (
                None if train_config.add_regularization_loss else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                0.9999, global_step)
            update_ops.append(
                variable_averages.apply(moving_average_variables))

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, 'critic_loss'))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
Ejemplo n.º 20
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels = tf.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    #if FLAGS.quantize:
    #  tf.contrib.quantize.create_eval_graph()

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(logits, 1)
    labels = tf.squeeze(labels)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        'Recall_5': slim.metrics.streaming_recall_at_k(
            logits, labels, 5),
    })

    # Print the summaries to screen.
    for name, value in names_to_values.items():
      summary_name = 'eval/%s' % name
      op = tf.summary.scalar(summary_name, value, collections=[])
      op = tf.Print(op, [value], summary_name)
      tf.add_to_collection(tf.compat.v1.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Evaluating %s' % checkpoint_path)

    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()),
        variables_to_restore=variables_to_restore)
Ejemplo n.º 21
0
Archivo: rcnn.py Proyecto: yekeren/VCR
def RCNN(inputs, proposals, options, is_training=True):
  """Runs RCNN model on the `inputs`.

  Args:
    inputs: Input image, a [batch, height, width, 3] uint8 tensor. The pixel
      values are in the range of [0, 255].
    proposals: Boxes used to crop the image features, using normalized
      coordinates. It should be a [batch, max_num_proposals, 4] float tensor
      denoting [y1, x1, y2, x2].
    options: A fast_rcnn_pb2.FastRCNN proto.
    is_training: If true, the model shall be executed in training mode.

  Returns:
    A [batch, max_num_proposals, feature_dims] tensor.

  Raises:
    ValueError if options is invalid.
  """
  if not isinstance(options, rcnn_pb2.RCNN):
    raise ValueError('The options has to be a rcnn_pb2.RCNN proto!')
  if inputs.dtype != tf.uint8:
    raise ValueError('The inputs has to be a tf.uint8 tensor.')

  net_fn = nets_factory.get_network_fn(name=options.feature_extractor_name,
                                       num_classes=1001)
  default_image_size = getattr(net_fn, 'default_image_size', 224)

  # Preprocess image.
  preprocess_fn = preprocessing_factory.get_preprocessing(
      options.feature_extractor_name, is_training=False)
  inputs = preprocess_fn(inputs,
                         output_height=None,
                         output_width=None,
                         crop_image=False)

  # Crop and resize images.
  batch = proposals.shape[0]
  max_num_proposals = tf.shape(proposals)[1]

  box_ind = tf.expand_dims(tf.range(batch), axis=-1)
  box_ind = tf.tile(box_ind, [1, max_num_proposals])

  cropped_inputs = tf.image.crop_and_resize(
      inputs,
      boxes=tf.reshape(proposals, [-1, 4]),
      box_ind=tf.reshape(box_ind, [-1]),
      crop_size=[default_image_size, default_image_size])

  # Run CNN.
  _, end_points = net_fn(cropped_inputs)
  outputs = end_points[options.feature_extractor_endpoint]
  outputs = tf.reshape(outputs, [batch, max_num_proposals, -1])

  init_fn = slim.assign_from_checkpoint_fn(
      options.feature_extractor_checkpoint,
      slim.get_model_variables(options.feature_extractor_scope))

  def _init_from_ckpt_fn(_, sess):
    return init_fn(sess)

  return outputs, _init_from_ckpt_fn
def main(_):
    #tf.disable_v2_behavior() ###
    tf.compat.v1.disable_eager_execution()
    tf.compat.v1.enable_resource_variables()

    # Enable habana bf16 conversion pass
    if FLAGS.dtype == 'bf16':
        os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path
        FLAGS.precision = 'bf16'
    else:
        os.environ['TF_BF16_CONVERSION'] = "0"

    if FLAGS.use_horovod:
        hvd_init()

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=True,
            use_grayscale=FLAGS.use_grayscale)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #if FLAGS.quantize_delay >= 0:
        #  quantize.create_training_graph(quant_delay=FLAGS.quantize_delay) #for debugging!!

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if horovod_enabled():
            hvd.broadcast_global_variables(0)
        ###########################
        # Kicks off the training. #
        ###########################
        with dump_callback():
            with logger.benchmark_context(FLAGS):
                eps1 = ExamplesPerSecondKerasHook(FLAGS.log_every_n_steps,
                                                  output_dir=FLAGS.train_dir,
                                                  batch_size=FLAGS.batch_size)

                write_hparams_v1(
                    eps1.writer, {
                        'batch_size': FLAGS.batch_size,
                        **{x: getattr(FLAGS, x)
                           for x in FLAGS}
                    })

                train_step_kwargs = {}
                if FLAGS.max_number_of_steps:
                    should_stop_op = math_ops.greater_equal(
                        global_step, FLAGS.max_number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                if FLAGS.log_every_n_steps > 0:
                    train_step_kwargs['should_log'] = math_ops.equal(
                        math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)

                eps1.on_train_begin()
                train_step_kwargs['EPS'] = eps1

                slim.learning.train(
                    train_tensor,
                    logdir=FLAGS.train_dir,
                    train_step_fn=train_step1,
                    train_step_kwargs=train_step_kwargs,
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    init_fn=_get_init_fn(),
                    summary_op=summary_op,
                    summary_writer=None,
                    number_of_steps=FLAGS.max_number_of_steps,
                    log_every_n_steps=FLAGS.log_every_n_steps,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    save_interval_secs=FLAGS.save_interval_secs,
                    sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Ejemplo n.º 23
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    n_hash = FLAGS.number_hashing_functions
    L_vec = FLAGS.neuron_vector_length
    quant_params = []
    for i in range(len(n_hash)):
        quant_params.append([int(n_hash[i]), int(L_vec[i])])

    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        quant_params=quant_params, is_training=False)
#     network_fn = nets_factory.get_network_fn(
#         FLAGS.model_name,
#         num_classes=(dataset.num_classes - FLAGS.labels_offset),
#         is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels = tf.compat.v1.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(input=logits, axis=1)
    labels = tf.squeeze(labels)

    # Define the metrics:
    #names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
    names_to_values, names_to_updates = aggregate_metric_map({
        #'Accuracy': slim.metrics.streaming_accuracy(predictions,labels),
        'Accuracy': tf.compat.v1.metrics.accuracy(labels, predictions), ##FIXXED
        'Recall_5': (
            logits, labels, 5),
    })

    # Print the summaries to screen.
    for name, value in names_to_values.items():
      summary_name = 'eval/%s' % name
      op = tf.compat.v1.summary.scalar(summary_name, value, collections=[])
      op = tf.compat.v1.Print(op, [value], summary_name)
      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

    if tf.io.gfile.isdir(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.compat.v1.logging.info('Evaluating %s' % checkpoint_path)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth=True
#     config.log_device_placement=True
    
    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()),
	    session_config=config,
        variables_to_restore=variables_to_restore)
Ejemplo n.º 24
0
def build_deeplabv3(inputs, num_classes, preset_model='DeepLabV3-Res50', weight_decay=1e-5, is_training=True, pretrained_dir="models"):
    """
    Builds the DeepLabV3 model.

    Arguments:
      inputs: The input tensor=
      preset_model: Which model you want to use. Select which ResNet model to use for feature extraction
      num_classes: Number of classes

    Returns:
      DeepLabV3 model
    """

    if preset_model == 'DeepLabV3-Res50':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_50(inputs, is_training=is_training, scope='resnet_v2_50')
            resnet_scope='resnet_v2_50'
            # DeepLabV3 requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_50.ckpt'), slim.get_model_variables('resnet_v2_50'))
    elif preset_model == 'DeepLabV3-Res101':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_101(inputs, is_training=is_training, scope='resnet_v2_101')
            resnet_scope='resnet_v2_101'
            # DeepLabV3 requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_101.ckpt'), slim.get_model_variables('resnet_v2_101'))
    elif preset_model == 'DeepLabV3-Res152':
        with slim.arg_scope(resnet_v2.resnet_arg_scope(weight_decay=weight_decay)):
            logits, end_points = resnet_v2.resnet_v2_152(inputs, is_training=is_training, scope='resnet_v2_152')
            resnet_scope='resnet_v2_152'
            # DeepLabV3 requires pre-trained ResNet weights
            init_fn = slim.assign_from_checkpoint_fn(os.path.join(pretrained_dir, 'resnet_v2_152.ckpt'), slim.get_model_variables('resnet_v2_152'))
    else:
        raise ValueError("Unsupported ResNet model '%s'. This function only supports ResNet 50, ResNet 101, and ResNet 152" % (preset_model))


    label_size = tf.shape(inputs)[1:3]

    net = AtrousSpatialPyramidPoolingModule(end_points['pool4'])

    net = Upsampling(net, label_size)

    net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, scope='logits')

    return net, init_fn
Ejemplo n.º 25
0
def main(model_root, datasets_dir, model_name, test_image_name):
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        test_image = os.path.join(datasets_dir, test_image_name)

        dataset = convert_data.get_datasets('train', dataset_dir=datasets_dir)

        network_fn = net_select.get_network_fn(model_name,
                                               num_classes=dataset.num_classes,
                                               is_training=False)

        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            common_queue_capacity=20 * batch_size,
            common_queue_min=10 * batch_size)
        [image, label] = provider.get(['image', 'label'])

        image_preprocessing_fn = preprocessing_select.get_preprocessing(
            model_name, is_training=False)

        eval_image_size = network_fn.default_image_size
        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

        images, labels = tf_v1.train.batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocessing_threads,
            capacity=5 * batch_size)

        checkpoint_path = os.path.join(model_root, model_name)
        if tf.io.gfile.isdir(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
        else:
            checkpoint_path = checkpoint_path

        logits_1, end_points_1 = network_fn(images)
        attention_maps = tf.reduce_mean(end_points_1['attention_maps'],
                                        axis=-1,
                                        keepdims=True)
        attention_maps = tf.image.resize(attention_maps,
                                         [eval_image_size, eval_image_size],
                                         method=tf.image.ResizeMethod.BILINEAR)
        bboxes = tf_v1.py_func(mask2bbox, [attention_maps], [tf.float32])
        bboxes = tf.reshape(bboxes, [batch_size, 4])
        box_ind = tf.range(batch_size, dtype=tf.int32)

        images = tf.image.crop_and_resize(
            images,
            bboxes,
            box_ind,
            crop_size=[eval_image_size, eval_image_size])
        logits_2, end_points_2 = network_fn(images, reuse=True)

        logits = tf_v1.log(
            tf.nn.softmax(logits_1) * 0.5 + tf.nn.softmax(logits_2) * 0.5)
        """
        tf_v1.enable_eager_execution()

        #测试单张图片
        image_data = tf.io.read_file(test_image)
        image_data = tf.image.decode_jpeg(image_data,channels= 3)

        # plt.figure(1)
        # plt.imshow(image_data)

        image_data = image_preprocessing_fn(image_data, eval_image_size, eval_image_size)
        image_data = tf.expand_dims(image_data, 0)

        logits_3,end_points_3 = network_fn(image_data,reuse =True)
        attention_map = tf.reduce_mean(end_points_3['attention_maps'], axis=-1, keepdims=True)
        attention_map = tf.image.resize(attention_map, [eval_image_size, eval_image_size],
                                         method=tf.image.ResizeMethod.BILINEAR)
        bboxes = tf_v1.py_func(mask2bbox, [attention_map], [tf.float32])
        bboxes = tf.reshape(bboxes, [batch_size, 4])
        box_ind = tf.range(batch_size, dtype=tf.int32)

        image_data = tf.image.crop_and_resize(images, bboxes, box_ind, crop_size=[eval_image_size, eval_image_size])

        logits_4, end_points_4 = network_fn(image_data, reuse=True)
        logits_0 = tf_v1.log(tf.nn.softmax(logits_3) * 0.5 + tf.nn.softmax(logits_4) * 0.5)
        probabilities = logits_0[0,0:]

        print(probabilities)
        # sorted_inds = [i[0] for i in sorted(enumerate(-probabilities),key= lambda x:x[1])]
        sorted_inds = (np.argsort(probabilities.numpy())[::-1])

        train_info = sio.loadmat(os.path.join(datasets_dir, 'devkit', 'cars_train_annos.mat'))['annotations'][0]
        names = train_info['class']
        print(names)
        for i in range(5):
            index = sorted_inds[i]
            #  打印top5的预测类别和相应的概率值。
            print('Probability %0.2f => [%s]' % (probabilities[index],names[index+1][0][0]))
        """
        if moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        logits_to_updates = add_eval_summary(logits, labels, scope='/bilinear')
        logits_1_to_updates = add_eval_summary(logits_1,
                                               labels,
                                               scope='/logits_1')
        logits_2_to_updates = add_eval_summary(logits_2,
                                               labels,
                                               scope='/logits_2')

        if max_num_batches:
            num_batches = max_num_batches
        else:
            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(dataset.num_samples / float(batch_size))

        config = tf_v1.ConfigProto(allow_soft_placement=True,
                                   log_device_placement=False)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 1.0

        tf.compat.v1.disable_eager_execution()

        while True:
            if tf.io.gfile.isdir(checkpoint_path):
                checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
            else:
                checkpoint_path = checkpoint_path

            print('Evaluating %s' % checkpoint_path)
            eval_op = []
            # eval_op = list(logits_to_updates.values())
            eval_op.append(list(logits_to_updates.values()))
            eval_op.append(list(logits_1_to_updates.values()))
            eval_op.append(list(logits_2_to_updates.values()))
            # tf.convert_to_tensor(eval_op)
            # tf.cast(eval_op,dtype=tf.string)
            # print(eval_op)

            test_dir = checkpoint_path
            slim.evaluation.evaluate_once(
                master=' ',
                checkpoint_path=checkpoint_path,
                logdir=test_dir,
                num_evals=num_batches,
                eval_op=eval_op,
                variables_to_restore=variables_to_restore,
                final_op=None,
                session_config=config)