def predict(input, model, config=None): assert input.shape[0] == config['IN_LEN'] assert input.shape[2] == global_config['IMG_SIZE'] assert input.shape[3] == global_config['IMG_SIZE'] with torch.no_grad(): input = torch.from_numpy(input[:, :, None]).to(config['DEVICE']) output = model(input) assert output.shape[0] == config['OUT_LEN'] assert output.shape[3] == global_config['IMG_SIZE'] assert output.shape[4] == global_config['IMG_SIZE'] return np.minimum(np.maximum(output.cpu().numpy()[:, :, 0], mm_dbz(0)), mm_dbz(60))
def get_data_test(self, indices): if self.config['SCALE'] is None: h = self.config['SIZEH'] w = self.config['SIZEW'] else: scale = self.config['SCALE'] h = int(global_config['DATA_HEIGHT'] * scale) w = int(global_config['DATA_WIDTH'] * scale) sliced_input = np.zeros((len(indices), self.config['IN_LEN'], h, w), dtype=np.float32) sliced_label = np.zeros( (len(indices), global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), dtype=np.float32) for i, idx in enumerate(indices): for j in range(self.config['IN_LEN']): f = np.fromfile(self.files[idx + j], dtype=np.float32) \ .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) sliced_input[i, j] = \ cv2.resize(f, (w, h), interpolation = cv2.INTER_AREA) for i, idx in enumerate(indices): for j in range(global_config['OUT_TARGET_LEN']): sliced_label[i, j] = np.fromfile(self.files[idx + j], dtype=np.float32) \ .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) sliced_input = (mm_dbz(sliced_input) - global_config['NORM_MIN']) / global_config['NORM_DIV'] if self.last_data is not None: for i in self.last_data: del i torch.cuda.empty_cache() self.last_data = [] self.last_data.append( torch.from_numpy(sliced_input).to(self.config['DEVICE'])) self.last_data.append(sliced_label) if self.config['DIM'] == '3D': for i in range(len(self.last_data)): self.last_data[i] = self.last_data[i][:, None, :] elif self.config['DIM'] == '2D': for i in range(len(self.last_data)): self.last_data[i] = self.last_data[i][:, :, None] return tuple(self.last_data)
def get_data(self, indices): if self.config['SCALE'] is None: h = self.config['SIZEH'] w = self.config['SIZEW'] else: scale = self.config['SCALE'] h = int(global_config['DATA_HEIGHT'] * scale) w = int(global_config['DATA_WIDTH'] * scale) sliced_data = np.zeros((len(indices), self.windows_size, h, w), dtype=np.float32) for i, idx in enumerate(indices): for j in range(self.windows_size): f = np.fromfile(self.files[idx + j], dtype=np.float32) \ .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) sliced_data[i, j] = \ cv2.resize(f, (w, h), interpolation = cv2.INTER_AREA) return ((mm_dbz(sliced_data) - global_config['NORM_MIN']) / global_config['NORM_DIV'])[:, :, 6:-6, 1:-1]
def extract(model, config, save_dir, files, file_name, crop=None): idx = 0 try: idx = next(i for i, f in enumerate(files) if os.path.basename(f) == file_name) except: print('not found') return scale = config['SCALE'] h = int(global_config['DATA_HEIGHT'] * scale) w = int(global_config['DATA_WIDTH'] * scale) sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32) sliced_label = np.zeros( (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), dtype=np.float32) for i, j in enumerate(range(idx - config['IN_LEN'], idx)): sliced_input[0, i] = read_file(files[j], h, w, resize=True) for i, j in enumerate(range(idx, idx + global_config['OUT_TARGET_LEN'])): sliced_label[0, i] = read_file(files[j]) sliced_input = (mm_dbz(sliced_input) - global_config['NORM_MIN']) / global_config['NORM_DIV'] sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE'])[:, None] save_dir = save_dir + '/extracted' if not os.path.exists(save_dir): os.makedirs(save_dir) model.eval() for dt in tqdm(range(6 * 24)): cur_input = sliced_input.clone() out_time = int( np.ceil(global_config['OUT_TARGET_LEN'] / config['OUT_LEN'])) outputs = None with torch.no_grad(): for t in range(out_time): output = model(cur_input) if outputs is None: outputs = output.detach().cpu().numpy() else: outputs = np.concatenate( [outputs, output.detach().cpu().numpy()], axis=2) if config['OPTFLOW']: cur_input = torch.cat([cur_input[:, :, -1, None], output], axis=2) else: cur_input = output pred = np.array(outputs)[:, ] pred = pred[:, 0, :global_config['OUT_TARGET_LEN']] pred = denorm(pred) pred_resized = np.zeros( (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) for i in range(pred.shape[0]): for j in range(pred.shape[1]): pred_resized[i, j] = cv2.resize(pred[i, j], (global_config['DATA_WIDTH'], global_config['DATA_HEIGHT']), interpolation=cv2.INTER_AREA) csi = fp_fn_image_csi(pred_resized, sliced_label) csi_multi, micro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label) sum_rmse = np.zeros((3, ), dtype=np.float32) for i in range(global_config['OUT_TARGET_LEN']): rmse, rmse_rain, rmse_non_rain = torch_cal_rmse_all( torch.from_numpy(pred_resized[:, i]).to(config['CAL_DEVICE']), torch.from_numpy(sliced_label[:, i]).to(config['CAL_DEVICE'])) sum_rmse += np.array([ rmse.cpu().numpy(), rmse_rain.cpu().numpy(), rmse_non_rain.cpu().numpy() ]) mean_rmse = sum_rmse / global_config['OUT_TARGET_LEN'] h_small = pred.shape[2] w_small = pred.shape[3] label_small = np.zeros( (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small)) for i in range(sliced_label.shape[0]): for j in range(sliced_label.shape[1]): label_small[i, j] = cv2.resize(sliced_label[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA) time_name = os.path.basename(files[idx + dt])[:-4] path = save_dir + '/' + time_name if not os.path.exists(path): os.makedirs(path) if not os.path.exists(path + '/pred'): os.makedirs(path + '/pred') if not os.path.exists(path + '/label'): os.makedirs(path + '/label') for i in range(pred_resized.shape[0]): for j in range(global_config['OUT_TARGET_LEN']): alpha_mask = (label_small[i, j] > 0.2).astype(np.uint8) * 255 lb = rainfall_shade(label_small[i, j], mode='BGR') lb = np.dstack((lb, alpha_mask)) alpha_mask = (pred[i, j] > 0.2).astype(np.uint8) * 255 prd = rainfall_shade(pred[i, j], mode='BGR') prd = np.dstack((prd, alpha_mask)) cv2.imwrite(path + '/label/' + str(j) + '.png', lb) cv2.imwrite(path + '/pred/' + str(j) + '.png', prd) np.savetxt(path + '/metrics.txt', [csi, micro_csi, np.mean(csi_multi)] + list(csi_multi) + list(mean_rmse), delimiter=',', fmt='%.2f') sliced_input[:, :-1] = sliced_input[:, 1:] next_input = (mm_dbz(read_file(files[idx + dt], h, w, resize=True)) - global_config['NORM_MIN']) / global_config['NORM_DIV'] sliced_input[:, -1] = torch.from_numpy(next_input).to(config['DEVICE']) sliced_label[:, :-1] = sliced_label[:, 1:] sliced_label[:, -1] = read_file( files[idx + global_config['OUT_TARGET_LEN'] + dt])
import glob import os import cv2 import numpy as np import torch from utils.units import mm_dbz, dbz_mm, denorm, torch_denorm from utils.visualizers import make_gif_color, rainfall_shade, make_gif, make_gif_color_label from utils.evaluators import fp_fn_image_csi, cal_rmse_all, fp_fn_image_csi_muti, torch_cal_rmse_all from global_config import global_config from models.unet.model import UNet2D from tqdm import tqdm rs_img = torch.nn.Upsample(size=(global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), mode='bilinear') thres = (mm_dbz(0.5) - global_config['NORM_MIN']) / global_config['NORM_DIV'] def read_file(file_name, h=None, w=None, resize=False): f = np.fromfile(file_name, dtype=np.float32) \ .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) if resize: f = cv2.resize(f, (w, h), interpolation=cv2.INTER_AREA) return f def extract(model, config, save_dir, files, file_name, crop=None): idx = 0 try: idx = next(i for i, f in enumerate(files)
def test(model, data_loader, config, save_dir, files, file_name, crop=None): h1, h2, w1, w2 = 0, global_config['DATA_HEIGHT'] - \ 1, 0, global_config['DATA_WIDTH'] - 1 if crop is not None: h1, h2, w1, w2 = get_crop_boundary_idx(crop) idx = 0 try: idx = next(i for i, f in enumerate(files) if os.path.basename(f) == file_name) except: print('not found') return scale = config['SCALE'] h = int(global_config['DATA_HEIGHT'] * scale) w = int(global_config['DATA_WIDTH'] * scale) sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32) sliced_label = np.zeros( (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), dtype=np.float32) for i, j in enumerate(range(idx - config['IN_LEN'], idx)): f = np.fromfile(files[j], dtype=np.float32) \ .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) sliced_input[0, i] = \ cv2.resize(f, (w, h), interpolation=cv2.INTER_AREA) for i, j in enumerate(range(idx, idx+global_config['OUT_TARGET_LEN'])): sliced_label[0, i] = np.fromfile(files[j], dtype=np.float32) \ .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) sliced_input = (mm_dbz(sliced_input) - global_config['NORM_MIN']) / global_config['NORM_DIV'] sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE']) if config['DIM'] == '3D': sliced_input = sliced_input[:, None, :] elif config['DIM'] == '2D': sliced_input = sliced_input[:, :, None] outputs = None with torch.no_grad(): for t in range(int(np.ceil(global_config['OUT_TARGET_LEN']/config['OUT_LEN']))): # print('input data', sliced_input.shape) output = model(sliced_input) # print('output', output.shape) if outputs is None: outputs = output.detach().cpu().numpy() else: if config['DIM'] == '3D': outputs = np.concatenate( [outputs, output.detach().cpu().numpy()], axis=2) else: outputs = np.concatenate( [outputs, output.detach().cpu().numpy()], axis=1) if config['DIM'] == '3D': if config['OPTFLOW']: sliced_input = torch.cat([sliced_input[:, :, -1, None], output], axis=2) else: sliced_input = output else: sliced_input = torch.cat([sliced_input[:, config['OUT_LEN']:], output], dim=1) pred = np.array(outputs) if config['DIM'] == '3D': pred = pred[:, 0] pred = pred[:, :global_config['OUT_TARGET_LEN']] elif config['DIM'] == '3D': pred = pred[:, :, 0] pred = pred[:, :global_config['OUT_TARGET_LEN']] else: pred = pred[:, :global_config['OUT_TARGET_LEN']] # print('pred label shape', pred.shape, sliced_label.shape) pred = denorm(pred) pred_resized = np.zeros( (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) for i in range(pred.shape[0]): for j in range(pred.shape[1]): pred_resized[i, j] = cv2.resize( pred[i, j], (global_config['DATA_WIDTH'], global_config['DATA_HEIGHT']), interpolation=cv2.INTER_AREA) pred_resized = pred_resized[:, :, h1: h2 + 1, w1: w2 + 1] sliced_label = sliced_label[:, :, h1: h2 + 1, w1: w2 + 1] # csi = fp_fn_image_csi(pred_resized, sliced_label) csis = [] for c in range(pred.shape[1]): csis.append(fp_fn_image_csi(pred_resized[:, c], sliced_label[:, c])) csi = np.mean(csis) csi_multi, macro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label) rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred_resized, sliced_label) result_all = [csi] + list(csi_multi) + [rmse, rmse_rain, rmse_non_rain] h_small = int(pred_resized.shape[2] * 0.5) w_small = int(pred_resized.shape[3] * 0.5) pred_small = np.zeros( (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small)) label_small = np.zeros( (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small)) for i in range(sliced_label.shape[0]): for j in range(sliced_label.shape[1]): pred_small[i, j] = cv2.resize( pred_resized[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA) label_small[i, j] = cv2.resize( sliced_label[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA) path = save_dir + '/imgs' if not os.path.exists(path): os.makedirs(path) for i in range(pred_resized.shape[0]): # Save pred gif # make_gif(pred[i] / 80 * 255, path + '/pred_{}_{}.gif'.format(b, i)) # Save colored pred gif make_gif_color(pred_small[i], path + '/pred_colored.gif') # Save gt gif # make_gif(label_small[i] / 80 * 255, path + '/gt_{}_{}.gif'.format(b, i)) # Save colored gt gif make_gif_color(label_small[i], path + '/gt_colored.gif') labels = [os.path.basename(files[idx+i]) for i in range(global_config['OUT_TARGET_LEN'])] make_gif_color_label(label_small[i], pred_small[i], labels, fname=path + '/all.gif') fig, ax = plt.subplots(figsize=(8, 4), facecolor='white') ax.plot(np.arange(len(csis))+1, csis) ax.set_xticks(np.arange(global_config['OUT_TARGET_LEN'])+1) ax.set_ylabel('Binary - CSI') ax.set_xlabel('Time Steps') plt.savefig(path + '/csis.png') result_all = np.array(result_all) result_all = np.around(result_all, decimals=3) np.savetxt(save_dir + '/result.txt', result_all, delimiter=',', fmt='%.3f') np.savetxt(save_dir + '/csi.txt', np.array(csis), delimiter=',', fmt='%.3f')
def conv_test(model_path, start_pred_fn, test_case, in_len, out_len, batch_size, multitask, crop=None): config = { 'DEVICE': torch.device('cuda:0'), 'IN_LEN': in_len, 'OUT_LEN': out_len, 'BATCH_SIZE': batch_size, } convlstm_encoder_params = [[ OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}), OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}), OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}), ], [ ConvLSTM(input_channel=8, num_filter=64, b_h_w=(batch_size, 96, 96), kernel_size=3, stride=1, padding=1, config=config), ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 32, 32), kernel_size=3, stride=1, padding=1, config=config), ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 16, 16), kernel_size=3, stride=1, padding=1, config=config), ]] convlstm_forecaster_params = [[ OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}), OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}), OrderedDict({ 'deconv3_leaky_1': [64, 8, 7, 5, 1], 'conv3_leaky_2': [8, 8, 3, 1, 1], 'conv3_3': [8, 1, 1, 1, 0] }), ], [ ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 16, 16), kernel_size=3, stride=1, padding=1, config=config), ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 32, 32), kernel_size=3, stride=1, padding=1, config=config), ConvLSTM(input_channel=64, num_filter=64, b_h_w=(batch_size, 96, 96), kernel_size=3, stride=1, padding=1, config=config), ]] encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(config['DEVICE']) forecaster = Forecaster(convlstm_forecaster_params[0], convlstm_forecaster_params[1], config=config).to(config['DEVICE']) model = EF(encoder, forecaster).to(config['DEVICE']) model.load_state_dict(torch.load(model_path, map_location='cuda')) data = get_data(start_pred_fn, crop=crop, config=config) data = mm_dbz(data) weight = global_config['MERGE_WEIGHT'] pred, label = prepare_testing(data, model, weight=weight, config=config) pred = np.maximum(dbz_mm(pred), 0) label = dbz_mm(label) csi = fp_fn_image_csi(pred, label) # print('CSI: ', csi) csi_multi = fp_fn_image_csi_muti_reg(pred, label) # print('CSI Multi: ', csi_multi) rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred, label) # print('rmse_all', rmse) # print('rmse_rain', rmse_rain) # print('rmse_non_rain', rmse_non_rain) if not os.path.exists('./imgs_conv'): os.makedirs('./imgs_conv') path = './imgs_conv/conv_{}_{}_{}_{}/'.format(test_case, in_len, out_len, multitask) try: os.makedirs(path) except: pass try: os.makedirs(path + 'imgs/') except: pass #Save erros gif errs = np.sqrt(np.square(pred - label)) make_gif(errs / errs.max() * 255, path + 'errs.gif') #Save pred gif make_gif(pred / 60 * 255, path + 'pred.gif') #Save colored pred gif make_gif_color(pred, path + 'pred_colored.gif') #Save gt gif make_gif(label / 60 * 255, path + 'gt.gif') #Save colored gt gif make_gif_color(label, path + 'gt_colored.gif') #Save imgs for i in range(pred.shape[0]): cv2.imwrite( path + 'imgs/' + str(i) + '.png', cv2.cvtColor(np.array(pred[i] / 80 * 255, dtype=np.uint8), cv2.COLOR_GRAY2BGR)) # for i in range(pred.shape[0]): # cv2.imwrite(path+str(i)+'.png', # cv2.cvtColor(np.array(pred[i]/60*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR)) # cv2.imwrite('conv_pred_1.png', rainfall_shade(pred[-1])) # cv2.imwrite('conv_label_1.png', rainfall_shade(data[-1])) return [rmse, rmse_rain, rmse_non_rain], csi, csi_multi
def conv_test(model_path, start_pred_fn, test_case, in_len, out_len, batch_size, multitask, crop=None): config = { 'DEVICE': torch.device('cuda:0'), 'IN_LEN': in_len, 'OUT_LEN': out_len, 'BATCH_SIZE': batch_size, } convlstm_encoder_params = [ [ OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}), OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}), OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}), ], [ ConvLSTM(input_channel=8, num_filter=64, b_h_w=(batch_size, 168, 126), kernel_size=3, stride=1, padding=1, config=config), ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 56, 42), kernel_size=3, stride=1, padding=1, config=config), ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 28, 21), kernel_size=3, stride=1, padding=1, config=config), ] ] convlstm_forecaster_params = [ [ OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}), OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}), OrderedDict({ 'deconv3_leaky_1': [64, 8, 7, 5, 1], 'conv3_leaky_2': [8, 8, 3, 1, 1], 'conv3_3': [8, 1, 1, 1, 0] }), ], [ ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 28, 21), kernel_size=3, stride=1, padding=1, config=config), ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 56, 42), kernel_size=3, stride=1, padding=1, config=config), ConvLSTM(input_channel=64, num_filter=64, b_h_w=(batch_size, 168, 126), kernel_size=3, stride=1, padding=1, config=config), ] ] encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to( config['DEVICE']) forecaster = Forecaster( convlstm_forecaster_params[0], convlstm_forecaster_params[1], config=config).to(config['DEVICE']) model = EF(encoder, forecaster).to(config['DEVICE']) model.load_state_dict( torch.load(model_path, map_location='cuda')) files = sorted([file for file in glob.glob(global_config['TEST_PATH'])]) idx = 0 if start_pred_fn != '': try: idx = next(i for i,f in enumerate(files) if os.path.basename(f) == start_pred_fn) except: idx = -1 print('not found') scale_div=4 data = np.zeros((config['IN_LEN'] + global_config['OUT_TARGET_LEN'], int(global_config['DATA_HEIGHT']/scale_div), int((global_config['DATA_WIDTH'] - 40)/scale_div)), dtype=np.float32) for i, file in enumerate(files[idx - config['IN_LEN']:idx + global_config['OUT_TARGET_LEN']]): fd = np.fromfile(file, dtype=np.float32).reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))[:,20:-20] fd = cv2.resize(fd, (int((global_config['DATA_WIDTH'] - 40)/scale_div), int(global_config['DATA_HEIGHT']/scale_div)), interpolation = cv2.INTER_AREA) data[i, :] = fd data = mm_dbz(data) input = data[:in_len] label = data[in_len:] preds = [] with torch.no_grad(): input = torch.from_numpy(input[:, None, None]).to(config['DEVICE']) for i in range(18): pred = model(input) pred[pred<0.2 + 0.035*i] -= 0.04*i pred = torch.clamp(pred, min=mm_dbz(0), max=mm_dbz(60)) preds.append(pred.detach().cpu().numpy()[0, 0, 0]) input = torch.cat((input[1:], pred), 0) pred = np.array(preds) print(pred.shape, label.shape) pred = np.maximum(dbz_mm(pred), 0) label = dbz_mm(label) csi = fp_fn_image_csi(pred, label) # print('CSI: ', csi) csi_multi = fp_fn_image_csi_muti_reg(pred, label) # print('CSI Multi: ', csi_multi) rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred, label) # print('rmse_all', rmse) # print('rmse_rain', rmse_rain) # print('rmse_non_rain', rmse_non_rain) if not os.path.exists('./imgs_conv'): os.makedirs('./imgs_conv') path = './imgs_conv/conv_{}_{}_{}_{}/'.format(test_case, in_len, out_len, multitask) try: os.makedirs(path) except: pass try: os.makedirs(path+'imgs/') except: pass #Save erros gif errs = np.sqrt(np.square(pred - label)) make_gif(errs / errs.max() * 255, path + 'errs.gif') #Save pred gif make_gif(pred / 60 * 255, path + 'pred.gif') #Save colored pred gif make_gif_color(pred, path + 'pred_colored.gif') #Save gt gif make_gif(label / 60 * 255, path + 'gt.gif') #Save colored gt gif make_gif_color(label, path + 'gt_colored.gif') #Save imgs for i in range(pred.shape[0]): cv2.imwrite(path+'imgs/'+str(i)+'.png', cv2.cvtColor(np.array(pred[i]/80*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR)) # for i in range(pred.shape[0]): # cv2.imwrite(path+str(i)+'.png', # cv2.cvtColor(np.array(pred[i]/60*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR)) # cv2.imwrite('conv_pred_1.png', rainfall_shade(pred[-1])) # cv2.imwrite('conv_label_1.png', rainfall_shade(data[-1])) return [rmse, rmse_rain, rmse_non_rain], csi, csi_multi