Beispiel #1
0
    def _helper(self, image_path, image_np=None):
        """

    :param image_path: Path to an image with human faces.
    :param image_np: Optional numpy array containing image in [h,w,c] format. Overrides image_path.
    :return: cropped faces as a list of numpy arrays
    """
        if image_np is None:
            image_np = util_io.imread(image_path)
        # the array based representation of the image will be used later in order to prepare the
        # result image with boxes and labels on it.
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        image_tensor = self.detection_graph.get_tensor_by_name(
            'image_tensor:0')
        # Each box represents a part of the image where a particular object was detected.
        boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
        # Each score represent how level of confidence for each of the objects.
        scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
        classes = self.detection_graph.get_tensor_by_name(
            'detection_classes:0')
        num_detections = self.detection_graph.get_tensor_by_name(
            'num_detections:0')
        # Actual detection.
        start_time = time.time()
        (boxes, scores, classes, num_detections) = self.sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        elapsed_time = time.time() - start_time
        print('Face cropping inference time cost: {}'.format(elapsed_time))

        return (image_np, boxes, scores, classes, num_detections)
Beispiel #2
0
  def do_inference(self, output_dir, image_path=None, image_np=None):
    """Tests PredictionService with concurrent requests.

    Args:
      output_dir: Directory to output image.
      image_path: Path to image.
      image_np: Image in np format. Ignored when image_path is set.

    Returns:
      `output_dir`.
    """
    if image_path is None and image_np is None:
      raise ValueError('Either `image_np` or `image_path` must be specified.')

    if image_path:
      image_resized = util_io.imread(image_path, (self.image_hw, self.image_hw))
    else:
      image_resized = scipy.misc.imresize(image_np, (self.image_hw, self.image_hw))
    # TODO: do preprocessing in a separate function. Check whether image has already been preprocessed.
    image = np.expand_dims(image_resized / np.float32(255.0), 0)

    stub = prediction_service_pb2.beta_create_PredictionService_stub(self.channel)
    request = predict_pb2.PredictRequest()
    request.CopyFrom(self.request_template)
    self._request_set_input_image(request, image)
    result_future = stub.Predict.future(request, 5.0)  # 5 seconds
    result_future.add_done_callback(self._create_rpc_callback(output_dir))
    return output_dir
