def upload_prediction(self, prediction): """ Parameters ---------- prediction : np.ndarray """ assert self._need_upload_prediction, "Must call get_observation first!" \ " Also, check the value of need_upload_predction" \ " after calling" self._need_upload_prediction = False received_seq_inds = range( self._received_pred_seq_num, self._received_pred_seq_num + prediction.shape[1]) save_ind_set = set(received_seq_inds).intersection(self._save_seq_inds) if len(save_ind_set) > 0: assert len(save_ind_set) == 1 ind = save_ind_set.pop() ind -= self._received_pred_seq_num if not os.path.exists( os.path.join(self._save_dir, self._fingerprint)): os.makedirs(os.path.join(self._save_dir, self._fingerprint)) print("Saving prediction videos to %s" % os.path.join(self._save_dir, self._fingerprint)) save_hko_movie( im_dat=self._in_frame_dat[:, ind, 0, ...], mask_dat=self._in_mask_dat[:, ind, 0, :, :], datetime_list=self._in_datetime_clips[ind], save_path=os.path.join( self._save_dir, self._fingerprint, "%s_in.mp4" % self._in_datetime_clips[ind][0].strftime('%Y%m%d%H%M'))) save_hko_movie( im_dat=self._out_frame_dat[:, ind, 0, ...], mask_dat=self._out_mask_dat[:, ind, 0, :, :], masked=False, datetime_list=self._out_datetime_clips[ind], save_path=os.path.join( self._save_dir, self._fingerprint, "%s_out.mp4" % self._in_datetime_clips[ind][0].strftime('%Y%m%d%H%M'))) save_hko_movie( im_dat=prediction[:, ind, 0, ...], mask_dat=self._out_mask_dat[:, ind, 0, :, :], masked=False, datetime_list=self._out_datetime_clips[ind], save_path=os.path.join( self._save_dir, self._fingerprint, "%s_pred.mp4" % self._in_datetime_clips[ind][0].strftime('%Y%m%d%H%M'))) self._received_pred_seq_num += prediction.shape[1] if self._mode == "online": if self._stride == 1: assert not self._begin_new_episode self._all_eval.update( gt=self._out_frame_dat, pred=prediction, mask=self._out_mask_dat, start_datetimes=[ele[0] for ele in self._out_datetime_clips])
hko_iter.sample(batch_size=1) valid_batch = valid_batch.astype(np.float32) / 255.0 valid_data = valid_batch[:IN_LEN, ...] valid_label = valid_batch[IN_LEN:IN_LEN + OUT_LEN, ...] mask = valid_mask[IN_LEN:IN_LEN + OUT_LEN, ...].astype(int) torch_valid_data = torch.from_numpy(valid_data).to(cfg.GLOBAL.DEVICE) # 生成预测结果 with torch.no_grad(): output = encoder_forecaster(torch_valid_data) output = np.clip(output.cpu().numpy(), 0.0, 1.0) base_dir = '.' # S*B*1*H*W label = valid_label[:, 0, 0, :, :] output = output[:, 0, 0, :, :] mask = mask[:, 0, 0, :, :].astype(np.uint8) # 调用save_hko_movie方法用以将预测结果与真实数据生成动图以进行比较 save_hko_movie(label, sample_datetimes[0], mask, masked=True, save_path=os.path.join(base_dir, 'ground_truth.mp4')) save_hko_movie(output, sample_datetimes[0], mask, masked=True, save_path=os.path.join(base_dir, 'pred.mp4'))
end = time.time() print("Test Data Sample FPS: %f" %(minibatch_size * seq_len * repeat_time / float(end-begin))) code = encode_month(np.arange(1, 13)) month = decode_month(code) print(code) print(month.T) train_time = 0 for i in range(30): train_batch, train_mask, sample_datetimes, new_start = \ train_hko_iter.sample(batch_size=minibatch_size) name_str = 'train_' + str(i) + '_' + sample_datetimes[0][0].strftime('%Y%m%d%H%M') save_hko_movie(train_batch[:, 0, 0, :, :], sample_datetimes[0], train_mask[:, 0, 0, :, :], masked=False, save_path=name_str + '.mp4') tic = time.time() save_hko_movie(train_batch[:, 0, 0, :, :], sample_datetimes[0], train_mask[:, 0, 0, :, :], masked=True, save_path=name_str + '_filtered.mp4') toc = time.time() save_hko_movie(train_mask[:, 0, 0, :, :].astype(np.uint8) * 255, sample_datetimes[0], None, masked=False, save_path=name_str + '_mask.mp4') print('train, time:', toc - tic)