def valid(self):
        test_iter = Clip_Iterator(c.VALID_DIR_CLIPS)
        evaluator = Evaluator(self.global_step)
        i = 0
        for data in test_iter.sample_valid(self._batch):
            in_data = data[:, :self._in_seq, ...]
            if c.IN_CHANEL == 3:
                gt_data = data[:,
                               self._in_seq:self._in_seq + self._out_seq, :, :,
                               1:-1]
            elif c.IN_CHANEL == 1:
                gt_data = data[:, self._in_seq:self._in_seq + self._out_seq,
                               ...]
            else:
                raise NotImplementedError
            if c.NORMALIZE:
                in_data = normalize_frames(in_data)
                gt_data = normalize_frames(gt_data)

            mse, mae, gdl, pred = self.g_model.valid_step(in_data, gt_data)
            evaluator.evaluate(gt_data, pred)
            self.logger.info(f"Iter {self.global_step} {i}: \n\t "
                             f"mse:{mse:.4f} \n\t "
                             f"mae:{mae:.4f} \n\t "
                             f"gdl:{gdl:.4f}")
            i += 1
        evaluator.done()
Beispiel #2
0
    def run_benchmark(self, iter, mode="Valid"):
        if mode == "Valid":
            time_interval = c.RAINY_VALID
            stride = 20
        else:
            time_interval = c.RAINY_TEST
            stride = 1
        test_iter = Iterator(time_interval=time_interval,
                             sample_mode="sequent",
                             seq_len=c.IN_SEQ + c.OUT_SEQ,
                             stride=1)
        evaluator = Evaluator(iter)
        i = 1
        while not test_iter.use_up:
            data, date_clip, *_ = test_iter.sample(batch_size=c.BATCH_SIZE)
            in_data = np.zeros(shape=(c.BATCH_SIZE, c.IN_SEQ, c.H, c.W, c.IN_CHANEL))
            gt_data = np.zeros(shape=(c.BATCH_SIZE, c.OUT_SEQ, c.H, c.W, 1))
            if type(data) == type([]):
                break
            in_data[...] = data[:, :c.IN_SEQ, ...]

            if c.IN_CHANEL == 3:
                gt_data[...] = data[:, c.IN_SEQ:c.IN_SEQ + c.OUT_SEQ, :, :, 1:-1]
            elif c.IN_CHANEL == 1:
                gt_data[...] = data[:, c.IN_SEQ:c.IN_SEQ + c.OUT_SEQ, ...]
            else:
                raise NotImplementedError

            # in_date = date_clip[0][:c.IN_SEQ]

            if c.NORMALIZE:
                in_data = normalize_frames(in_data)
                gt_data = normalize_frames(gt_data)

            mse, mae, gdl, pred = self.model.valid_step(in_data, gt_data)
            evaluator.evaluate(gt_data, pred)
            logging.info(f"Iter {iter} {i}: \n\t mse:{mse} \n\t mae:{mae} \n\t gdl:{gdl}")
            i += 1
            if i % stride == 0:
                if c.IN_CHANEL == 3:
                    in_data = in_data[:, :, :, :, 1:-1]

                for b in range(c.BATCH_SIZE):
                    predict_date = date_clip[b][c.IN_SEQ]
                    logging.info(f"Save {predict_date} results")
                    if mode == "Valid":
                        save_path = os.path.join(c.SAVE_VALID, str(iter), predict_date.strftime("%Y%m%d%H%M"))
                    else:
                        save_path = os.path.join(c.SAVE_TEST, str(iter), predict_date.strftime("%Y%m%d%H%M"))

                    path = os.path.join(save_path, "in")
                    save_png(in_data[b], path)

                    path = os.path.join(save_path, "pred")
                    save_png(pred[b], path)

                    path = os.path.join(save_path, "out")
                    save_png(gt_data[b], path)
        evaluator.done()