Beispiel #3
0
 def combine_original_and_transferred(images,
                                      transferred_2x_image_file_format,
                                      combined_image_pattern):
     ret = []
     for i in range(len(images)):
         save_image_path = combined_image_pattern % i
         ret.append(save_image_path)
         if os.path.exists(save_image_path):
             continue
         start_time = time.time()
         transferred_2x_image_file_path = transferred_2x_image_file_format % i
         while not os.path.exists(transferred_2x_image_file_path
                                  ) and time.time() - start_time < 5:
             time.sleep(1)
         transferred_image = None
         while time.time() - start_time < 5:
             try:
                 transferred_image = util_io.imread(
                     transferred_2x_image_file_path, )
             except IOError:
                 time.sleep(1)
         if transferred_image is None:
             raise IOError('Cannot read image file %s' %
                           (transferred_2x_image_file_path))
         face_image = scipy.misc.imresize(
             images[i],
             (transferred_image.shape[0], transferred_image.shape[1]))
         combined_image = np.concatenate((face_image, transferred_image),
                                         axis=1)
         util_io.imsave(save_image_path, combined_image)
     return ret
 def infer(self, input_image_path, return_image_paths=False, num_output=None):
     """Given an image, a path containing images, or a list of paths, return the outputs."""
     one_output = False
     if input_image_path:
         if isinstance(input_image_path, list) or isinstance(input_image_path, tuple):
             image_paths = input_image_path
         else:
             if os.path.isfile(input_image_path):
                 image_paths = [input_image_path]
                 one_output = True
             else:
                 image_paths = util_io.get_files_in_dir(
                     input_image_path, do_sort=True, do_random_ordering=False)
         images = [util_io.imread(image_path, dtype=np.uint8)
                   for image_path in image_paths]
     else:
         assert num_output >= 1
         images = [None for _ in range(num_output)]
         image_paths = [str(i) for i in range(num_output)]
         one_output = (num_output == 1)
     outputs = []
     for image in images:
         if image is None:
             feed_dict = None
         else:
             feed_dict = {self.images_placeholder: image}
         output = self.sess.run(self.output, feed_dict=feed_dict)
         output = output[0] * 255.0  # Batch size == 1, range = 0~1.
         outputs.append(output)
     if one_output:
         outputs = outputs[0]
         image_paths = image_paths[0]
     if return_image_paths:
         return outputs, image_paths
     return outputs
  def crop_face(self, image_path):
    """

    :param image_path: Path to an image with human faces.
    :return: cropped faces as a list of numpy arrays
    """
    image_np = util_io.imread(image_path)
    # the array based representation of the image will be used later in order to prepare the
    # result image with boxes and labels on it.
    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
    image_np_expanded = np.expand_dims(image_np, axis=0)
    image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
    # Each box represents a part of the image where a particular object was detected.
    boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
    # Each score represent how level of confidence for each of the objects.
    scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
    classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
    # Actual detection.
    start_time = time.time()
    (boxes, scores, classes, num_detections) = self.sess.run(
      [boxes, scores, classes, num_detections],
      feed_dict={image_tensor: image_np_expanded})
    elapsed_time = time.time() - start_time
    print('Face cropping inference time cost: {}'.format(elapsed_time))

    return crop_by_category(
      image_np,
      np.squeeze(boxes),
      np.squeeze(classes).astype(np.int32),
      np.squeeze(scores),
      self.category_index,
      category_to_crop=self.face_category_index
    )
  def test_imread_and_imsave_utf8(self):
    height = 256
    width = 256

    content_folder = tempfile.mkdtemp()
    image_path = content_folder + u'/骨董屋・三千世界の女主人_12746957.png'
    current_image = np.ones((height, width, 3)) * 255.0
    current_image[0, 0, 0] = 0
    util_io.imsave(image_path, current_image)

    actual_output = util_io.imread(util_io.get_all_image_paths(content_folder + '/')[0])

    expected_answer = np.round(np.array(current_image))
    np.testing.assert_almost_equal(expected_answer, actual_output)
  def test_imread_rgba(self):
    height = 256
    width = 256

    content_folder = tempfile.mkdtemp()
    image_path = content_folder + '/image.png'
    current_image = np.ones((height, width, 4)) * 255.0
    current_image[0, 0, 0] = 0
    scipy.misc.imsave(image_path, current_image)

    content_pre_list = util_io.imread(image_path, rgba=True)
    expected_answer = current_image
    np.testing.assert_almost_equal(expected_answer, content_pre_list)

    shutil.rmtree(content_folder)
