Esempio n. 1
0
def run_executor(params, mode, train_input_fn, callbacks):
    model_builder = model_factory.model_generator(params)
    model = model_builder.build_model(params, mode=mode)
    loss_fn = [model_builder.build_cls_loss_fn(), model_builder.build_box_loss_fn()]
    model.compile(optimizer=model.optimizer,
                  loss=loss_fn)
    model.fit(train_input_fn(), epochs=10, steps_per_epoch=100, callbacks=callbacks)
Esempio n. 2
0
    def _serving_model_graph(features, params):
        """Build the model graph for serving."""
        images = features['images']
        _, height, width, _ = images.get_shape().as_list()

        input_anchor = anchor.Anchor(params.anchor.min_level,
                                     params.anchor.max_level,
                                     params.anchor.num_scales,
                                     params.anchor.aspect_ratios,
                                     params.anchor.anchor_size,
                                     (height, width))

        model_fn = factory.model_generator(params)
        model_outputs = model_fn.build_outputs(
            features['images'],
            labels={'anchor_boxes': input_anchor.multilevel_boxes},
            mode=mode_keys.PREDICT)

        if cast_num_detections_to_float:
            model_outputs['num_detections'] = tf.cast(
                model_outputs['num_detections'], dtype=tf.float32)

        if output_image_info:
            model_outputs.update({
                'image_info': features['image_info'],
            })

        if output_normalized_coordinates:
            model_outputs['detection_boxes'] = box_utils.normalize_boxes(
                model_outputs['detection_boxes'],
                features['image_info'][:, 1:2, :])

        return model_outputs
Esempio n. 3
0
  def _serving_model_graph(features, params):
    """Build the model graph for serving."""
    images = features['images']
    batch_size, height, width, _ = images.get_shape().as_list()

    input_anchor = anchor.Anchor(
        params.anchor.min_level, params.anchor.max_level,
        params.anchor.num_scales, params.anchor.aspect_ratios,
        params.anchor.anchor_size, (height, width))

    multilevel_boxes = {}
    for k, v in six.iteritems(input_anchor.multilevel_boxes):
      multilevel_boxes[k] = tf.tile(
          tf.expand_dims(v, 0), [batch_size, 1, 1])

    model_fn = factory.model_generator(params)
    model_outputs = model_fn.build_outputs(
        features['images'],
        labels={
            'anchor_boxes': multilevel_boxes,
            'image_info': features['image_info'],
        },
        mode=mode_keys.PREDICT)

    if cast_num_detections_to_float:
      model_outputs['num_detections'] = tf.cast(
          model_outputs['num_detections'], dtype=tf.float32)

    if output_image_info:
      model_outputs.update({
          'image_info': features['image_info'],
      })

    if output_normalized_coordinates:
      model_outputs['detection_boxes'] = box_utils.normalize_boxes(
          model_outputs['detection_boxes'],
          features['image_info'][:, 1:2, :])

    predictions = {
        'num_detections': tf.identity(
            model_outputs['num_detections'], 'NumDetections'),
        'detection_boxes': tf.identity(
            model_outputs['detection_boxes'], 'DetectionBoxes'),
        'detection_classes': tf.identity(
            model_outputs['detection_classes'], 'DetectionClasses'),
        'detection_scores': tf.identity(
            model_outputs['detection_scores'], 'DetectionScores'),
    }
    if 'detection_masks' in model_outputs:
      predictions.update({
          'detection_masks':
              tf.identity(model_outputs['detection_masks'], 'DetectionMasks'),
      })

    if output_image_info:
      predictions['image_info'] = tf.identity(
          model_outputs['image_info'], 'ImageInfo')

    return predictions
