def Run_QuadTreeSR(self, patch_size):
        self.patch_size = patch_size
        self.save_dir = "result/Octempo_" + str(
            self.simSizeLow) + "_" + str(patch_size)
        self.make_folder(self.save_dir)
        f_time = open(
            "result/time_check/Quadtempo_" + str(self.simSizeLow) + "_" +
            str(patch_size) + ".txt", 'w')

        pre_process_time_start = timeit.default_timer()

        ph.checkUnusedParams()

        modelPath = "model/model_0199_final.ckpt"

        # 각 패치 별 최댓값으로 만든 메트릭스 생성
        Max_Matrix_list = self.Gen_Maxmatrix(patch_size)
        Max_Matrix_list_size = 0

        self.x = tf.placeholder(tf.float32, shape=[None, None])
        self.keep_prob = tf.placeholder(tf.float32)
        self.train = tf.placeholder(tf.bool)

        # 각 패치 크기 별로 Layer 설정
        self.set_parameter(self.simSizeLow)
        self.sampler_0 = self.gen_resnet(self.x,
                                         0,
                                         reuse=False,
                                         use_batch_norm=True)
        self.set_parameter(self.simSizeLow // 2 + self.overlap * 2)
        self.sampler_1 = self.gen_resnet(self.x,
                                         1,
                                         reuse=False,
                                         use_batch_norm=True)
        self.set_parameter(self.simSizeLow // 4 + self.overlap * 2)
        self.sampler_2 = self.gen_resnet(self.x,
                                         2,
                                         reuse=False,
                                         use_batch_norm=True)
        self.set_parameter(self.simSizeLow // 8 + self.overlap * 2)
        self.sampler_3 = self.gen_resnet(self.x,
                                         3,
                                         reuse=False,
                                         use_batch_norm=True)
        self.set_parameter(self.simSizeLow // 16 + self.overlap * 2)
        self.sampler_4 = self.gen_resnet(self.x,
                                         4,
                                         reuse=False,
                                         use_batch_norm=True)

        # TempoGAN 모델 불러오기
        config = tf.ConfigProto(allow_soft_placement=True)
        self.sess = tf.Session(config=config)
        saver = tf.train.Saver()
        saver.restore(self.sess, modelPath)

        pre_process_time = timeit.default_timer() - pre_process_time_start

        f_time.write("0" + " " + str(pre_process_time))
        for frame in range(self.start_frame, self.total_frame):
            print("frame = ", frame)
            one_frame_start_time = timeit.default_timer()

            data = self.input_data[Max_Matrix_list_size]
            data = np.reshape(data, (self.simSizeLow, self.simSizeLow, 4))
            data = np.ascontiguousarray(data, dtype=np.float32)

            self.QT = Qtree_GPU(data, Max_Matrix_list[Max_Matrix_list_size],
                                patch_size)

            # 패치 크기 및 키 별로 받기
            key0, data0, key1, data1, key2, data2, key3, data3, key4, data4, key_t, data_t = self.QT.set_data_quadtree(
            )
            data0 = np.array(data0)
            data1 = np.array(data1)
            data2 = np.array(data2)
            data3 = np.array(data3)
            data4 = np.array(data4)
            data_t = np.array(data_t)

            # 결과로 저장할 배열 설정
            final_result = np.zeros(
                [self.simSizeLow * 4, self.simSizeLow * 4,
                 1]).astype(np.float32)

            # 패치 크기마다 데이터가 있을 때 Super-resolution
            # final result에 각각 더해주기
            if data0.shape[0] > 0:
                self.set_parameter(data0.shape[1])
                result0 = self.tempoGAN(data0, key0)
                final_result += result0
            if data1.shape[0] > 0:
                self.set_parameter(data1.shape[1])
                result1 = self.tempoGAN(data1, key1)
                final_result += result1
            if data2.shape[0] > 0:
                self.set_parameter(data2.shape[1])
                result2 = self.tempoGAN(data2, key2)
                final_result += result2
            if data3.shape[0] > 0:
                self.set_parameter(data3.shape[1])
                result3 = self.tempoGAN(data3, key3)
                final_result += result3
            if data4.shape[0] > 0:
                self.set_parameter(data4.shape[1])
                result4 = self.tempoGAN(data4, key4)
                final_result += result4
            if data_t.shape[0] > 0:
                self.set_parameter(data_t.shape[1])
                result_t = self.tempoGAN(data_t, key_t)
                final_result += result_t

            # 결과 이미지로 저장
            final_result = np.uint8(np.clip(final_result * 255, 0, 255))
            final_result = cv2.flip(final_result, 0)
            cv2.imwrite(self.save_dir + "/%05d" % (frame + 1) + ".jpg",
                        final_result)
            one_frame_time = timeit.default_timer() - one_frame_start_time
            f_time.write("\n" + str(frame + 1) + " " + str(one_frame_time))
            Max_Matrix_list_size += 1

        f_time.close()
        self.sess.close()
Ejemplo n.º 2
0
basePath = ph.getParam("basepath", basePath)
npSeedstr = ph.getParam("seed", npSeedstr)
npSeed = int(npSeedstr)
resetN = int(ph.getParam("reset", resetN))
dim = int(ph.getParam("dim", dim))
simMode = int(ph.getParam("mode", 1))  # 1 = double sim, 2 = wlt
savenpz = int(ph.getParam("savenpz", False)) > 0
saveuni = int(ph.getParam("saveuni", False)) > 0
saveppm = int(ph.getParam("saveppm", False)) > 0
showGui = int(ph.getParam("gui", showGui))
res = int(ph.getParam("res", res))
scaleFactor = int(ph.getParam("fac", scaleFactor))
steps = int(ph.getParam("steps", steps))
timeOffset = int(ph.getParam("warmup",
                             20))  # skip certain no of steps at beginning
ph.checkUnusedParams()

doRecenter = False  # re-center densities , disabled for now

setDebugLevel(1)
if not basePath.endswith("/"): basePath = basePath + "/"

if savenpz or saveuni or saveppm:
    folderNo = simNo
    simPath, simNo = ph.getNextSimPath(simNo, basePath)

    # add some more info for json file
    ph.paramDict["simNo"] = simNo
    ph.paramDict["type"] = "smoke"
    ph.paramDict["name"] = "gen6combined"
    ph.paramDict["version"] = printBuildInfo()