Beispiel #8
0
    def test_imread_bw(self):
        height = 256
        width = 256

        content_folder = tempfile.mkdtemp()
        image_path = content_folder + u'/骨董屋・三千世界の女主人_12746957.png'
        current_image = np.ones((height, width, 3)) * 255.0
        current_image[0, 0, 0] = 0
        util_io.imsave(image_path, current_image)

        actual_output = util_io.imread(
            util_io.get_files_in_dir(content_folder + '/')[0], bw=True)

        expected_answer = np.floor(_rgb2gray(np.array(current_image)))
        np.testing.assert_almost_equal(expected_answer, actual_output)
 def combine_original_and_transferred(images,
                                      transferred_2x_image_file_format,
                                      combined_image_pattern):
     ret = []
     for i in range(len(images)):
         save_image_path = combined_image_pattern % i
         ret.append(save_image_path)
         if os.path.exists(save_image_path):
             continue
         # raise NotImplementedError('deal with error: IOError: cannot identify image file  u\'./static/images/transferred_faces/')
         transferred_image = util_io.imread(
             transferred_2x_image_file_format % i, )
         face_image = scipy.misc.imresize(
             images[i],
             (transferred_image.shape[0], transferred_image.shape[1]))
         combined_image = np.concatenate((face_image, transferred_image),
                                         axis=1)
         util_io.imsave(save_image_path, combined_image)
     return ret
 def test_is_sketch(self):
     image_dir = u'/mnt/f032b8a5-c186-4fae-b911-bcfdee99a2e9/pixiv_collected_sketches/PixivUtil2/test_samples_tiny'
     image_paths_and_expected = [
         (u'11739035_p1 - レミリア.jpg', False),
         (u'12925016_p2 - 東方系ラクガキ詰め合わせ.jpg', False),
         (u'14362806_p0 - 無題.jpg', False),  # Background not white enough.
         (u'14485122_p0 - 宮子.jpg', False),  # Background not white enough.
         (u'15444948_p4 - ぐ~てん☆もるげんっ!.jpg', False),
         (u'15469173_p0 - 東方魔理沙 塗ってみた.jpg', False),
         (u'17834866_p0 - 以蔵の闘い.jpg', False),
         (u'24774862_p2 - からくりばーすと、メイキング.jpg', False),
         (u'28474957_p0 - 白澤さん。.jpg',
          True),  # Wierd edge case unrecognizable to human eyes.
         (u'28592442_p0 - 本田菊.jpg', False),
         (u'29908672_p0 - ソードアート・おっぱい.jpg', False),
         (u'30454124_p0 - 水野亜美(セーラーマーキュリー).jpg', False),
         (u'5833646_p0 - Q命病棟でQP化(線画).png', True),
         (u'5943678_p0 - バニー達は塗ってほしそうにこちらを見ている・・.jpg', True),
         (u'5952556_p0 - 死神線画.png', True),
         (u'5981054_p0 - リン・レン線画.jpg', False),
         (u'5987994_p0 - お願いします。.png', True),
         (u'7210533_p0 - カウントダウン2日前.jpg', False),
         (u'7242346_p0 - ミリー【線画】.jpg', True),
         (u'7252006_p0 - ナタネ&ロズレイド線画.jpg', True),
         (u'7304820_p0 - 少女の宮サマ。.jpg', True),
         (u'7431716_p0 - 今吉さん線画.jpg', True),
         (u'8113468_p0 - あんこくのじょおう【再投稿】.png', True),
         (u'8152845_p0 - ヘルメスさん.jpg', False),
         (u'8425348_p0 - 早苗さん(線画).png',
          True),  # PNG bug, minor recall loss.
         (u'8441291_p0 - 【線画】ベルベル.jpg', False),
         (u'8502938_p0 - ベアト線画!.jpg', True),
     ]
     for image_name, expected in image_paths_and_expected:
         image = util_io.imread(os.path.join(image_dir, image_name))
         actual_output = util_misc.is_sketch(image)
         if not actual_output == expected:
             print('unexpected: ', image_name)
Beispiel #11
0
 def __init__(self, hostport, image_hw, **kwargs):
   super(MockTwinGANClient, self).__init__(hostport, image_hw,**kwargs)
   self.mock_output_image = util_io.imread('static/images/mock/mock_twingan_output.png',
                                           shape=(image_hw,image_hw))
