def train_iteration(self): n_train_batch = self.data_loader.n_train_batch() pbar_b = tqdm(range(n_train_batch)) for b in pbar_b: self.model.train() pbar_b.set_description('Training at batch %d / %d' % (b, n_train_batch)) train_data, train_label = self.data_loader.get_train(b) self.optim.zero_grad() output = self.model(train_data) loss = self.mse_loss(output, train_label) + self.mae_loss(output, train_label) loss.backward() self.optim.step() self.train_loss += loss.data.item() / len(train_data) # lbl_pred = output # lbl_true = train_label lbl_pred = output.detach().cpu().numpy() lbl_true = train_label.cpu().numpy() # print('train', lbl_pred.shape, lbl_true.shape) # csis, w_csi = torch_csi_muti(torch_denorm(lbl_pred), torch_denorm(lbl_true)) csis, w_csi = fp_fn_image_csi_muti(denorm(lbl_pred), denorm(lbl_true)) self.train_metrics_value += csis self.add_epoch()
def extract(model, config, save_dir, files, file_name, crop=None): idx = 0 try: idx = next(i for i, f in enumerate(files) if os.path.basename(f) == file_name) except: print('not found') return scale = config['SCALE'] h = int(global_config['DATA_HEIGHT'] * scale) w = int(global_config['DATA_WIDTH'] * scale) sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32) sliced_label = np.zeros( (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), dtype=np.float32) for i, j in enumerate(range(idx - config['IN_LEN'], idx)): sliced_input[0, i] = read_file(files[j], h, w, resize=True) for i, j in enumerate(range(idx, idx + global_config['OUT_TARGET_LEN'])): sliced_label[0, i] = read_file(files[j]) sliced_input = (mm_dbz(sliced_input) - global_config['NORM_MIN']) / global_config['NORM_DIV'] sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE'])[:, None] save_dir = save_dir + '/extracted' if not os.path.exists(save_dir): os.makedirs(save_dir) model.eval() for dt in tqdm(range(6 * 24)): cur_input = sliced_input.clone() out_time = int( np.ceil(global_config['OUT_TARGET_LEN'] / config['OUT_LEN'])) outputs = None with torch.no_grad(): for t in range(out_time): output = model(cur_input) if outputs is None: outputs = output.detach().cpu().numpy() else: outputs = np.concatenate( [outputs, output.detach().cpu().numpy()], axis=2) if config['OPTFLOW']: cur_input = torch.cat([cur_input[:, :, -1, None], output], axis=2) else: cur_input = output pred = np.array(outputs)[:, ] pred = pred[:, 0, :global_config['OUT_TARGET_LEN']] pred = denorm(pred) pred_resized = np.zeros( (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) for i in range(pred.shape[0]): for j in range(pred.shape[1]): pred_resized[i, j] = cv2.resize(pred[i, j], (global_config['DATA_WIDTH'], global_config['DATA_HEIGHT']), interpolation=cv2.INTER_AREA) csi = fp_fn_image_csi(pred_resized, sliced_label) csi_multi, micro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label) sum_rmse = np.zeros((3, ), dtype=np.float32) for i in range(global_config['OUT_TARGET_LEN']): rmse, rmse_rain, rmse_non_rain = torch_cal_rmse_all( torch.from_numpy(pred_resized[:, i]).to(config['CAL_DEVICE']), torch.from_numpy(sliced_label[:, i]).to(config['CAL_DEVICE'])) sum_rmse += np.array([ rmse.cpu().numpy(), rmse_rain.cpu().numpy(), rmse_non_rain.cpu().numpy() ]) mean_rmse = sum_rmse / global_config['OUT_TARGET_LEN'] h_small = pred.shape[2] w_small = pred.shape[3] label_small = np.zeros( (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small)) for i in range(sliced_label.shape[0]): for j in range(sliced_label.shape[1]): label_small[i, j] = cv2.resize(sliced_label[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA) time_name = os.path.basename(files[idx + dt])[:-4] path = save_dir + '/' + time_name if not os.path.exists(path): os.makedirs(path) if not os.path.exists(path + '/pred'): os.makedirs(path + '/pred') if not os.path.exists(path + '/label'): os.makedirs(path + '/label') for i in range(pred_resized.shape[0]): for j in range(global_config['OUT_TARGET_LEN']): alpha_mask = (label_small[i, j] > 0.2).astype(np.uint8) * 255 lb = rainfall_shade(label_small[i, j], mode='BGR') lb = np.dstack((lb, alpha_mask)) alpha_mask = (pred[i, j] > 0.2).astype(np.uint8) * 255 prd = rainfall_shade(pred[i, j], mode='BGR') prd = np.dstack((prd, alpha_mask)) cv2.imwrite(path + '/label/' + str(j) + '.png', lb) cv2.imwrite(path + '/pred/' + str(j) + '.png', prd) np.savetxt(path + '/metrics.txt', [csi, micro_csi, np.mean(csi_multi)] + list(csi_multi) + list(mean_rmse), delimiter=',', fmt='%.2f') sliced_input[:, :-1] = sliced_input[:, 1:] next_input = (mm_dbz(read_file(files[idx + dt], h, w, resize=True)) - global_config['NORM_MIN']) / global_config['NORM_DIV'] sliced_input[:, -1] = torch.from_numpy(next_input).to(config['DEVICE']) sliced_label[:, :-1] = sliced_label[:, 1:] sliced_label[:, -1] = read_file( files[idx + global_config['OUT_TARGET_LEN'] + dt])
def validate(self): self.model.eval() n_val_batch = self.data_loader.n_val_batch() n_val = 20 self.val_loss = 0 self.val_metrics_value[:] = 0 for ib_val, b_val in enumerate(np.random.choice(n_val_batch, n_val)): self.pbar_i.set_description("Validating at batch %d / %d" % (ib_val, n_val)) val_data, val_label = self.data_loader.get_val(b_val) with torch.no_grad(): output = self.model(val_data) loss = self.mse_loss(output, val_label) + self.mae_loss(output, val_label) self.val_loss += loss.data.item() / len(val_data) # lbl_pred = output # lbl_true = val_label lbl_pred = output.detach().cpu().numpy() lbl_true = val_label.cpu().numpy() # print('val', lbl_pred.shape, lbl_true.shape) # csis, w_csi = torch_csi_muti(torch_denorm(lbl_pred), torch_denorm(lbl_true)) csis, w_csi = fp_fn_image_csi_muti(denorm(lbl_pred), denorm(lbl_true)) self.val_metrics_value += csis self.train_loss /= self.interval_validate self.train_metrics_value /= self.interval_validate self.val_loss /= n_val self.val_metrics_value /= n_val self.writer.add_scalars('loss', { 'train': self.train_loss, 'valid': self.val_loss }, self.epoch) for i in range(len(self.metrics_name)): self.writer.add_scalars(self.metrics_name[i], { 'train': self.train_metrics_value[i], 'valid': self.val_metrics_value[i] }, self.epoch) # print('img', lbl_pred.shape, lbl_true.shape) # lbl_pred = lbl_pred.detach().cpu().numpy() # lbl_true = lbl_true.cpu().numpy() self.writer.add_image('result/pred', rainfall_shade(denorm(lbl_pred[-1, 0, 0, :, :, None])).swapaxes(0,2), self.epoch) self.writer.add_image('result/true', rainfall_shade(denorm(lbl_true[-1, 0, 0, :, :, None])).swapaxes(0,2), self.epoch) if self.val_loss <= self.best_val_loss: try: torch.save(self.model.module.state_dict(), os.path.join(self.save_dir, 'model_best.pth')) except: torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_best.pth')) self.best_val_loss = self.val_loss with open(os.path.join(self.save_dir, "best.txt"), "w") as file: file.write(str(self.epoch)) self.train_loss = 0 self.train_metrics_value[:] = 0
def test(model, data_loader, config, save_dir, files, file_name, crop=None): h1, h2, w1, w2 = 0, global_config['DATA_HEIGHT'] - \ 1, 0, global_config['DATA_WIDTH'] - 1 if crop is not None: h1, h2, w1, w2 = get_crop_boundary_idx(crop) idx = 0 try: idx = next(i for i, f in enumerate(files) if os.path.basename(f) == file_name) except: print('not found') return scale = config['SCALE'] h = int(global_config['DATA_HEIGHT'] * scale) w = int(global_config['DATA_WIDTH'] * scale) sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32) sliced_label = np.zeros( (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), dtype=np.float32) for i, j in enumerate(range(idx - config['IN_LEN'], idx)): f = np.fromfile(files[j], dtype=np.float32) \ .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) sliced_input[0, i] = \ cv2.resize(f, (w, h), interpolation=cv2.INTER_AREA) for i, j in enumerate(range(idx, idx+global_config['OUT_TARGET_LEN'])): sliced_label[0, i] = np.fromfile(files[j], dtype=np.float32) \ .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) sliced_input = (mm_dbz(sliced_input) - global_config['NORM_MIN']) / global_config['NORM_DIV'] sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE']) if config['DIM'] == '3D': sliced_input = sliced_input[:, None, :] elif config['DIM'] == '2D': sliced_input = sliced_input[:, :, None] outputs = None with torch.no_grad(): for t in range(int(np.ceil(global_config['OUT_TARGET_LEN']/config['OUT_LEN']))): # print('input data', sliced_input.shape) output = model(sliced_input) # print('output', output.shape) if outputs is None: outputs = output.detach().cpu().numpy() else: if config['DIM'] == '3D': outputs = np.concatenate( [outputs, output.detach().cpu().numpy()], axis=2) else: outputs = np.concatenate( [outputs, output.detach().cpu().numpy()], axis=1) if config['DIM'] == '3D': if config['OPTFLOW']: sliced_input = torch.cat([sliced_input[:, :, -1, None], output], axis=2) else: sliced_input = output else: sliced_input = torch.cat([sliced_input[:, config['OUT_LEN']:], output], dim=1) pred = np.array(outputs) if config['DIM'] == '3D': pred = pred[:, 0] pred = pred[:, :global_config['OUT_TARGET_LEN']] elif config['DIM'] == '3D': pred = pred[:, :, 0] pred = pred[:, :global_config['OUT_TARGET_LEN']] else: pred = pred[:, :global_config['OUT_TARGET_LEN']] # print('pred label shape', pred.shape, sliced_label.shape) pred = denorm(pred) pred_resized = np.zeros( (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) for i in range(pred.shape[0]): for j in range(pred.shape[1]): pred_resized[i, j] = cv2.resize( pred[i, j], (global_config['DATA_WIDTH'], global_config['DATA_HEIGHT']), interpolation=cv2.INTER_AREA) pred_resized = pred_resized[:, :, h1: h2 + 1, w1: w2 + 1] sliced_label = sliced_label[:, :, h1: h2 + 1, w1: w2 + 1] # csi = fp_fn_image_csi(pred_resized, sliced_label) csis = [] for c in range(pred.shape[1]): csis.append(fp_fn_image_csi(pred_resized[:, c], sliced_label[:, c])) csi = np.mean(csis) csi_multi, macro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label) rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred_resized, sliced_label) result_all = [csi] + list(csi_multi) + [rmse, rmse_rain, rmse_non_rain] h_small = int(pred_resized.shape[2] * 0.5) w_small = int(pred_resized.shape[3] * 0.5) pred_small = np.zeros( (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small)) label_small = np.zeros( (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small)) for i in range(sliced_label.shape[0]): for j in range(sliced_label.shape[1]): pred_small[i, j] = cv2.resize( pred_resized[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA) label_small[i, j] = cv2.resize( sliced_label[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA) path = save_dir + '/imgs' if not os.path.exists(path): os.makedirs(path) for i in range(pred_resized.shape[0]): # Save pred gif # make_gif(pred[i] / 80 * 255, path + '/pred_{}_{}.gif'.format(b, i)) # Save colored pred gif make_gif_color(pred_small[i], path + '/pred_colored.gif') # Save gt gif # make_gif(label_small[i] / 80 * 255, path + '/gt_{}_{}.gif'.format(b, i)) # Save colored gt gif make_gif_color(label_small[i], path + '/gt_colored.gif') labels = [os.path.basename(files[idx+i]) for i in range(global_config['OUT_TARGET_LEN'])] make_gif_color_label(label_small[i], pred_small[i], labels, fname=path + '/all.gif') fig, ax = plt.subplots(figsize=(8, 4), facecolor='white') ax.plot(np.arange(len(csis))+1, csis) ax.set_xticks(np.arange(global_config['OUT_TARGET_LEN'])+1) ax.set_ylabel('Binary - CSI') ax.set_xlabel('Time Steps') plt.savefig(path + '/csis.png') result_all = np.array(result_all) result_all = np.around(result_all, decimals=3) np.savetxt(save_dir + '/result.txt', result_all, delimiter=',', fmt='%.3f') np.savetxt(save_dir + '/csi.txt', np.array(csis), delimiter=',', fmt='%.3f')
def test(model, data_loader, config, save_dir, crop=None): save_dir = save_dir + '/res' if not os.path.exists(save_dir): os.makedirs(save_dir) model.eval() result_all = [] n_test = data_loader.n_test_batch() test_idx = np.arange(0, n_test - 3 * global_config['OUT_TARGET_LEN'], 3 * global_config['OUT_TARGET_LEN']) # np.random.seed(42) # np.random.shuffle(test_idx) all_csis = np.zeros((global_config['OUT_TARGET_LEN'], )) for b in tqdm(test_idx): # for b in tqdm(range(n_test)): data, label = data_loader.get_test(b) # per = torch.sum(data >= thres) # if per <= 4e4: # print('skip', per) # continue outputs = None out_time = int( np.ceil(global_config['OUT_TARGET_LEN'] / config['OUT_LEN'])) with torch.no_grad(): for t in range(out_time): output = model(data) if outputs is None: outputs = output.detach().cpu().numpy() else: outputs = np.concatenate( [outputs, output.detach().cpu().numpy()], axis=0) data = output[-config['IN_LEN']:] pred = np.array(outputs)[:global_config['OUT_TARGET_LEN'], 0, 0][None, :] label = label[0, None] pred = denorm(pred) pred_resized = np.zeros( (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH'])) for i in range(pred.shape[0]): for j in range(pred.shape[1]): pred_resized[i, j] = cv2.resize(pred[i, j], (global_config['DATA_WIDTH'], global_config['DATA_HEIGHT']), interpolation=cv2.INTER_AREA) # don't need to denorm test csi_time = [] for t in range(global_config['OUT_TARGET_LEN']): csi_time.append(fp_fn_image_csi(pred_resized[:, t], label[:, t])) csi = np.mean(csi_time) all_csis += csi_time csi_multi, macro_csi = fp_fn_image_csi_muti(pred_resized, label) # rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred_resized, label) sum_rmse = np.zeros((3, ), dtype=np.float32) for i in range(global_config['OUT_TARGET_LEN']): rmse, rmse_rain, rmse_non_rain = torch_cal_rmse_all( torch.from_numpy(pred_resized[:, i]).to(config['CAL_DEVICE']), torch.from_numpy(label[:, i]).to(config['CAL_DEVICE'])) sum_rmse += np.array([ rmse.cpu().numpy(), rmse_rain.cpu().numpy(), rmse_non_rain.cpu().numpy() ]) mean_rmse = sum_rmse / global_config['OUT_TARGET_LEN'] result_all.append([csi, macro_csi] + list(csi_multi) + list(mean_rmse)) h_small = pred.shape[2] w_small = pred.shape[3] label_small = np.zeros( (label.shape[0], label.shape[1], h_small, w_small)) for i in range(label.shape[0]): for j in range(label.shape[1]): label_small[i, j] = cv2.resize(label[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA) path = save_dir + '/imgs' if not os.path.exists(path): os.makedirs(path) for i in range(pred_resized.shape[0]): # Save pred gif # make_gif(pred[i] / 80 * 255, path + '/pred_{}_{}.gif'.format(b, i)) # Save colored pred gif # make_gif_color(pred[i], path + '/pred_colored.gif') # Save gt gif # make_gif(label_small[i] / 80 * 255, path + '/gt_{}_{}.gif'.format(b, i)) # Save colored gt gif # make_gif_color(label_small[i], path + '/gt_colored.gif') labels = ['' for i in range(global_config['OUT_TARGET_LEN'])] make_gif_color_label(label_small[i], pred[i], labels, fname=path + '/{}.gif'.format(b)) all_csis /= len(test_idx) np.savetxt(save_dir + '/csi_time.txt', all_csis, delimiter=',', fmt='%.3f') result_all = np.array(result_all) result_all_mean = np.mean(result_all, axis=0) result_all_mean = np.around(result_all_mean, decimals=3) np.savetxt(save_dir + '/result.txt', result_all_mean, delimiter=',', fmt='%.3f')