示例#1
0
    def upload_prediction(self, prediction):
        """

        Parameters
        ----------
        prediction : np.ndarray

        """
        assert self._need_upload_prediction, "Must call get_observation first!" \
                                             " Also, check the value of need_upload_predction" \
                                             " after calling"
        self._need_upload_prediction = False
        received_seq_inds = range(
            self._received_pred_seq_num,
            self._received_pred_seq_num + prediction.shape[1])
        save_ind_set = set(received_seq_inds).intersection(self._save_seq_inds)
        if len(save_ind_set) > 0:
            assert len(save_ind_set) == 1
            ind = save_ind_set.pop()
            ind -= self._received_pred_seq_num
            if not os.path.exists(
                    os.path.join(self._save_dir, self._fingerprint)):
                os.makedirs(os.path.join(self._save_dir, self._fingerprint))
            print("Saving prediction videos to %s" %
                  os.path.join(self._save_dir, self._fingerprint))
            save_hko_movie(
                im_dat=self._in_frame_dat[:, ind, 0, ...],
                mask_dat=self._in_mask_dat[:, ind, 0, :, :],
                datetime_list=self._in_datetime_clips[ind],
                save_path=os.path.join(
                    self._save_dir, self._fingerprint, "%s_in.mp4" %
                    self._in_datetime_clips[ind][0].strftime('%Y%m%d%H%M')))
            save_hko_movie(
                im_dat=self._out_frame_dat[:, ind, 0, ...],
                mask_dat=self._out_mask_dat[:, ind, 0, :, :],
                masked=False,
                datetime_list=self._out_datetime_clips[ind],
                save_path=os.path.join(
                    self._save_dir, self._fingerprint, "%s_out.mp4" %
                    self._in_datetime_clips[ind][0].strftime('%Y%m%d%H%M')))
            save_hko_movie(
                im_dat=prediction[:, ind, 0, ...],
                mask_dat=self._out_mask_dat[:, ind, 0, :, :],
                masked=False,
                datetime_list=self._out_datetime_clips[ind],
                save_path=os.path.join(
                    self._save_dir, self._fingerprint, "%s_pred.mp4" %
                    self._in_datetime_clips[ind][0].strftime('%Y%m%d%H%M')))
        self._received_pred_seq_num += prediction.shape[1]
        if self._mode == "online":
            if self._stride == 1:
                assert not self._begin_new_episode
        self._all_eval.update(
            gt=self._out_frame_dat,
            pred=prediction,
            mask=self._out_mask_dat,
            start_datetimes=[ele[0] for ele in self._out_datetime_clips])
示例#2
0
    hko_iter.sample(batch_size=1)

valid_batch = valid_batch.astype(np.float32) / 255.0
valid_data = valid_batch[:IN_LEN, ...]
valid_label = valid_batch[IN_LEN:IN_LEN + OUT_LEN, ...]
mask = valid_mask[IN_LEN:IN_LEN + OUT_LEN, ...].astype(int)
torch_valid_data = torch.from_numpy(valid_data).to(cfg.GLOBAL.DEVICE)

# 生成预测结果
with torch.no_grad():
    output = encoder_forecaster(torch_valid_data)

output = np.clip(output.cpu().numpy(), 0.0, 1.0)

base_dir = '.'
# S*B*1*H*W
label = valid_label[:, 0, 0, :, :]
output = output[:, 0, 0, :, :]
mask = mask[:, 0, 0, :, :].astype(np.uint8)
# 调用save_hko_movie方法用以将预测结果与真实数据生成动图以进行比较
save_hko_movie(label,
               sample_datetimes[0],
               mask,
               masked=True,
               save_path=os.path.join(base_dir, 'ground_truth.mp4'))
save_hko_movie(output,
               sample_datetimes[0],
               mask,
               masked=True,
               save_path=os.path.join(base_dir, 'pred.mp4'))
示例#3
0
    end = time.time()
    print("Test Data Sample FPS: %f" %(minibatch_size * seq_len
                                       * repeat_time / float(end-begin)))
    code = encode_month(np.arange(1, 13))
    month = decode_month(code)
    print(code)
    print(month.T)

    train_time = 0
    for i in range(30):
        train_batch, train_mask, sample_datetimes, new_start = \
            train_hko_iter.sample(batch_size=minibatch_size)
        name_str = 'train_' + str(i) + '_' + sample_datetimes[0][0].strftime('%Y%m%d%H%M')
        save_hko_movie(train_batch[:, 0, 0, :, :],
                       sample_datetimes[0],
                       train_mask[:, 0, 0, :, :],
                       masked=False,
                       save_path=name_str + '.mp4')
        tic = time.time()
        save_hko_movie(train_batch[:, 0, 0, :, :],
                       sample_datetimes[0],
                       train_mask[:, 0, 0, :, :],
                       masked=True,
                       save_path=name_str + '_filtered.mp4')
        toc = time.time()
        save_hko_movie(train_mask[:, 0, 0, :, :].astype(np.uint8) * 255,
                       sample_datetimes[0],
                       None,
                       masked=False,
                       save_path=name_str + '_mask.mp4')
        print('train, time:', toc - tic)