コード例 #1
0
 def _get_ckpt(self):
     all_ckpt_list = [
         _.split(".index")[0]
         for _ in list_getter(self.config.ckpt_dir, 'index')
     ]
     ckpt_pattern = './model/checkpoints/model_step-%d'
     return all_ckpt_list[all_ckpt_list.index(ckpt_pattern %
                                              self.config.ckpt_id)]
コード例 #2
0
    def _input_from_image(self):
        def inspect_file_extension(target_list):
            extensions = list(
                set([
                    os.path.basename(img_name).split(".")[-1]
                    for img_name in target_list
                ]))
            if len(extensions) > 1:
                raise ValueError("Multiple image formats are used:")
            elif len(extensions) == 0:
                raise ValueError("no image files exist")

        def inspect_pairness(list1, list2):
            if not len(list1) == len(list2):
                raise ValueError("number of images are different")
            for file1, file2 in zip(list1, list2):
                file1_name = os.path.basename(file1).split(".")[-2]
                file2_name = os.path.basename(file2).split(".")[-2]
                if not file1_name == file2_name:
                    raise ValueError("image names are different: %s | %s" %
                                     (file2, file1))

        img_list = list_getter(self.config.img_dir, "jpg")
        img_list_tensor = tf.convert_to_tensor(img_list, dtype=tf.string)
        img_data = tf.data.Dataset.from_tensor_slices(img_list_tensor)
        if self.config.phase == "eval":
            gt_list = list_getter(self.seg_dir, "png")
            inspect_pairness(gt_list, img_list)
            inspect_file_extension(gt_list)
            inspect_file_extension(img_list)
            gt_list_tensor = tf.convert_to_tensor(gt_list, dtype=tf.string)
            gt_data = tf.data.Dataset.from_tensor_slices(gt_list_tensor)
            data = tf.data.Dataset.zip((img_data, gt_data))
            data = data.map(self._image_gt_parser,
                            4).batch(self.config.batch_size, False)
        else:
            data = img_data.map(self._image_parser,
                                4).batch(self.config.batch_size, False)
        data = data.prefetch(4)  # tf.data_pipeline.experimental.AUTOTUNE
        iterator = data.make_initializable_iterator()
        dataset = iterator.get_next()
        self.input_data = dataset["input_data"]
        self.gt = dataset["gt"] if self.config.phase == "eval" else None
        self.filename = dataset["filename"]
        self.data_init = iterator.initializer
コード例 #3
0
 def _get_batch_and_init(self, tfrecord_dir, batch_size):
     tfrecord_list = list_getter(tfrecord_dir, extension="tfrecord")
     if not tfrecord_list:
         raise ValueError("tfrecord does not exist: %s" % tfrecord_dir)
     data = tf.data.TFRecordDataset(tfrecord_list)
     data = data.repeat()
     data = data.shuffle(batch_size * 10)
     data = data.map(self._tfrecord_parser,
                     4).batch(batch_size, self._drop_remainder)
     data = data.prefetch(4)  # tf.data_pipeline.experimental.AUTOTUNE
     iterator = data.make_one_shot_iterator()
     return iterator.get_next()
コード例 #4
0
    def _get_ckpt_in_range(self):
        all_ckpt_list = [
            _.split(".index")[0]
            for _ in list_getter(self.config.ckpt_dir, 'index')
        ]
        ckpt_pattern = './model/checkpoints/model_step-%d'
        if self.config.ckpt_start == 'beginning':
            start_idx = 0
        else:
            start_idx = all_ckpt_list.index(ckpt_pattern %
                                            self.config.ckpt_start)

        if self.config.ckpt_end == 'end':
            end_idx = None
        else:
            end_idx = all_ckpt_list.index(
                ckpt_pattern % self.config.ckpt_end) + 1
        return all_ckpt_list[start_idx:end_idx:self.config.ckpt_step]
コード例 #5
0
 def _start_train(self, hvd, sess):
     graph = tf.get_default_graph()
     saver = tf.train.Saver(max_to_keep=5000)
     with graph.as_default() as graph:
         global_init_fn = tf.global_variables_initializer()
         local_init_fn = tf.local_variables_initializer()
         init_fn = tf.group(global_init_fn, local_init_fn)
         all_ckpt_list = [
             _.split(".index")[0]
             for _ in list_getter(self.config.ckpt_dir, 'index')
         ]
         sess.run(init_fn)
         if all_ckpt_list:  # assumed the current model is intended to continue training if latest checkpoint exists
             print('Training will be continued from the last checkpoint...')
             saver.restore(sess, all_ckpt_list[-1])
             print('The last checkpoint is loaded!')
         else:
             print('Training will be started from scratch...')
         sess.run(hvd.broadcast_global_variables(0))
         self._train_step(graph, sess, saver)
コード例 #6
0
 def _vis_with_video(self, sess):
     vid_list = list_getter(self.config.img_dir, ("avi", "mp4"))
     for vid_name in vid_list:
         vid = VideoCapture(vid_name)
         fps = round(vid.get(5))
         should_continue, frame = vid.read()
         basename = os.path.basename(vid_name)[:-4]
         dst_name = self.config.vis_result_dir + "/" + basename + ".avi"
         h, w, _ = frame.shape
         pred = sess.run(self.pred,
                         {self.input_data: np.expand_dims(frame, 0)})
         superimposed = self._superimpose(frame, pred)
         vid_out = VideoWriter(dst_name, VideoWriter_fourcc(*"XVID"), fps,
                               (w, h))
         vid_out.write(superimposed.astype(np.uint8))
         while should_continue:
             should_continue, frame = vid.read()
             if should_continue:
                 pred = sess.run(
                     self.pred, {self.input_data: np.expand_dims(frame, 0)})
                 superimposed = self._superimpose(frame, pred)
                 vid_out.write(superimposed.astype(np.uint8))
         vid_out.release()