Beispiel #12
0
    def do_POST(self):
        # TODO: refactor this file. A lot of asynchronous functions are hacky and ugly.
        form = self.parse_POST()

        if 'id' in form:
            id_str = form['id'][0]
            id_str = id_str.decode()
        else:
            id_str = 'test'

        if 'register_download' in form and form['register_download']:
            if 'subid' in form:
                subid_str = form['subid'][0]
                subid_str = subid_str.decode()
            else:
                subid_str = '0'
            print(
                'TODO: do action for register download for id: %s and subid: %s'
                % (id_str, subid_str))
            self.post_success(id_str, )
            return

        elif 'image' in form:
            bin1 = form['image'][0]
            input_image_path = interface_utils.save_encoded_image(
                bin1, './static/images/inputs/' + id_str)

            cropped_image_pattern = './static/images/cropped_faces/' + id_str + '_%d.png'
            faces = FACE_DETECTOR.crop_face_and_save(input_image_path,
                                                     cropped_image_pattern)
            num_faces = len(faces)
            if num_faces > FLAGS.max_num_faces:
                faces = faces[FLAGS.max_num_faces]
                num_faces = FLAGS.max_num_faces
            if num_faces == 0:
                shutil.copy(
                    input_image_path,
                    './static/images/cropped_faces/' + id_str + '_%d.png' % 0)
                faces = [
                    util_io.imread(input_image_path,
                                   (FLAGS.image_hw, FLAGS.image_hw))
                ]
                num_faces = len(faces)

            transferred_image_file_format = './static/images/transferred_faces/' + id_str + '_%d.png'
            succeed, transferred_image_files = self.automatic_retry(
                functools.partial(
                    self.domain_transfer,
                    transferred_image_file_format=transferred_image_file_format,
                    images=faces))
            if not succeed:
                self.post_server_internal_error('Domain transfer failed',
                                                id_str,
                                                {'num_faces': num_faces})
                return

            if 'do_waifu2x' in form:
                do_waifu2x = form['do_waifu2x'][0] == 'true'
            else:
                do_waifu2x = False
            if do_waifu2x:
                transferred_2x_image_file_format = './static/images/transferred_faces_2x/' + id_str + '_%d.png'
                succeed, transferred_2x_image_files = self.automatic_retry(
                    functools.partial(self.call_waifu2x,
                                      transferred_image_file_format=
                                      transferred_image_file_format,
                                      transferred_2x_image_file_format=
                                      transferred_2x_image_file_format,
                                      num_images=num_faces))
                if not succeed:
                    self.post_server_internal_error('Waifu2x failed', id_str,
                                                    {'num_faces': num_faces})
                    return
                transferred_image_to_be_combined_format = transferred_2x_image_file_format
            else:
                transferred_image_to_be_combined_format = transferred_image_file_format

            combined_image_pattern = './static/images/combined/' + id_str + '_%d.png'
            succeed, combined_images = self.automatic_retry(
                functools.partial(
                    self.combine_original_and_transferred,
                    images=faces,
                    combined_image_pattern=combined_image_pattern,
                    transferred_2x_image_file_format=
                    transferred_image_to_be_combined_format))
            if not succeed:
                self.post_server_internal_error(
                    'Combine original and transferred failed.', id_str,
                    {'num_faces': num_faces})
                return

            self.post_success(id_str, {'num_faces': num_faces})
        else:
            self.post_bad_request('Post request must contain image.', id_str)
        return
    def do_inference(self,
                     output_dir,
                     center_point_xy,
                     sketch_image_np=None,
                     image_path=None,
                     image_np=None):
        """Tests PredictionService with concurrent requests.

    Args:
      output_dir: Directory to output image.
      image_path: Path to image.

    Returns:
      `output_dir`.
    """
        sketch_image = sketch_image_np
        if len(sketch_image.shape) == 2:
            sketch_image = np.expand_dims(sketch_image, axis=-1)
        if self.supervised:
            if image_path is None and image_np is None:
                raise ValueError(
                    'Either `image_np` or `image_path` must be specified.')

            if image_path:
                image = util_io.imread(image_path, bw=True)
            else:
                image = image_np
            if len(image.shape) == 2:
                image = np.expand_dims(image, axis=-1)
            assert image.shape == sketch_image.shape
            combined = np.concatenate((sketch_image, image),
                                      axis=-1).astype(np.float32)
            # Select the subregion of interest.
            # Note: In numpy and tensorflow we're in (h,w,c) format.
            start = (max(0, center_point_xy[1] - self.image_hw / 2),
                     max(0, center_point_xy[0] - self.image_hw / 2))
            end = (center_point_xy[1] + self.image_hw / 2,
                   center_point_xy[0] + self.image_hw / 2)
            subregion = combined[start[0]:end[0], start[1]:end[1]]
            subregion_shape = subregion.shape
            if subregion_shape[0] != self.image_hw or subregion_shape[
                    1] != self.image_hw:
                # Stupid imresize only accepts hxw images or hxwx3 images.
                subregion_resized = np.concatenate(
                    (np.expand_dims(scipy.misc.imresize(
                        subregion[..., 0], (self.image_hw, self.image_hw)),
                                    axis=-1),
                     np.expand_dims(scipy.misc.imresize(
                         subregion[..., 1], (self.image_hw, self.image_hw)),
                                    axis=-1)),
                    axis=-1)
            else:
                subregion_resized = subregion

            # TODO: do preprocessing in a separate function. Check whether image has already been preprocessed.
            subregion_resized = np.expand_dims(
                subregion_resized / np.float32(255, ), 0)
            input_image = subregion_resized
            callback_kwargs = {
                'start': start,
                'end': end,
                'subregion_shape': subregion_shape
            }
        else:
            input_image = sketch_image
            callback_kwargs = dict()

        stub = prediction_service_pb2.beta_create_PredictionService_stub(
            self.channel)
        request = predict_pb2.PredictRequest()
        request.CopyFrom(self.request_template)
        self._request_set_input_image(request, input_image)
        result_future = stub.Predict.future(request, self.timeout)  # 5 seconds
        result_future.add_done_callback(
            self._create_rpc_callback(output_dir,
                                      sketch_image,
                                      supervised=self.supervised,
                                      **callback_kwargs))
        return output_dir
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    inference_class = mask_inference if FLAGS.detect_masks else detection_inference
    if not os.path.exists(FLAGS.output_path):
        tf.gfile.MakeDirs(FLAGS.output_path)

    required_flags = ['input_images', 'output_path', 'inference_graph']
    for flag_name in required_flags:
        if not getattr(FLAGS, flag_name):
            raise ValueError('Flag --{} is required'.format(flag_name))

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, ))

    input_image_paths = []
    for v in FLAGS.input_images.split(','):
        if v:
            input_image_paths += tf.gfile.Glob(v)
    tf.logging.info('Reading input from %d files', len(input_image_paths))
    # tf.logging.info('filepath is :' + input_image_paths[0])

    image_ph, image_tensor = build_input()

    tf.logging.info('Reading graph and building model...')
    detected_tensors = inference_class.build_inference_graph(
        image_tensor,
        FLAGS.inference_graph,
        override_num_detections=FLAGS.override_num_detections)

    tf.logging.info('Running inference and writing output to {}'.format(
        FLAGS.output_path))
    sess.run(tf.local_variables_initializer())

    for i, image_path in enumerate(input_image_paths):
        image_np = util_io.imread(image_path)
        inference_class.crop_out(sess,
                                 image_tensor,
                                 detected_tensors,
                                 min_score_thresh=FLAGS.min_score_thresh,
                                 visualize_inference=FLAGS.visualize_inference,
                                 feed_dict={image_ph: image_np},
                                 img_np=image_np,
                                 filename=os.path.join(
                                     FLAGS.output_path,
                                     os.path.basename(image_path)))

        # break
        # result = inference_class.infer_detections(
        #   sess, image_tensor, detected_tensors,
        #   min_score_thresh=FLAGS.min_score_thresh,
        #   visualize_inference=FLAGS.visualize_inference,
        #   feed_dict={image_ph: image_np})
        # if FLAGS.visualize_inference:
        #   output_image = os.path.join(FLAGS.output_path, os.path.basename(image_path))
        #   util_io.imsave(output_image, result['annotated_image'])
        #   del result['annotated_image']  # No need to write the image to json.
        # if FLAGS.detect_masks:
        #   base, ext = os.path.splitext(os.path.basename(image_path))
        #   for mask_i in range(len(result['detected_masks'])):
        #     # Stores as png to preserve accurate mask values.
        #     output_mask = os.path.join(FLAGS.output_path, base + '_mask_%d' % mask_i + '.png')
        #     util_io.imsave(output_mask, np.array(result['detected_masks'][mask_i]) * 255)
        #   del result['detected_masks']  # Storing mask in json is pretty space consuming.
        #
        # output_file = os.path.join(FLAGS.output_path, os.path.splitext(os.path.basename(image_path))[0] + '.json')
        # with open(output_file, 'w') as f:
        #   json.dump(result, f)
        #
        # tf.logging.log_every_n(tf.logging.INFO, 'Processed %d/%d images...', 10, i, len(input_image_paths))

    print('Finished processing all images in data set.')
    def do_POST(self):
        form = self.parse_POST()

        post_type = None
        if 'type' in form:
            post_type = form['type'][0]

        if post_type == 'labeler':
            if 'id' in form:
                id_str = form['id'][0]
                id_str = id_str.decode()

                skipped = (form['skip'][0] == 'true')
                if not skipped:
                    bin1 = form['image'][0]
                    bin1 = bin1.decode().split(',')[1]
                    bin1 = base64.b64decode(bin1.encode())
                    image_path = os.path.join(
                        FLAGS.labeler_output_image_path,
                        os.path.splitext(id_str)[0] +
                        '.png')  # id_str happens to be the file name.
                    with open(image_path, 'wb') as fout1:
                        fout1.write(bin1)
                LABELER_CLIENT.mark_current_as_done(skipped)

            image, sketch, image_id = LABELER_CLIENT.get_image_and_id()
            if image:
                self.post_success(image_id, {'image': image, 'sketch': sketch})
            else:
                self.post_success(image_id, {'error': 'Ran out of images.'})
            return
        elif post_type == 'sketch_refinement':
            id_str = form['id'][0]
            id_str = id_str.decode()
            if 'sketch_refinement' not in form or 'sketch' not in form:
                self.post_bad_request(
                    'must contain sketch_refinement and sketch', id_str)
                return
            x_pos = int(form['x'][0])
            y_pos = int(form['y'][0])
            # We only need the alpha channel. That will give us a grayscale (0/1 in this case) image.
            sketch_refinement_img = interface_utils.base64_to_numpy(
                form['sketch_refinement'][0], contains_format=True)[..., 3]
            sketch_img = interface_utils.base64_to_numpy(form['sketch'][0],
                                                         contains_format=True)
            sketch_img = util_misc.im2gray(sketch_img)
            # with open(image_path, 'wb') as fout1:
            #   fout1.write(sketch_refinement)
            out_path = os.path.join(
                FLAGS.labeler_sketch_refinement_folder,
                os.path.splitext(id_str)[0] + '_' + str(time.time()) + '.png')
            SKETCH_REFINEMENT_CLIENT.do_inference(
                out_path,
                center_point_xy=[x_pos, y_pos],
                sketch_image_np=sketch_img,
                image_np=sketch_refinement_img)
            sketch_refinement_exists = SKETCH_REFINEMENT_CLIENT.block_on_callback(
                out_path)
            if not sketch_refinement_exists:
                self.post_server_internal_error('sketch refinement timed out.',
                                                id_str)
                return
            LABELER_CLIENT.set_current_sketch_path(out_path)
            self.post_success(id_str, {
                'refined_sketch':
                interface_utils.get_image_encoding(out_path)
            })
            return

        if 'id' in form:
            id_str = form['id'][0]
            id_str = id_str.decode()
        else:
            id_str = 'test'

        if 'register_download' in form and form['register_download']:
            if 'subid' in form:
                subid_str = form['subid'][0]
                subid_str = subid_str.decode()
            else:
                subid_str = '0'
            print(
                'TODO: do action for register download for id: %s and subid: %s'
                % (id_str, subid_str))
            self.post_success(id_str, )
            return

        elif 'image' in form:
            bin1 = form['image'][0]
            input_image_path = interface_utils.save_encoded_image(
                bin1, './static/images/inputs/' + id_str)

            cropped_image_pattern = './static/images/cropped_faces/' + id_str + '_%d.png'
            faces = FACE_DETECTOR.crop_face_and_save(input_image_path,
                                                     cropped_image_pattern)
            num_faces = len(faces)
            if num_faces > FLAGS.max_num_faces:
                faces = faces[FLAGS.max_num_faces]
                num_faces = FLAGS.max_num_faces
            if num_faces == 0:
                shutil.copy(
                    input_image_path,
                    './static/images/cropped_faces/' + id_str + '_%d.png' % 0)
                faces = [
                    util_io.imread(input_image_path,
                                   (FLAGS.image_hw, FLAGS.image_hw))
                ]
                num_faces = len(faces)

            transferred_image_file_format = './static/images/transferred_faces/' + id_str + '_%d.png'
            succeed, transferred_image_files = self.automatic_retry(
                functools.partial(
                    self.domain_transfer,
                    transferred_image_file_format=transferred_image_file_format,
                    images=faces))
            if not succeed:
                self.post_server_internal_error('Domain transfer failed',
                                                id_str,
                                                {'num_faces': num_faces})
                return

            if 'do_waifu2x' in form:
                do_waifu2x = form['do_waifu2x'][0] == 'true'
            else:
                do_waifu2x = False
            if do_waifu2x:
                transferred_2x_image_file_format = './static/images/transferred_faces_2x/' + id_str + '_%d.png'
                succeed, transferred_2x_image_files = self.automatic_retry(
                    functools.partial(self.call_waifu2x,
                                      transferred_image_file_format=
                                      transferred_image_file_format,
                                      transferred_2x_image_file_format=
                                      transferred_2x_image_file_format,
                                      num_images=num_faces))
                if not succeed:
                    self.post_server_internal_error('Waifu2x failed', id_str,
                                                    {'num_faces': num_faces})
                    return
                transferred_image_to_be_combined_format = transferred_2x_image_file_format
            else:
                transferred_image_to_be_combined_format = transferred_image_file_format

            combined_image_pattern = './static/images/combined/' + id_str + '_%d.png'
            succeed, combined_images = self.automatic_retry(
                functools.partial(
                    self.combine_original_and_transferred,
                    images=faces,
                    combined_image_pattern=combined_image_pattern,
                    transferred_2x_image_file_format=
                    transferred_image_to_be_combined_format))
            if not succeed:
                self.post_server_internal_error(
                    'Combine original and transferred failed.', id_str,
                    {'num_faces': num_faces})
                return

            self.post_success(id_str, {'num_faces': num_faces})
        else:
            self.post_bad_request('Post request must contain image.', id_str)
        return