Esempio n. 4
0
    def __init__(
        self,
        config_file: str,
        checkpoint_path: str,
        batch_size: int,
        resize_shape: tuple[int, int],
        cache_dir: str,
        device: int | None = None,
    ):
        self.device = device
        self.batch_size = batch_size
        self.resize_shape = resize_shape

        params = config_factory.config_generator("mask_rcnn")
        if config_file:
            params = params_dict.override_params_dict(params,
                                                      config_file,
                                                      is_strict=True)
        params.validate()
        params.lock()
        self.max_level = params.architecture.max_level

        self._model = model_factory.model_generator(params)
        estimator = tf.estimator.Estimator(model_fn=self._model_fn, )

        # Use SavedModel instead of Estimator.predcit()
        # because it is difficult to download images from GCS
        # when executing these codes on Vertex Pipelines.

        with tempfile.TemporaryDirectory() as tmpdir:
            export_dir_parent = cache_dir or tmpdir
            children = list(Path(export_dir_parent).glob("*"))
            if children == []:
                logger.info(f"export saved_model: {export_dir_parent}")
                estimator.export_saved_model(
                    export_dir_base=export_dir_parent,
                    serving_input_receiver_fn=self._serving_input_receiver_fn,
                    checkpoint_path=checkpoint_path,
                )

            children = list(Path(export_dir_parent).glob("*"))
            export_dir = str(children[0])
            logger.info(f"load saved_model from {export_dir}")
            self.saved_model = tf.saved_model.load(export_dir=export_dir)
Esempio n. 5
0
def build_model_fn(features, labels, mode, params):
    features = features
    model_builder = model_factory.model_generator(params)
    model = model_builder.build_model(params, mode=mode)
    # model.summary()
    loss_fn = model_builder.build_loss_fn()
    global_step = tf.train.get_or_create_global_step()
    outputs = model(features, training=True)
    prediction_loss = loss_fn(labels, outputs)
    total_loss = tf.reduce_mean(prediction_loss['total_loss'])
    # total_loss = tf.reduce_mean(outputs["cls_outputs"][3])
    optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(total_loss, global_step)
    return tf.estimator.EstimatorSpec(
                              mode=mode,
                              loss=total_loss,
                              train_op=train_op)
