Esempio n. 1
0
    def train_iteration(self):

        n_train_batch = self.data_loader.n_train_batch()
        pbar_b = tqdm(range(n_train_batch))
        for b in pbar_b:
            self.model.train()
            pbar_b.set_description('Training at batch %d / %d' % (b, n_train_batch))
            train_data, train_label = self.data_loader.get_train(b)
            self.optim.zero_grad()
            output = self.model(train_data)
            
            loss = self.mse_loss(output, train_label) + self.mae_loss(output, train_label)
            loss.backward()

            self.optim.step()
            self.train_loss += loss.data.item() / len(train_data)

            # lbl_pred = output
            # lbl_true = train_label
            lbl_pred = output.detach().cpu().numpy()
            lbl_true = train_label.cpu().numpy()
            # print('train', lbl_pred.shape, lbl_true.shape)
            # csis, w_csi = torch_csi_muti(torch_denorm(lbl_pred), torch_denorm(lbl_true))
            csis, w_csi = fp_fn_image_csi_muti(denorm(lbl_pred), denorm(lbl_true))
            self.train_metrics_value += csis
            self.add_epoch()
Esempio n. 2
0
def extract(model, config, save_dir, files, file_name, crop=None):

    idx = 0
    try:
        idx = next(i for i, f in enumerate(files)
                   if os.path.basename(f) == file_name)
    except:
        print('not found')
        return

    scale = config['SCALE']
    h = int(global_config['DATA_HEIGHT'] * scale)
    w = int(global_config['DATA_WIDTH'] * scale)
    sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32)
    sliced_label = np.zeros(
        (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'],
         global_config['DATA_WIDTH']),
        dtype=np.float32)

    for i, j in enumerate(range(idx - config['IN_LEN'], idx)):
        sliced_input[0, i] = read_file(files[j], h, w, resize=True)

    for i, j in enumerate(range(idx, idx + global_config['OUT_TARGET_LEN'])):
        sliced_label[0, i] = read_file(files[j])

    sliced_input = (mm_dbz(sliced_input) -
                    global_config['NORM_MIN']) / global_config['NORM_DIV']
    sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE'])[:, None]

    save_dir = save_dir + '/extracted'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model.eval()
    for dt in tqdm(range(6 * 24)):
        cur_input = sliced_input.clone()
        out_time = int(
            np.ceil(global_config['OUT_TARGET_LEN'] / config['OUT_LEN']))
        outputs = None
        with torch.no_grad():
            for t in range(out_time):
                output = model(cur_input)
                if outputs is None:
                    outputs = output.detach().cpu().numpy()
                else:
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=2)
                if config['OPTFLOW']:
                    cur_input = torch.cat([cur_input[:, :, -1, None], output],
                                          axis=2)
                else:
                    cur_input = output
        pred = np.array(outputs)[:, ]
        pred = pred[:, 0, :global_config['OUT_TARGET_LEN']]
        pred = denorm(pred)
        pred_resized = np.zeros(
            (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'],
             global_config['DATA_WIDTH']))
        for i in range(pred.shape[0]):
            for j in range(pred.shape[1]):
                pred_resized[i, j] = cv2.resize(pred[i, j],
                                                (global_config['DATA_WIDTH'],
                                                 global_config['DATA_HEIGHT']),
                                                interpolation=cv2.INTER_AREA)

        csi = fp_fn_image_csi(pred_resized, sliced_label)
        csi_multi, micro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label)
        sum_rmse = np.zeros((3, ), dtype=np.float32)
        for i in range(global_config['OUT_TARGET_LEN']):
            rmse, rmse_rain, rmse_non_rain = torch_cal_rmse_all(
                torch.from_numpy(pred_resized[:, i]).to(config['CAL_DEVICE']),
                torch.from_numpy(sliced_label[:, i]).to(config['CAL_DEVICE']))
            sum_rmse += np.array([
                rmse.cpu().numpy(),
                rmse_rain.cpu().numpy(),
                rmse_non_rain.cpu().numpy()
            ])
        mean_rmse = sum_rmse / global_config['OUT_TARGET_LEN']

        h_small = pred.shape[2]
        w_small = pred.shape[3]
        label_small = np.zeros(
            (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
        for i in range(sliced_label.shape[0]):
            for j in range(sliced_label.shape[1]):
                label_small[i, j] = cv2.resize(sliced_label[i, j],
                                               (w_small, h_small),
                                               interpolation=cv2.INTER_AREA)

        time_name = os.path.basename(files[idx + dt])[:-4]
        path = save_dir + '/' + time_name
        if not os.path.exists(path):
            os.makedirs(path)
        if not os.path.exists(path + '/pred'):
            os.makedirs(path + '/pred')
        if not os.path.exists(path + '/label'):
            os.makedirs(path + '/label')
        for i in range(pred_resized.shape[0]):
            for j in range(global_config['OUT_TARGET_LEN']):
                alpha_mask = (label_small[i, j] > 0.2).astype(np.uint8) * 255
                lb = rainfall_shade(label_small[i, j], mode='BGR')
                lb = np.dstack((lb, alpha_mask))

                alpha_mask = (pred[i, j] > 0.2).astype(np.uint8) * 255
                prd = rainfall_shade(pred[i, j], mode='BGR')
                prd = np.dstack((prd, alpha_mask))

                cv2.imwrite(path + '/label/' + str(j) + '.png', lb)
                cv2.imwrite(path + '/pred/' + str(j) + '.png', prd)
                np.savetxt(path + '/metrics.txt',
                           [csi, micro_csi, np.mean(csi_multi)] +
                           list(csi_multi) + list(mean_rmse),
                           delimiter=',',
                           fmt='%.2f')

        sliced_input[:, :-1] = sliced_input[:, 1:]
        next_input = (mm_dbz(read_file(files[idx + dt], h, w, resize=True)) -
                      global_config['NORM_MIN']) / global_config['NORM_DIV']
        sliced_input[:, -1] = torch.from_numpy(next_input).to(config['DEVICE'])

        sliced_label[:, :-1] = sliced_label[:, 1:]
        sliced_label[:, -1] = read_file(
            files[idx + global_config['OUT_TARGET_LEN'] + dt])
Esempio n. 3
0
    def validate(self):

        self.model.eval()
        n_val_batch = self.data_loader.n_val_batch()
        n_val = 20
        self.val_loss = 0
        self.val_metrics_value[:] = 0
        for ib_val, b_val in enumerate(np.random.choice(n_val_batch, n_val)):

            self.pbar_i.set_description("Validating at batch %d / %d" % (ib_val, n_val))
            val_data, val_label = self.data_loader.get_val(b_val)
            with torch.no_grad():
                output = self.model(val_data)
            
            loss = self.mse_loss(output, val_label) + self.mae_loss(output, val_label)
            
            self.val_loss += loss.data.item() / len(val_data)
            # lbl_pred = output
            # lbl_true = val_label
            lbl_pred = output.detach().cpu().numpy()
            lbl_true = val_label.cpu().numpy()
            # print('val', lbl_pred.shape, lbl_true.shape)
            # csis, w_csi = torch_csi_muti(torch_denorm(lbl_pred), torch_denorm(lbl_true))
            csis, w_csi = fp_fn_image_csi_muti(denorm(lbl_pred), denorm(lbl_true))
            self.val_metrics_value += csis

        self.train_loss /= self.interval_validate
        self.train_metrics_value /= self.interval_validate
        self.val_loss /= n_val
        self.val_metrics_value /= n_val
        self.writer.add_scalars('loss', {
            'train': self.train_loss,
            'valid': self.val_loss
        }, self.epoch)
        for i in range(len(self.metrics_name)):
            self.writer.add_scalars(self.metrics_name[i], {
                'train': self.train_metrics_value[i],
                'valid': self.val_metrics_value[i]
            }, self.epoch)

        # print('img', lbl_pred.shape, lbl_true.shape)
        # lbl_pred = lbl_pred.detach().cpu().numpy()
        # lbl_true = lbl_true.cpu().numpy()
        self.writer.add_image('result/pred',
            rainfall_shade(denorm(lbl_pred[-1, 0, 0, :, :, None])).swapaxes(0,2), 
            self.epoch)
        self.writer.add_image('result/true',
            rainfall_shade(denorm(lbl_true[-1, 0, 0, :, :, None])).swapaxes(0,2), 
            self.epoch)

        if self.val_loss <= self.best_val_loss:
            try:
                torch.save(self.model.module.state_dict(), os.path.join(self.save_dir, 
                    'model_best.pth'))
            except:
                torch.save(self.model.state_dict(), os.path.join(self.save_dir, 
                    'model_best.pth'))
            self.best_val_loss = self.val_loss
            with open(os.path.join(self.save_dir, "best.txt"), "w") as file:
                file.write(str(self.epoch))
            
        self.train_loss = 0
        self.train_metrics_value[:] = 0
Esempio n. 4
0
def test(model, data_loader, config, save_dir, files, file_name, crop=None):

    h1, h2, w1, w2 = 0, global_config['DATA_HEIGHT'] - \
        1, 0, global_config['DATA_WIDTH'] - 1
    if crop is not None:
        h1, h2, w1, w2 = get_crop_boundary_idx(crop)

    idx = 0
    try:
        idx = next(i for i, f in enumerate(files)
                   if os.path.basename(f) == file_name)
    except:
        print('not found')
        return

    scale = config['SCALE']
    h = int(global_config['DATA_HEIGHT'] * scale)
    w = int(global_config['DATA_WIDTH'] * scale)
    sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32)
    sliced_label = np.zeros(
        (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), dtype=np.float32)

    for i, j in enumerate(range(idx - config['IN_LEN'], idx)):
        f = np.fromfile(files[j], dtype=np.float32) \
            .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
        sliced_input[0, i] = \
            cv2.resize(f, (w, h), interpolation=cv2.INTER_AREA)

    for i, j in enumerate(range(idx, idx+global_config['OUT_TARGET_LEN'])):
        sliced_label[0, i] = np.fromfile(files[j], dtype=np.float32) \
            .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))

    sliced_input = (mm_dbz(sliced_input) -
                    global_config['NORM_MIN']) / global_config['NORM_DIV']
    sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE'])

    if config['DIM'] == '3D':
        sliced_input = sliced_input[:, None, :]
    elif config['DIM'] == '2D':
        sliced_input = sliced_input[:, :, None]

    outputs = None
    with torch.no_grad():
        for t in range(int(np.ceil(global_config['OUT_TARGET_LEN']/config['OUT_LEN']))):
            # print('input data', sliced_input.shape)
            output = model(sliced_input)
            # print('output', output.shape)
            if outputs is None:
                outputs = output.detach().cpu().numpy()
            else:
                if config['DIM'] == '3D':
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=2)
                else:
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=1)

            if config['DIM'] == '3D':
                if config['OPTFLOW']:
                    sliced_input = torch.cat([sliced_input[:, :, -1, None], output], axis=2)
                else:
                    sliced_input = output
            else:
                sliced_input = torch.cat([sliced_input[:, config['OUT_LEN']:], output], dim=1)

    pred = np.array(outputs)
    if config['DIM'] == '3D':
        pred = pred[:, 0]
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    elif config['DIM'] == '3D':
        pred = pred[:, :, 0]
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    else:
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    # print('pred label shape', pred.shape, sliced_label.shape)
    pred = denorm(pred)
    pred_resized = np.zeros(
        (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
    for i in range(pred.shape[0]):
        for j in range(pred.shape[1]):
            pred_resized[i, j] = cv2.resize(
                pred[i, j], (global_config['DATA_WIDTH'], global_config['DATA_HEIGHT']), interpolation=cv2.INTER_AREA)

    pred_resized = pred_resized[:, :, h1: h2 + 1, w1: w2 + 1]
    sliced_label = sliced_label[:, :, h1: h2 + 1, w1: w2 + 1]
    # csi = fp_fn_image_csi(pred_resized, sliced_label)
    csis = []
    for c in range(pred.shape[1]):
        csis.append(fp_fn_image_csi(pred_resized[:, c], sliced_label[:, c]))
    csi = np.mean(csis)
    csi_multi, macro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label)
    rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred_resized, sliced_label)
    result_all = [csi] + list(csi_multi) + [rmse, rmse_rain, rmse_non_rain]

    h_small = int(pred_resized.shape[2] * 0.5)
    w_small = int(pred_resized.shape[3] * 0.5)

    pred_small = np.zeros(
        (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
    label_small = np.zeros(
        (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
    for i in range(sliced_label.shape[0]):
        for j in range(sliced_label.shape[1]):
            pred_small[i, j] = cv2.resize(
                pred_resized[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA)
            label_small[i, j] = cv2.resize(
                sliced_label[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA)

    path = save_dir + '/imgs'
    if not os.path.exists(path):
        os.makedirs(path)
    for i in range(pred_resized.shape[0]):
        # Save pred gif
        # make_gif(pred[i] / 80 * 255, path + '/pred_{}_{}.gif'.format(b, i))
        # Save colored pred gif
        make_gif_color(pred_small[i], path + '/pred_colored.gif')
        # Save gt gif
        # make_gif(label_small[i] / 80 * 255, path + '/gt_{}_{}.gif'.format(b, i))
        # Save colored gt gif
        make_gif_color(label_small[i], path + '/gt_colored.gif')

        labels = [os.path.basename(files[idx+i]) for i in range(global_config['OUT_TARGET_LEN'])]
        make_gif_color_label(label_small[i], pred_small[i], labels, fname=path + '/all.gif')

    fig, ax = plt.subplots(figsize=(8, 4), facecolor='white')
    ax.plot(np.arange(len(csis))+1, csis)
    ax.set_xticks(np.arange(global_config['OUT_TARGET_LEN'])+1)
    ax.set_ylabel('Binary - CSI')
    ax.set_xlabel('Time Steps')
    plt.savefig(path + '/csis.png')

    result_all = np.array(result_all)
    result_all = np.around(result_all, decimals=3)
    np.savetxt(save_dir + '/result.txt', result_all, delimiter=',', fmt='%.3f')
    np.savetxt(save_dir + '/csi.txt', np.array(csis), delimiter=',', fmt='%.3f')
Esempio n. 5
0
def test(model, data_loader, config, save_dir, crop=None):

    save_dir = save_dir + '/res'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model.eval()
    result_all = []
    n_test = data_loader.n_test_batch()
    test_idx = np.arange(0, n_test - 3 * global_config['OUT_TARGET_LEN'],
                         3 * global_config['OUT_TARGET_LEN'])
    # np.random.seed(42)
    # np.random.shuffle(test_idx)
    all_csis = np.zeros((global_config['OUT_TARGET_LEN'], ))
    for b in tqdm(test_idx):
        # for b in tqdm(range(n_test)):
        data, label = data_loader.get_test(b)
        # per = torch.sum(data >= thres)
        # if per <= 4e4:
        #     print('skip', per)
        #     continue
        outputs = None
        out_time = int(
            np.ceil(global_config['OUT_TARGET_LEN'] / config['OUT_LEN']))
        with torch.no_grad():
            for t in range(out_time):
                output = model(data)
                if outputs is None:
                    outputs = output.detach().cpu().numpy()
                else:
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=0)

                data = output[-config['IN_LEN']:]
        pred = np.array(outputs)[:global_config['OUT_TARGET_LEN'], 0,
                                 0][None, :]
        label = label[0, None]
        pred = denorm(pred)
        pred_resized = np.zeros(
            (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'],
             global_config['DATA_WIDTH']))
        for i in range(pred.shape[0]):
            for j in range(pred.shape[1]):
                pred_resized[i, j] = cv2.resize(pred[i, j],
                                                (global_config['DATA_WIDTH'],
                                                 global_config['DATA_HEIGHT']),
                                                interpolation=cv2.INTER_AREA)
        # don't need to denorm test
        csi_time = []
        for t in range(global_config['OUT_TARGET_LEN']):
            csi_time.append(fp_fn_image_csi(pred_resized[:, t], label[:, t]))
        csi = np.mean(csi_time)
        all_csis += csi_time
        csi_multi, macro_csi = fp_fn_image_csi_muti(pred_resized, label)
        # rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred_resized, label)
        sum_rmse = np.zeros((3, ), dtype=np.float32)
        for i in range(global_config['OUT_TARGET_LEN']):
            rmse, rmse_rain, rmse_non_rain = torch_cal_rmse_all(
                torch.from_numpy(pred_resized[:, i]).to(config['CAL_DEVICE']),
                torch.from_numpy(label[:, i]).to(config['CAL_DEVICE']))
            sum_rmse += np.array([
                rmse.cpu().numpy(),
                rmse_rain.cpu().numpy(),
                rmse_non_rain.cpu().numpy()
            ])
        mean_rmse = sum_rmse / global_config['OUT_TARGET_LEN']
        result_all.append([csi, macro_csi] + list(csi_multi) + list(mean_rmse))

        h_small = pred.shape[2]
        w_small = pred.shape[3]
        label_small = np.zeros(
            (label.shape[0], label.shape[1], h_small, w_small))
        for i in range(label.shape[0]):
            for j in range(label.shape[1]):
                label_small[i, j] = cv2.resize(label[i, j], (w_small, h_small),
                                               interpolation=cv2.INTER_AREA)

        path = save_dir + '/imgs'
        if not os.path.exists(path):
            os.makedirs(path)
        for i in range(pred_resized.shape[0]):
            # Save pred gif
            #             make_gif(pred[i] / 80 * 255, path + '/pred_{}_{}.gif'.format(b, i))
            # Save colored pred gif
            # make_gif_color(pred[i], path + '/pred_colored.gif')
            # Save gt gif
            #             make_gif(label_small[i] / 80 * 255, path + '/gt_{}_{}.gif'.format(b, i))
            # Save colored gt gif
            # make_gif_color(label_small[i], path + '/gt_colored.gif')

            labels = ['' for i in range(global_config['OUT_TARGET_LEN'])]
            make_gif_color_label(label_small[i],
                                 pred[i],
                                 labels,
                                 fname=path + '/{}.gif'.format(b))

    all_csis /= len(test_idx)
    np.savetxt(save_dir + '/csi_time.txt', all_csis, delimiter=',', fmt='%.3f')

    result_all = np.array(result_all)
    result_all_mean = np.mean(result_all, axis=0)
    result_all_mean = np.around(result_all_mean, decimals=3)
    np.savetxt(save_dir + '/result.txt',
               result_all_mean,
               delimiter=',',
               fmt='%.3f')