Beispiel #3
0
    def valid_clips(self, step):
        test_iter = Clip_Iterator(c.VALID_DIR_CLIPS)
        evaluator = Evaluator(step)
        i = 0
        for data in test_iter.sample_valid(c.BATCH_SIZE):
            in_data = data[:, :c.IN_SEQ, ...]
            if c.IN_CHANEL == 3:
                gt_data = data[:, c.IN_SEQ:c.IN_SEQ + c.OUT_SEQ, :, :, 1:-1]
            elif c.IN_CHANEL == 1:
                gt_data = data[:, c.IN_SEQ:c.IN_SEQ + c.OUT_SEQ, ...]
            else:
                raise NotImplementedError
            if c.NORMALIZE:
                in_data = normalize_frames(in_data)
                gt_data = normalize_frames(gt_data)

            mse, mae, gdl, pred = self.model.valid_step(in_data, gt_data)
            evaluator.evaluate(gt_data, pred)
            logging.info(f"Iter {step} {i}: \n\t mse:{mse} \n\t mae:{mae} \n\t gdl:{gdl}")
            i += 1
        evaluator.done()
    def run_benchmark(self, iter, mode="Test"):
        if mode == "Valid":
            time_interval = c.RAINY_VALID
            stride = 5
        else:
            time_interval = c.RAINY_TEST
            stride = 1
        test_iter = Iterator(time_interval=time_interval,
                             sample_mode="sequent",
                             seq_len=self._in_seq + self._out_seq,
                             stride=1)
        evaluator = Evaluator(iter, length=self._out_seq, mode=mode)
        i = 1
        while not test_iter.use_up:
            data, date_clip, *_ = test_iter.sample(batch_size=self._batch)
            in_data = np.zeros(shape=(self._batch, self._in_seq, self._h,
                                      self._w, c.IN_CHANEL))
            gt_data = np.zeros(shape=(self._batch, self._out_seq, self._h,
                                      self._w, 1))
            if type(data) == type([]):
                break
            in_data[...] = data[:, :self._in_seq, :, :, :]

            if c.IN_CHANEL == 3:
                gt_data[...] = data[:, self._in_seq:self._in_seq +
                                    self._out_seq, :, :, :]
            elif c.IN_CHANEL == 1:
                gt_data[...] = data[:, self._in_seq:self._in_seq +
                                    self._out_seq, :, :, :]
            else:
                raise NotImplementedError

            # in_date = date_clip[0][:c.IN_SEQ]

            if c.NORMALIZE:
                in_data = normalize_frames(in_data)
                gt_data = normalize_frames(gt_data)
            in_data = crop_img(in_data)
            gt_data = crop_img(gt_data)
            mse, mae, gdl, pred = self.g_model.valid_step(in_data, gt_data)
            evaluator.evaluate(gt_data, pred)
            self.logger.info(
                f"Iter {iter} {i}: \n\t mse:{mse} \n\t mae:{mae} \n\t gdl:{gdl}"
            )
            i += 1
            if i % stride == 0:
                if c.IN_CHANEL == 3:
                    in_data = in_data[:, :, :, :, 1:-1]

                for b in range(self._batch):
                    predict_date = date_clip[b][self._in_seq - 1]
                    self.logger.info(f"Save {predict_date} results")
                    if mode == "Valid":
                        save_path = os.path.join(
                            c.SAVE_VALID, str(iter),
                            predict_date.strftime("%Y%m%d%H%M"))
                    else:
                        save_path = os.path.join(
                            c.SAVE_TEST, str(iter),
                            predict_date.strftime("%Y%m%d%H%M"))

                    path = os.path.join(save_path, "in")
                    save_png(in_data[b], path)

                    path = os.path.join(save_path, "pred")
                    save_png(pred[b], path)

                    path = os.path.join(save_path, "out")
                    save_png(gt_data[b], path)
        evaluator.done()
        self.notifier.eval(iter, evaluator.result_path)