Esempio n. 6
0
def main(unused_argv):
    del unused_argv
    # Load the label map.
    print(' - Loading the label map...')
    label_map_dict = {}
    if FLAGS.label_map_format == 'csv':
        with tf.gfile.Open(FLAGS.label_map_file, 'r') as csv_file:
            reader = csv.reader(csv_file, delimiter=':')
            for row in reader:
                if len(row) != 2:
                    raise ValueError(
                        'Each row of the csv label map file must be in '
                        '`id:name` format.')
                id_index = int(row[0])
                name = row[1]
                label_map_dict[id_index] = {
                    'id': id_index,
                    'name': name,
                }
    else:
        raise ValueError('Unsupported label map format: {}.'.format(
            FLAGS.label_mape_format))

    params = config_factory.config_generator(FLAGS.model)
    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.validate()
    params.lock()

    model = model_factory.model_generator(params)

    with tf.Graph().as_default():
        image_input = tf.placeholder(shape=(), dtype=tf.string)
        image = tf.io.decode_image(image_input, channels=3)
        image.set_shape([None, None, 3])

        image = input_utils.normalize_image(image)
        image_size = [FLAGS.image_size, FLAGS.image_size]
        image, image_info = input_utils.resize_and_crop_image(
            image,
            image_size,
            image_size,
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image.set_shape([image_size[0], image_size[1], 3])

        # batching.
        images = tf.reshape(image, [1, image_size[0], image_size[1], 3])
        images_info = tf.expand_dims(image_info, axis=0)

        # model inference
        outputs = model.build_outputs(images, {'image_info': images_info},
                                      mode=mode_keys.PREDICT)

        outputs['detection_boxes'] = (
            outputs['detection_boxes'] /
            tf.tile(images_info[:, 2:3, :], [1, 1, 2]))

        predictions = outputs

        # Create a saver in order to load the pre-trained checkpoint.
        saver = tf.train.Saver()

        image_with_detections_list = []
        with tf.Session() as sess:
            print(' - Loading the checkpoint...')
            saver.restore(sess, FLAGS.checkpoint_path)

            image_files = tf.gfile.Glob(FLAGS.image_file_pattern)
            for i, image_file in enumerate(image_files):
                print(' - Processing image %d...' % i)

                with tf.gfile.GFile(image_file, 'rb') as f:
                    image_bytes = f.read()

                image = Image.open(image_file)
                image = image.convert(
                    'RGB')  # needed for images with 4 channels.
                width, height = image.size
                np_image = (np.array(image.getdata()).reshape(
                    height, width, 3).astype(np.uint8))

                predictions_np = sess.run(predictions,
                                          feed_dict={image_input: image_bytes})

                num_detections = int(predictions_np['num_detections'][0])
                np_boxes = predictions_np['detection_boxes'][
                    0, :num_detections]
                np_scores = predictions_np['detection_scores'][
                    0, :num_detections]
                np_classes = predictions_np['detection_classes'][
                    0, :num_detections]
                np_classes = np_classes.astype(np.int32)
                np_masks = None
                if 'detection_masks' in predictions_np:
                    instance_masks = predictions_np['detection_masks'][
                        0, :num_detections]
                    np_masks = mask_utils.paste_instance_masks(
                        instance_masks, box_utils.yxyx_to_xywh(np_boxes),
                        height, width)

                image_with_detections = (
                    visualization_utils.
                    visualize_boxes_and_labels_on_image_array(
                        np_image,
                        np_boxes,
                        np_classes,
                        np_scores,
                        label_map_dict,
                        instance_masks=np_masks,
                        use_normalized_coordinates=False,
                        max_boxes_to_draw=FLAGS.max_boxes_to_draw,
                        min_score_thresh=FLAGS.min_score_threshold))
                image_with_detections_list.append(image_with_detections)

    print(' - Saving the outputs...')
    formatted_image_with_detections_list = [
        Image.fromarray(image.astype(np.uint8))
        for image in image_with_detections_list
    ]
    html_str = '<html>'
    image_strs = []
    for formatted_image in formatted_image_with_detections_list:
        with io.BytesIO() as stream:
            formatted_image.save(stream, format='JPEG')
            data_uri = base64.b64encode(stream.getvalue()).decode('utf-8')
        image_strs.append(
            '<img src="data:image/jpeg;base64,{}", height=800>'.format(
                data_uri))
    images_str = ' '.join(image_strs)
    html_str += images_str
    html_str += '</html>'
    with tf.gfile.GFile(FLAGS.output_html, 'w') as f:
        f.write(html_str)
Esempio n. 7
0
def build_predictions(features,
                      params,
                      output_image_info,
                      output_normalized_coordinates,
                      cast_num_detections_to_float,
                      cast_detection_classes_to_float=False):
    """Builds the model graph for serving.

  Args:
    features: features to be passed to the serving model graph
    params: hyperparameters to be passed to the serving model graph
    output_image_info: bool, whether output the image_info node.
    output_normalized_coordinates: bool, whether box outputs are in the
      normalized coordinates.
    cast_num_detections_to_float: bool, whether to cast the number of detections
      to float type.
    cast_detection_classes_to_float: bool, whether or not cast the detection
      classes  to float type.

  Returns:
    predictions: model outputs for serving.
    model_outputs: a dict of model output tensors.
  """
    images = features['images']
    batch_size, height, width, _ = images.get_shape().as_list()

    input_anchor = anchor.Anchor(params.architecture.min_level,
                                 params.architecture.max_level,
                                 params.anchor.num_scales,
                                 params.anchor.aspect_ratios,
                                 params.anchor.anchor_size, (height, width))

    multilevel_boxes = {}
    for k, v in six.iteritems(input_anchor.multilevel_boxes):
        multilevel_boxes[k] = tf.tile(tf.expand_dims(v, 0), [batch_size, 1, 1])

    model_fn = factory.model_generator(params)
    model_outputs = model_fn.build_outputs(features['images'],
                                           labels={
                                               'anchor_boxes':
                                               multilevel_boxes,
                                               'image_info':
                                               features['image_info'],
                                           },
                                           mode=mode_keys.PREDICT)

    # Return flattened raw outputs.
    if not params.postprocess.apply_nms:
        predictions = {
            'raw_boxes': tf.identity(model_outputs['raw_boxes'], 'RawBoxes'),
            'raw_scores': tf.identity(model_outputs['raw_scores'],
                                      'RawScores'),
        }
        return predictions, model_outputs

    if cast_num_detections_to_float:
        model_outputs['num_detections'] = tf.cast(
            model_outputs['num_detections'], dtype=tf.float32)

    if cast_detection_classes_to_float:
        model_outputs['detection_classes'] = tf.cast(
            model_outputs['detection_classes'], dtype=tf.float32)

    if output_image_info:
        model_outputs.update({
            'image_info': features['image_info'],
        })

    if output_normalized_coordinates:
        detection_boxes = (
            model_outputs['detection_boxes'] /
            tf.tile(features['image_info'][:, 2:3, :], [1, 1, 2]))
        model_outputs['detection_boxes'] = box_utils.normalize_boxes(
            detection_boxes, features['image_info'][:, 0:1, :])

    predictions = {
        'num_detections':
        tf.identity(model_outputs['num_detections'], 'NumDetections'),
        'detection_boxes':
        tf.identity(model_outputs['detection_boxes'], 'DetectionBoxes'),
        'detection_classes':
        tf.identity(model_outputs['detection_classes'], 'DetectionClasses'),
        'detection_scores':
        tf.identity(model_outputs['detection_scores'], 'DetectionScores'),
    }
    if 'detection_masks' in model_outputs:
        predictions.update({
            'detection_masks':
            tf.identity(model_outputs['detection_masks'], 'DetectionMasks'),
        })
        if 'detection_outer_boxes' in model_outputs:
            predictions.update({
                'detection_outer_boxes':
                tf.identity(model_outputs['detection_outer_boxes'],
                            'DetectionOuterBoxes'),
            })

    if output_image_info:
        predictions['image_info'] = tf.identity(model_outputs['image_info'],
                                                'ImageInfo')

    return predictions, model_outputs
Esempio n. 8
0
 def __init__(self, params):
   self._model = factory.model_generator(params)
Esempio n. 9
0
  def _serving_model_fn(features, labels, mode, params):
    """Builds the serving model_fn."""
    del labels  # unused.
    if mode != tf.estimator.ModeKeys.PREDICT:
      raise ValueError('To build the serving model_fn, set '
                       'mode = `tf.estimator.ModeKeys.PREDICT`')

    model_params = params_dict.ParamsDict(params)

    images = features['images']
    _, height, width, _ = images.get_shape().as_list()

    model_fn = factory.model_generator(model_params)
    outputs = model_fn.build_outputs(
        features['images'], labels=None, mode=mode_keys.PREDICT)

    logits = tf.image.resize_bilinear(
        outputs['logits'], tf.shape(images)[1:3], align_corners=False)

    original_image_size = tf.squeeze(features['image_info'][:, 0:1, :])
    height = original_image_size[0]
    width = original_image_size[1]
    offset_height = tf.zeros_like(height, dtype=tf.int32)
    offset_width = tf.zeros_like(width, dtype=tf.int32)

    # Clip the predictions to original image size.
    logits = tf.image.crop_to_bounding_box(logits, offset_height, offset_width,
                                           tf.cast(height, dtype=tf.int32),
                                           tf.cast(width, dtype=tf.int32))
    probabilities = tf.nn.softmax(logits)

    score_threshold_placeholder = features['score_thresholds']
    key_placeholder = features['key']

    score_threshold_pred_expanded = score_threshold_placeholder
    for _ in range(0, logits.shape.ndims - 1):
      score_threshold_pred_expanded = tf.expand_dims(
          score_threshold_pred_expanded, -1)

    scores = tf.where(probabilities > score_threshold_pred_expanded,
                      probabilities, tf.zeros_like(probabilities))
    scores = tf.reduce_max(scores, 3)
    scores = tf.expand_dims(scores, -1)
    scores = tf.cast(tf.minimum(scores * 255.0, 255), tf.uint8)
    categories = tf.to_int32(tf.expand_dims(tf.argmax(probabilities, 3), -1))

    # Generate images for scores and categories.
    score_bytes = tf.map_fn(
        tf.image.encode_png, scores, back_prop=False, dtype=tf.string)
    category_bytes = tf.map_fn(
        tf.image.encode_png,
        tf.cast(categories, tf.uint8),
        back_prop=False,
        dtype=tf.string)

    predictions = {}

    predictions['category_bytes'] = tf.identity(
        category_bytes, name='category_bytes')
    predictions['score_bytes'] = tf.identity(score_bytes, name='score_bytes')
    predictions['key'] = tf.identity(key_placeholder, name='key')
    if output_image_info:
      predictions['image_info'] = tf.identity(
          features['image_info'], name='image_info')

    if export_tpu_model:
      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode, predictions=predictions)
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
Esempio n. 10
0
def main(unused_argv):
  del unused_argv

  params = config_factory.config_generator(FLAGS.model)
  if FLAGS.config_file:
    params = params_dict.override_params_dict(
        params, FLAGS.config_file, is_strict=True)
  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  # We currently only support batch_size = 1 to evaluate images one by one.
  # Override the `eval_batch_size` = 1 here.
  params.override({
      'eval': {
          'eval_batch_size': 1,
      },
  })
  params.validate()
  params.lock()

  model = model_factory.model_generator(params)
  evaluator = evaluator_factory.evaluator_generator(params.eval)

  parse_fn = functools.partial(parse_single_example, params=params)
  with tf.Graph().as_default():
    dataset = tf.data.Dataset.list_files(
        params.eval.eval_file_pattern, shuffle=False)
    dataset = dataset.apply(
        tf.data.experimental.parallel_interleave(
            lambda filename: tf.data.TFRecordDataset(filename).prefetch(1),
            cycle_length=32,
            sloppy=False))
    dataset = dataset.map(parse_fn, num_parallel_calls=64)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(1, drop_remainder=False)

    images, labels, groundtruths = dataset.make_one_shot_iterator().get_next()
    images.set_shape([
        1,
        params.retinanet_parser.output_size[0],
        params.retinanet_parser.output_size[1],
        3])

    # model inference
    outputs = model.build_outputs(images, labels, mode=mode_keys.PREDICT)

    predictions = outputs
    predictions.update({
        'source_id': groundtruths['source_id'],
        'image_info': labels['image_info'],
    })

    # Create a saver in order to load the pre-trained checkpoint.
    saver = tf.train.Saver()

    with tf.Session() as sess:
      saver.restore(sess, FLAGS.checkpoint_path)

      num_batches = params.eval.eval_samples // params.eval.eval_batch_size
      for i in range(num_batches):
        if i % 100 == 0:
          print('{}/{} batches...'.format(i, num_batches))
        predictions_np, groundtruths_np = sess.run([predictions, groundtruths])
        evaluator.update(predictions_np, groundtruths_np)

    if FLAGS.dump_predictions_only:
      print('Dumping the predction results...')
      evaluator.dump_predictions(FLAGS.predictions_path)
      print('Done!')
    else:
      print('Evaluating the prediction results...')
      metrics = evaluator.evaluate()
      print('Eval results: {}'.format(metrics))
