def _get_ckpt(self): all_ckpt_list = [ _.split(".index")[0] for _ in list_getter(self.config.ckpt_dir, 'index') ] ckpt_pattern = './model/checkpoints/model_step-%d' return all_ckpt_list[all_ckpt_list.index(ckpt_pattern % self.config.ckpt_id)]
def _input_from_image(self): def inspect_file_extension(target_list): extensions = list( set([ os.path.basename(img_name).split(".")[-1] for img_name in target_list ])) if len(extensions) > 1: raise ValueError("Multiple image formats are used:") elif len(extensions) == 0: raise ValueError("no image files exist") def inspect_pairness(list1, list2): if not len(list1) == len(list2): raise ValueError("number of images are different") for file1, file2 in zip(list1, list2): file1_name = os.path.basename(file1).split(".")[-2] file2_name = os.path.basename(file2).split(".")[-2] if not file1_name == file2_name: raise ValueError("image names are different: %s | %s" % (file2, file1)) img_list = list_getter(self.config.img_dir, "jpg") img_list_tensor = tf.convert_to_tensor(img_list, dtype=tf.string) img_data = tf.data.Dataset.from_tensor_slices(img_list_tensor) if self.config.phase == "eval": gt_list = list_getter(self.seg_dir, "png") inspect_pairness(gt_list, img_list) inspect_file_extension(gt_list) inspect_file_extension(img_list) gt_list_tensor = tf.convert_to_tensor(gt_list, dtype=tf.string) gt_data = tf.data.Dataset.from_tensor_slices(gt_list_tensor) data = tf.data.Dataset.zip((img_data, gt_data)) data = data.map(self._image_gt_parser, 4).batch(self.config.batch_size, False) else: data = img_data.map(self._image_parser, 4).batch(self.config.batch_size, False) data = data.prefetch(4) # tf.data_pipeline.experimental.AUTOTUNE iterator = data.make_initializable_iterator() dataset = iterator.get_next() self.input_data = dataset["input_data"] self.gt = dataset["gt"] if self.config.phase == "eval" else None self.filename = dataset["filename"] self.data_init = iterator.initializer
def _get_batch_and_init(self, tfrecord_dir, batch_size): tfrecord_list = list_getter(tfrecord_dir, extension="tfrecord") if not tfrecord_list: raise ValueError("tfrecord does not exist: %s" % tfrecord_dir) data = tf.data.TFRecordDataset(tfrecord_list) data = data.repeat() data = data.shuffle(batch_size * 10) data = data.map(self._tfrecord_parser, 4).batch(batch_size, self._drop_remainder) data = data.prefetch(4) # tf.data_pipeline.experimental.AUTOTUNE iterator = data.make_one_shot_iterator() return iterator.get_next()
def _get_ckpt_in_range(self): all_ckpt_list = [ _.split(".index")[0] for _ in list_getter(self.config.ckpt_dir, 'index') ] ckpt_pattern = './model/checkpoints/model_step-%d' if self.config.ckpt_start == 'beginning': start_idx = 0 else: start_idx = all_ckpt_list.index(ckpt_pattern % self.config.ckpt_start) if self.config.ckpt_end == 'end': end_idx = None else: end_idx = all_ckpt_list.index( ckpt_pattern % self.config.ckpt_end) + 1 return all_ckpt_list[start_idx:end_idx:self.config.ckpt_step]
def _start_train(self, hvd, sess): graph = tf.get_default_graph() saver = tf.train.Saver(max_to_keep=5000) with graph.as_default() as graph: global_init_fn = tf.global_variables_initializer() local_init_fn = tf.local_variables_initializer() init_fn = tf.group(global_init_fn, local_init_fn) all_ckpt_list = [ _.split(".index")[0] for _ in list_getter(self.config.ckpt_dir, 'index') ] sess.run(init_fn) if all_ckpt_list: # assumed the current model is intended to continue training if latest checkpoint exists print('Training will be continued from the last checkpoint...') saver.restore(sess, all_ckpt_list[-1]) print('The last checkpoint is loaded!') else: print('Training will be started from scratch...') sess.run(hvd.broadcast_global_variables(0)) self._train_step(graph, sess, saver)
def _vis_with_video(self, sess): vid_list = list_getter(self.config.img_dir, ("avi", "mp4")) for vid_name in vid_list: vid = VideoCapture(vid_name) fps = round(vid.get(5)) should_continue, frame = vid.read() basename = os.path.basename(vid_name)[:-4] dst_name = self.config.vis_result_dir + "/" + basename + ".avi" h, w, _ = frame.shape pred = sess.run(self.pred, {self.input_data: np.expand_dims(frame, 0)}) superimposed = self._superimpose(frame, pred) vid_out = VideoWriter(dst_name, VideoWriter_fourcc(*"XVID"), fps, (w, h)) vid_out.write(superimposed.astype(np.uint8)) while should_continue: should_continue, frame = vid.read() if should_continue: pred = sess.run( self.pred, {self.input_data: np.expand_dims(frame, 0)}) superimposed = self._superimpose(frame, pred) vid_out.write(superimposed.astype(np.uint8)) vid_out.release()