Beispiel #16
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  inference_class = mask_inference if FLAGS.detect_masks else detection_inference
  if not os.path.exists(FLAGS.output_path):
    tf.gfile.MakeDirs(FLAGS.output_path)

  required_flags = ['input_images', 'output_path',
                    'inference_graph']
  for flag_name in required_flags:
    if not getattr(FLAGS, flag_name):
      raise ValueError('Flag --{} is required'.format(flag_name))

  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, ))

  input_image_paths = []
  for v in FLAGS.input_images.split(','):
    if v:
      input_image_paths += tf.gfile.Glob(v)
  tf.logging.info('Reading input from %d files', len(input_image_paths))
  image_ph, image_tensor = build_input()

  tf.logging.info('Reading graph and building model...')
  detected_tensors = inference_class.build_inference_graph(
    image_tensor, FLAGS.inference_graph, override_num_detections=FLAGS.override_num_detections)

  tf.logging.info('Running inference and writing output to {}'.format(
    FLAGS.output_path))
  sess.run(tf.local_variables_initializer())

  for i, image_path in enumerate(input_image_paths):
    image_np = util_io.imread(image_path)
    result = inference_class.infer_detections(
      sess, image_tensor, detected_tensors,
      min_score_thresh=FLAGS.min_score_thresh,
      visualize_inference=FLAGS.visualize_inference,
      feed_dict={image_ph: image_np})

    if FLAGS.output_cropped_image:
      if FLAGS.only_output_cropped_single_object and len(result["detection_score"]) == 1:
        num_outputs = 1
      else:
        num_outputs = len(result["detection_score"])

      for crop_i in range(0, num_outputs):
        if (result["detection_score"])[crop_i] > FLAGS.min_score_thresh:
          base, ext = os.path.splitext(os.path.basename(image_path))
          output_crop = os.path.join(FLAGS.output_path, base + '_crop_%d.png' %crop_i)
          idims = image_np.shape  # np array with shape (height, width, num_color(1, 3, or 4))
          min_x = int(min(round(result["detection_bbox_xmin"][crop_i] * idims[1]), idims[1]))
          max_x = int(min(round(result["detection_bbox_xmax"][crop_i] * idims[1]), idims[1]))
          min_y = int(min(round(result["detection_bbox_ymin"][crop_i] * idims[0]), idims[0]))
          max_y = int(min(round(result["detection_bbox_ymax"][crop_i] * idims[0]), idims[0]))
          image_cropped = image_np[min_y:max_y, min_x:max_x, :]
          util_io.imsave(output_crop, image_cropped)

    if FLAGS.visualize_inference:
      output_image = os.path.join(FLAGS.output_path, os.path.basename(image_path))
      util_io.imsave(output_image, result['annotated_image'])
      del result['annotated_image']  # No need to write the image to json.
    if FLAGS.detect_masks:
      base, ext = os.path.splitext(os.path.basename(image_path))
      for mask_i in range(len(result['detected_masks'])):
        # Stores as png to preserve accurate mask values.
        output_mask = os.path.join(FLAGS.output_path, base + '_mask_%d' % mask_i + '.png')
        util_io.imsave(output_mask, np.array(result['detected_masks'][mask_i]) * 255)
      del result['detected_masks']  # Storing mask in json is pretty space consuming.

    output_file = os.path.join(FLAGS.output_path, os.path.splitext(os.path.basename(image_path))[0] + '.json')
    with open(output_file, 'w') as f:
      json.dump(result, f)

    tf.logging.log_every_n(tf.logging.INFO, 'Processed %d/%d images...', 10, i, len(input_image_paths))

  print('Finished processing all images in data set.')