Esempio n. 11
0
def main(unused_argv):
    del unused_argv
    # Load the label map.
    print(' - Loading the label map...')
    label_map_dict = {}
    if FLAGS.label_map_format == 'csv':
        with tf.gfile.Open(FLAGS.label_map_file, 'r') as csv_file:
            reader = csv.reader(csv_file, delimiter=':')
            for row in reader:
                if len(row) != 2:
                    raise ValueError(
                        'Each row of the csv label map file must be in '
                        '`id:name` format.')
                id_index = int(row[0])
                name = row[1]
                label_map_dict[id_index] = {
                    'id': id_index,
                    'name': name,
                }
    else:
        raise ValueError('Unsupported label map format: {}.'.format(
            FLAGS.label_mape_format))

    params = config_factory.config_generator(FLAGS.model)
    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override(
        {
            'architecture': {
                'use_bfloat16': False,  # The inference runs on CPU/GPU.
            },
        },
        is_strict=True)
    params.validate()
    params.lock()

    model = model_factory.model_generator(params)

    with tf.Graph().as_default():
        image_input = tf.placeholder(shape=(), dtype=tf.string)
        image = tf.io.decode_image(image_input, channels=3)
        image.set_shape([None, None, 3])

        image = input_utils.normalize_image(image)
        image_size = [FLAGS.image_size, FLAGS.image_size]
        image, image_info = input_utils.resize_and_crop_image(
            image,
            image_size,
            image_size,
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image.set_shape([image_size[0], image_size[1], 3])

        # batching.
        images = tf.reshape(image, [1, image_size[0], image_size[1], 3])
        images_info = tf.expand_dims(image_info, axis=0)

        # model inference
        outputs = model.build_outputs(images, {'image_info': images_info},
                                      mode=mode_keys.PREDICT)

        # outputs['detection_boxes'] = (
        #     outputs['detection_boxes'] / tf.tile(images_info[:, 2:3, :], [1, 1, 2]))

        predictions = outputs

        # Create a saver in order to load the pre-trained checkpoint.
        saver = tf.train.Saver()

        image_with_detections_list = []
        with tf.Session() as sess:
            print(' - Loading the checkpoint...')
            saver.restore(sess, FLAGS.checkpoint_path)

            image_files = tf.gfile.Glob(FLAGS.image_file_pattern)
            for i, image_file in enumerate(image_files):
                print(' - Processing image %d...' % i)

                with tf.gfile.GFile(image_file, 'rb') as f:
                    image_bytes = f.read()

                image = Image.open(image_file)
                image = image.convert(
                    'RGB')  # needed for images with 4 channels.
                width, height = image.size
                np_image = (np.array(image.getdata()).reshape(
                    height, width, 3).astype(np.uint8))
                print(np_image.shape)

                predictions_np = sess.run(predictions,
                                          feed_dict={image_input: image_bytes})

                logits = predictions_np['logits'][0]
                print(logits.shape)

                labels = np.argmax(logits.squeeze(), -1)
                print(labels.shape)
                print(labels)
                labels = np.array(Image.fromarray(labels.astype('uint8')))
                print(labels.shape)

                plt.imshow(labels)
                plt.savefig(f"temp-{i}.png")
Esempio n. 12
0
def initiate():
    # Load the label map.
    print(' - Loading the label map...')
    label_map_dict = {}
    if 'csv' == 'csv':
        with tf.gfile.Open('dataset/fashionpedia_label_map.csv',
                           'r') as csv_file:
            reader = csv.reader(csv_file, delimiter=':')
            for row in reader:
                if len(row) != 2:
                    raise ValueError(
                        'Each row of the csv label map file must be in '
                        '`id:name` format.')
                id_index = int(row[0])
                name = row[1]
                label_map_dict[id_index] = {
                    'id': id_index,
                    'name': name,
                }
    else:
        raise ValueError('Unsupported label map format: {}.'.format('csv'))

    params = config_factory.config_generator('attribute_mask_rcnn')
    if 'configs/yaml/spinenet49_amrcnn.yaml':
        params = params_dict.override_params_dict(
            params, 'configs/yaml/spinenet49_amrcnn.yaml', is_strict=True)
    params = params_dict.override_params_dict(params, '', is_strict=True)
    params.override(
        {
            'architecture': {
                'use_bfloat16': False,  # The inference runs on CPU/GPU.
            },
        },
        is_strict=True)
    params.validate()
    params.lock()

    model = model_factory.model_generator(params)

    with tf.Graph().as_default():
        image_input = tf.placeholder(shape=(), dtype=tf.string)
        image = tf.io.decode_image(image_input, channels=3)
        image.set_shape([None, None, 3])

        image = input_utils.normalize_image(image)
        image_size = [640, 640]
        image, image_info = input_utils.resize_and_crop_image(
            image,
            image_size,
            image_size,
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image.set_shape([image_size[0], image_size[1], 3])

        # batching.
        images = tf.reshape(image, [1, image_size[0], image_size[1], 3])
        images_info = tf.expand_dims(image_info, axis=0)

        # model inference
        outputs = model.build_outputs(images, {'image_info': images_info},
                                      mode=mode_keys.PREDICT)

        outputs['detection_boxes'] = (
            outputs['detection_boxes'] /
            tf.tile(images_info[:, 2:3, :], [1, 1, 2]))

        predictions = outputs

        # Create a saver in order to load the pre-trained checkpoint.
        saver = tf.train.Saver()
        sess = tf.Session()
        print(' - Loading the checkpoint...')
        saver.restore(sess, 'fashionpedia-spinenet-49/model.ckpt')
        print(' - Checkpoint Loaded...')
        return sess, predictions, image_input