Ejemplo n.º 1
0
def predict(input, model, config=None):
    assert input.shape[0] == config['IN_LEN']
    assert input.shape[2] == global_config['IMG_SIZE']
    assert input.shape[3] == global_config['IMG_SIZE']

    with torch.no_grad():
        input = torch.from_numpy(input[:, :, None]).to(config['DEVICE'])
        output = model(input)

    assert output.shape[0] == config['OUT_LEN']
    assert output.shape[3] == global_config['IMG_SIZE']
    assert output.shape[4] == global_config['IMG_SIZE']

    return np.minimum(np.maximum(output.cpu().numpy()[:, :, 0], mm_dbz(0)),
                      mm_dbz(60))
Ejemplo n.º 2
0
    def get_data_test(self, indices):

        if self.config['SCALE'] is None:
            h = self.config['SIZEH']
            w = self.config['SIZEW']
        else:
            scale = self.config['SCALE']
            h = int(global_config['DATA_HEIGHT'] * scale)
            w = int(global_config['DATA_WIDTH'] * scale)
        sliced_input = np.zeros((len(indices), self.config['IN_LEN'], h, w),
                                dtype=np.float32)
        sliced_label = np.zeros(
            (len(indices), global_config['OUT_TARGET_LEN'],
             global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']),
            dtype=np.float32)
        for i, idx in enumerate(indices):
            for j in range(self.config['IN_LEN']):
                f = np.fromfile(self.files[idx + j], dtype=np.float32) \
                    .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
                sliced_input[i, j] = \
                    cv2.resize(f, (w, h), interpolation = cv2.INTER_AREA)

        for i, idx in enumerate(indices):
            for j in range(global_config['OUT_TARGET_LEN']):
                sliced_label[i, j] = np.fromfile(self.files[idx + j], dtype=np.float32) \
                    .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))

        sliced_input = (mm_dbz(sliced_input) -
                        global_config['NORM_MIN']) / global_config['NORM_DIV']

        if self.last_data is not None:
            for i in self.last_data:
                del i
            torch.cuda.empty_cache()

        self.last_data = []
        self.last_data.append(
            torch.from_numpy(sliced_input).to(self.config['DEVICE']))

        self.last_data.append(sliced_label)

        if self.config['DIM'] == '3D':
            for i in range(len(self.last_data)):
                self.last_data[i] = self.last_data[i][:, None, :]
        elif self.config['DIM'] == '2D':
            for i in range(len(self.last_data)):
                self.last_data[i] = self.last_data[i][:, :, None]

        return tuple(self.last_data)
Ejemplo n.º 3
0
    def get_data(self, indices):
        if self.config['SCALE'] is None:
            h = self.config['SIZEH']
            w = self.config['SIZEW']
        else:
            scale = self.config['SCALE']
            h = int(global_config['DATA_HEIGHT'] * scale)
            w = int(global_config['DATA_WIDTH'] * scale)
        sliced_data = np.zeros((len(indices), self.windows_size, h, w),
                               dtype=np.float32)
        for i, idx in enumerate(indices):
            for j in range(self.windows_size):
                f = np.fromfile(self.files[idx + j], dtype=np.float32) \
                    .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
                sliced_data[i, j] = \
                    cv2.resize(f, (w, h), interpolation = cv2.INTER_AREA)

        return ((mm_dbz(sliced_data) - global_config['NORM_MIN']) /
                global_config['NORM_DIV'])[:, :, 6:-6, 1:-1]
Ejemplo n.º 4
0
def extract(model, config, save_dir, files, file_name, crop=None):

    idx = 0
    try:
        idx = next(i for i, f in enumerate(files)
                   if os.path.basename(f) == file_name)
    except:
        print('not found')
        return

    scale = config['SCALE']
    h = int(global_config['DATA_HEIGHT'] * scale)
    w = int(global_config['DATA_WIDTH'] * scale)
    sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32)
    sliced_label = np.zeros(
        (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'],
         global_config['DATA_WIDTH']),
        dtype=np.float32)

    for i, j in enumerate(range(idx - config['IN_LEN'], idx)):
        sliced_input[0, i] = read_file(files[j], h, w, resize=True)

    for i, j in enumerate(range(idx, idx + global_config['OUT_TARGET_LEN'])):
        sliced_label[0, i] = read_file(files[j])

    sliced_input = (mm_dbz(sliced_input) -
                    global_config['NORM_MIN']) / global_config['NORM_DIV']
    sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE'])[:, None]

    save_dir = save_dir + '/extracted'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model.eval()
    for dt in tqdm(range(6 * 24)):
        cur_input = sliced_input.clone()
        out_time = int(
            np.ceil(global_config['OUT_TARGET_LEN'] / config['OUT_LEN']))
        outputs = None
        with torch.no_grad():
            for t in range(out_time):
                output = model(cur_input)
                if outputs is None:
                    outputs = output.detach().cpu().numpy()
                else:
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=2)
                if config['OPTFLOW']:
                    cur_input = torch.cat([cur_input[:, :, -1, None], output],
                                          axis=2)
                else:
                    cur_input = output
        pred = np.array(outputs)[:, ]
        pred = pred[:, 0, :global_config['OUT_TARGET_LEN']]
        pred = denorm(pred)
        pred_resized = np.zeros(
            (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'],
             global_config['DATA_WIDTH']))
        for i in range(pred.shape[0]):
            for j in range(pred.shape[1]):
                pred_resized[i, j] = cv2.resize(pred[i, j],
                                                (global_config['DATA_WIDTH'],
                                                 global_config['DATA_HEIGHT']),
                                                interpolation=cv2.INTER_AREA)

        csi = fp_fn_image_csi(pred_resized, sliced_label)
        csi_multi, micro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label)
        sum_rmse = np.zeros((3, ), dtype=np.float32)
        for i in range(global_config['OUT_TARGET_LEN']):
            rmse, rmse_rain, rmse_non_rain = torch_cal_rmse_all(
                torch.from_numpy(pred_resized[:, i]).to(config['CAL_DEVICE']),
                torch.from_numpy(sliced_label[:, i]).to(config['CAL_DEVICE']))
            sum_rmse += np.array([
                rmse.cpu().numpy(),
                rmse_rain.cpu().numpy(),
                rmse_non_rain.cpu().numpy()
            ])
        mean_rmse = sum_rmse / global_config['OUT_TARGET_LEN']

        h_small = pred.shape[2]
        w_small = pred.shape[3]
        label_small = np.zeros(
            (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
        for i in range(sliced_label.shape[0]):
            for j in range(sliced_label.shape[1]):
                label_small[i, j] = cv2.resize(sliced_label[i, j],
                                               (w_small, h_small),
                                               interpolation=cv2.INTER_AREA)

        time_name = os.path.basename(files[idx + dt])[:-4]
        path = save_dir + '/' + time_name
        if not os.path.exists(path):
            os.makedirs(path)
        if not os.path.exists(path + '/pred'):
            os.makedirs(path + '/pred')
        if not os.path.exists(path + '/label'):
            os.makedirs(path + '/label')
        for i in range(pred_resized.shape[0]):
            for j in range(global_config['OUT_TARGET_LEN']):
                alpha_mask = (label_small[i, j] > 0.2).astype(np.uint8) * 255
                lb = rainfall_shade(label_small[i, j], mode='BGR')
                lb = np.dstack((lb, alpha_mask))

                alpha_mask = (pred[i, j] > 0.2).astype(np.uint8) * 255
                prd = rainfall_shade(pred[i, j], mode='BGR')
                prd = np.dstack((prd, alpha_mask))

                cv2.imwrite(path + '/label/' + str(j) + '.png', lb)
                cv2.imwrite(path + '/pred/' + str(j) + '.png', prd)
                np.savetxt(path + '/metrics.txt',
                           [csi, micro_csi, np.mean(csi_multi)] +
                           list(csi_multi) + list(mean_rmse),
                           delimiter=',',
                           fmt='%.2f')

        sliced_input[:, :-1] = sliced_input[:, 1:]
        next_input = (mm_dbz(read_file(files[idx + dt], h, w, resize=True)) -
                      global_config['NORM_MIN']) / global_config['NORM_DIV']
        sliced_input[:, -1] = torch.from_numpy(next_input).to(config['DEVICE'])

        sliced_label[:, :-1] = sliced_label[:, 1:]
        sliced_label[:, -1] = read_file(
            files[idx + global_config['OUT_TARGET_LEN'] + dt])
Ejemplo n.º 5
0
import glob
import os
import cv2
import numpy as np
import torch
from utils.units import mm_dbz, dbz_mm, denorm, torch_denorm
from utils.visualizers import make_gif_color, rainfall_shade, make_gif, make_gif_color_label
from utils.evaluators import fp_fn_image_csi, cal_rmse_all, fp_fn_image_csi_muti, torch_cal_rmse_all
from global_config import global_config
from models.unet.model import UNet2D
from tqdm import tqdm

rs_img = torch.nn.Upsample(size=(global_config['DATA_HEIGHT'],
                                 global_config['DATA_WIDTH']),
                           mode='bilinear')
thres = (mm_dbz(0.5) - global_config['NORM_MIN']) / global_config['NORM_DIV']


def read_file(file_name, h=None, w=None, resize=False):
    f = np.fromfile(file_name, dtype=np.float32) \
            .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
    if resize:
        f = cv2.resize(f, (w, h), interpolation=cv2.INTER_AREA)
    return f


def extract(model, config, save_dir, files, file_name, crop=None):

    idx = 0
    try:
        idx = next(i for i, f in enumerate(files)
Ejemplo n.º 6
0
def test(model, data_loader, config, save_dir, files, file_name, crop=None):

    h1, h2, w1, w2 = 0, global_config['DATA_HEIGHT'] - \
        1, 0, global_config['DATA_WIDTH'] - 1
    if crop is not None:
        h1, h2, w1, w2 = get_crop_boundary_idx(crop)

    idx = 0
    try:
        idx = next(i for i, f in enumerate(files)
                   if os.path.basename(f) == file_name)
    except:
        print('not found')
        return

    scale = config['SCALE']
    h = int(global_config['DATA_HEIGHT'] * scale)
    w = int(global_config['DATA_WIDTH'] * scale)
    sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32)
    sliced_label = np.zeros(
        (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), dtype=np.float32)

    for i, j in enumerate(range(idx - config['IN_LEN'], idx)):
        f = np.fromfile(files[j], dtype=np.float32) \
            .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
        sliced_input[0, i] = \
            cv2.resize(f, (w, h), interpolation=cv2.INTER_AREA)

    for i, j in enumerate(range(idx, idx+global_config['OUT_TARGET_LEN'])):
        sliced_label[0, i] = np.fromfile(files[j], dtype=np.float32) \
            .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))

    sliced_input = (mm_dbz(sliced_input) -
                    global_config['NORM_MIN']) / global_config['NORM_DIV']
    sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE'])

    if config['DIM'] == '3D':
        sliced_input = sliced_input[:, None, :]
    elif config['DIM'] == '2D':
        sliced_input = sliced_input[:, :, None]

    outputs = None
    with torch.no_grad():
        for t in range(int(np.ceil(global_config['OUT_TARGET_LEN']/config['OUT_LEN']))):
            # print('input data', sliced_input.shape)
            output = model(sliced_input)
            # print('output', output.shape)
            if outputs is None:
                outputs = output.detach().cpu().numpy()
            else:
                if config['DIM'] == '3D':
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=2)
                else:
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=1)

            if config['DIM'] == '3D':
                if config['OPTFLOW']:
                    sliced_input = torch.cat([sliced_input[:, :, -1, None], output], axis=2)
                else:
                    sliced_input = output
            else:
                sliced_input = torch.cat([sliced_input[:, config['OUT_LEN']:], output], dim=1)

    pred = np.array(outputs)
    if config['DIM'] == '3D':
        pred = pred[:, 0]
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    elif config['DIM'] == '3D':
        pred = pred[:, :, 0]
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    else:
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    # print('pred label shape', pred.shape, sliced_label.shape)
    pred = denorm(pred)
    pred_resized = np.zeros(
        (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
    for i in range(pred.shape[0]):
        for j in range(pred.shape[1]):
            pred_resized[i, j] = cv2.resize(
                pred[i, j], (global_config['DATA_WIDTH'], global_config['DATA_HEIGHT']), interpolation=cv2.INTER_AREA)

    pred_resized = pred_resized[:, :, h1: h2 + 1, w1: w2 + 1]
    sliced_label = sliced_label[:, :, h1: h2 + 1, w1: w2 + 1]
    # csi = fp_fn_image_csi(pred_resized, sliced_label)
    csis = []
    for c in range(pred.shape[1]):
        csis.append(fp_fn_image_csi(pred_resized[:, c], sliced_label[:, c]))
    csi = np.mean(csis)
    csi_multi, macro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label)
    rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred_resized, sliced_label)
    result_all = [csi] + list(csi_multi) + [rmse, rmse_rain, rmse_non_rain]

    h_small = int(pred_resized.shape[2] * 0.5)
    w_small = int(pred_resized.shape[3] * 0.5)

    pred_small = np.zeros(
        (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
    label_small = np.zeros(
        (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
    for i in range(sliced_label.shape[0]):
        for j in range(sliced_label.shape[1]):
            pred_small[i, j] = cv2.resize(
                pred_resized[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA)
            label_small[i, j] = cv2.resize(
                sliced_label[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA)

    path = save_dir + '/imgs'
    if not os.path.exists(path):
        os.makedirs(path)
    for i in range(pred_resized.shape[0]):
        # Save pred gif
        # make_gif(pred[i] / 80 * 255, path + '/pred_{}_{}.gif'.format(b, i))
        # Save colored pred gif
        make_gif_color(pred_small[i], path + '/pred_colored.gif')
        # Save gt gif
        # make_gif(label_small[i] / 80 * 255, path + '/gt_{}_{}.gif'.format(b, i))
        # Save colored gt gif
        make_gif_color(label_small[i], path + '/gt_colored.gif')

        labels = [os.path.basename(files[idx+i]) for i in range(global_config['OUT_TARGET_LEN'])]
        make_gif_color_label(label_small[i], pred_small[i], labels, fname=path + '/all.gif')

    fig, ax = plt.subplots(figsize=(8, 4), facecolor='white')
    ax.plot(np.arange(len(csis))+1, csis)
    ax.set_xticks(np.arange(global_config['OUT_TARGET_LEN'])+1)
    ax.set_ylabel('Binary - CSI')
    ax.set_xlabel('Time Steps')
    plt.savefig(path + '/csis.png')

    result_all = np.array(result_all)
    result_all = np.around(result_all, decimals=3)
    np.savetxt(save_dir + '/result.txt', result_all, delimiter=',', fmt='%.3f')
    np.savetxt(save_dir + '/csi.txt', np.array(csis), delimiter=',', fmt='%.3f')
Ejemplo n.º 7
0
def conv_test(model_path,
              start_pred_fn,
              test_case,
              in_len,
              out_len,
              batch_size,
              multitask,
              crop=None):

    config = {
        'DEVICE': torch.device('cuda:0'),
        'IN_LEN': in_len,
        'OUT_LEN': out_len,
        'BATCH_SIZE': batch_size,
    }

    convlstm_encoder_params = [[
        OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}),
        OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}),
        OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}),
    ],
                               [
                                   ConvLSTM(input_channel=8,
                                            num_filter=64,
                                            b_h_w=(batch_size, 96, 96),
                                            kernel_size=3,
                                            stride=1,
                                            padding=1,
                                            config=config),
                                   ConvLSTM(input_channel=192,
                                            num_filter=192,
                                            b_h_w=(batch_size, 32, 32),
                                            kernel_size=3,
                                            stride=1,
                                            padding=1,
                                            config=config),
                                   ConvLSTM(input_channel=192,
                                            num_filter=192,
                                            b_h_w=(batch_size, 16, 16),
                                            kernel_size=3,
                                            stride=1,
                                            padding=1,
                                            config=config),
                               ]]

    convlstm_forecaster_params = [[
        OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}),
        OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}),
        OrderedDict({
            'deconv3_leaky_1': [64, 8, 7, 5, 1],
            'conv3_leaky_2': [8, 8, 3, 1, 1],
            'conv3_3': [8, 1, 1, 1, 0]
        }),
    ],
                                  [
                                      ConvLSTM(input_channel=192,
                                               num_filter=192,
                                               b_h_w=(batch_size, 16, 16),
                                               kernel_size=3,
                                               stride=1,
                                               padding=1,
                                               config=config),
                                      ConvLSTM(input_channel=192,
                                               num_filter=192,
                                               b_h_w=(batch_size, 32, 32),
                                               kernel_size=3,
                                               stride=1,
                                               padding=1,
                                               config=config),
                                      ConvLSTM(input_channel=64,
                                               num_filter=64,
                                               b_h_w=(batch_size, 96, 96),
                                               kernel_size=3,
                                               stride=1,
                                               padding=1,
                                               config=config),
                                  ]]

    encoder = Encoder(convlstm_encoder_params[0],
                      convlstm_encoder_params[1]).to(config['DEVICE'])
    forecaster = Forecaster(convlstm_forecaster_params[0],
                            convlstm_forecaster_params[1],
                            config=config).to(config['DEVICE'])
    model = EF(encoder, forecaster).to(config['DEVICE'])
    model.load_state_dict(torch.load(model_path, map_location='cuda'))

    data = get_data(start_pred_fn, crop=crop, config=config)
    data = mm_dbz(data)

    weight = global_config['MERGE_WEIGHT']

    pred, label = prepare_testing(data, model, weight=weight, config=config)
    pred = np.maximum(dbz_mm(pred), 0)
    label = dbz_mm(label)
    csi = fp_fn_image_csi(pred, label)
    # print('CSI: ', csi)
    csi_multi = fp_fn_image_csi_muti_reg(pred, label)
    # print('CSI Multi: ', csi_multi)
    rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred, label)
    # print('rmse_all', rmse)
    # print('rmse_rain', rmse_rain)
    # print('rmse_non_rain', rmse_non_rain)

    if not os.path.exists('./imgs_conv'):
        os.makedirs('./imgs_conv')
    path = './imgs_conv/conv_{}_{}_{}_{}/'.format(test_case, in_len, out_len,
                                                  multitask)
    try:
        os.makedirs(path)
    except:
        pass
    try:
        os.makedirs(path + 'imgs/')
    except:
        pass

    #Save erros gif
    errs = np.sqrt(np.square(pred - label))
    make_gif(errs / errs.max() * 255, path + 'errs.gif')
    #Save pred gif
    make_gif(pred / 60 * 255, path + 'pred.gif')
    #Save colored pred gif
    make_gif_color(pred, path + 'pred_colored.gif')
    #Save gt gif
    make_gif(label / 60 * 255, path + 'gt.gif')
    #Save colored gt gif
    make_gif_color(label, path + 'gt_colored.gif')
    #Save imgs
    for i in range(pred.shape[0]):
        cv2.imwrite(
            path + 'imgs/' + str(i) + '.png',
            cv2.cvtColor(np.array(pred[i] / 80 * 255, dtype=np.uint8),
                         cv2.COLOR_GRAY2BGR))

    # for i in range(pred.shape[0]):
    #     cv2.imwrite(path+str(i)+'.png',
    #                 cv2.cvtColor(np.array(pred[i]/60*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR))

    # cv2.imwrite('conv_pred_1.png', rainfall_shade(pred[-1]))
    # cv2.imwrite('conv_label_1.png', rainfall_shade(data[-1]))

    return [rmse, rmse_rain, rmse_non_rain], csi, csi_multi
Ejemplo n.º 8
0
def conv_test(model_path, start_pred_fn, test_case, in_len, out_len, batch_size, multitask, crop=None):

    config = {
        'DEVICE': torch.device('cuda:0'),
        'IN_LEN': in_len,
        'OUT_LEN': out_len,
        'BATCH_SIZE': batch_size,
    }

    convlstm_encoder_params = [
        [
            OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}),
            OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}),
            OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}),
        ],

        [
            ConvLSTM(input_channel=8, num_filter=64, b_h_w=(batch_size, 168, 126),
                    kernel_size=3, stride=1, padding=1, config=config),
            ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 56, 42),
                    kernel_size=3, stride=1, padding=1, config=config),
            ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 28, 21),
                    kernel_size=3, stride=1, padding=1, config=config),
        ]
    ]

    convlstm_forecaster_params = [
        [
            OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}),
            OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}),
            OrderedDict({
                'deconv3_leaky_1': [64, 8, 7, 5, 1],
                'conv3_leaky_2': [8, 8, 3, 1, 1],
                'conv3_3': [8, 1, 1, 1, 0]
            }),
        ],

        [
            ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 28, 21),
                    kernel_size=3, stride=1, padding=1, config=config),
            ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 56, 42),
                    kernel_size=3, stride=1, padding=1, config=config),
            ConvLSTM(input_channel=64, num_filter=64, b_h_w=(batch_size, 168, 126),
                    kernel_size=3, stride=1, padding=1, config=config),
        ]
    ]

    encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(
        config['DEVICE'])
    forecaster = Forecaster(
        convlstm_forecaster_params[0], convlstm_forecaster_params[1], config=config).to(config['DEVICE'])
    model = EF(encoder, forecaster).to(config['DEVICE'])
    model.load_state_dict(
        torch.load(model_path, map_location='cuda'))

    files = sorted([file for file in glob.glob(global_config['TEST_PATH'])])
    idx = 0
    if start_pred_fn != '':
        try:
            idx = next(i for i,f in enumerate(files) if os.path.basename(f) == start_pred_fn)
        except:
            idx = -1
            print('not found')
    scale_div=4
    data = np.zeros((config['IN_LEN'] + global_config['OUT_TARGET_LEN'], int(global_config['DATA_HEIGHT']/scale_div), int((global_config['DATA_WIDTH'] - 40)/scale_div)), dtype=np.float32)
    for i, file in enumerate(files[idx - config['IN_LEN']:idx + global_config['OUT_TARGET_LEN']]):
        fd = np.fromfile(file, dtype=np.float32).reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))[:,20:-20]
        fd = cv2.resize(fd, (int((global_config['DATA_WIDTH'] - 40)/scale_div), int(global_config['DATA_HEIGHT']/scale_div)), interpolation = cv2.INTER_AREA)
        data[i, :] = fd
    
    data = mm_dbz(data)
    input = data[:in_len]
    label = data[in_len:]
    preds = []
    with torch.no_grad():
        input = torch.from_numpy(input[:, None, None]).to(config['DEVICE'])
        for i in range(18):
            pred = model(input)
            pred[pred<0.2 + 0.035*i] -= 0.04*i
            pred = torch.clamp(pred, min=mm_dbz(0), max=mm_dbz(60))
            preds.append(pred.detach().cpu().numpy()[0, 0, 0])
            input = torch.cat((input[1:], pred), 0)
    pred = np.array(preds)
    print(pred.shape, label.shape)
    pred = np.maximum(dbz_mm(pred), 0)
    label = dbz_mm(label)
    csi = fp_fn_image_csi(pred, label)
    # print('CSI: ', csi)
    csi_multi = fp_fn_image_csi_muti_reg(pred, label)
    # print('CSI Multi: ', csi_multi)
    rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred, label)
    # print('rmse_all', rmse)
    # print('rmse_rain', rmse_rain)
    # print('rmse_non_rain', rmse_non_rain)

    if not os.path.exists('./imgs_conv'):
        os.makedirs('./imgs_conv')
    path = './imgs_conv/conv_{}_{}_{}_{}/'.format(test_case, in_len, out_len, multitask)
    try:
        os.makedirs(path)
    except:
        pass
    try:
        os.makedirs(path+'imgs/')
    except:
        pass

    #Save erros gif
    errs = np.sqrt(np.square(pred - label))
    make_gif(errs / errs.max() * 255, path + 'errs.gif')
    #Save pred gif
    make_gif(pred / 60 * 255, path + 'pred.gif')
    #Save colored pred gif
    make_gif_color(pred, path + 'pred_colored.gif')
    #Save gt gif
    make_gif(label / 60 * 255, path + 'gt.gif')
    #Save colored gt gif
    make_gif_color(label, path + 'gt_colored.gif')
    #Save imgs
    for i in range(pred.shape[0]):
        cv2.imwrite(path+'imgs/'+str(i)+'.png', 
                    cv2.cvtColor(np.array(pred[i]/80*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR))
    
    # for i in range(pred.shape[0]):
    #     cv2.imwrite(path+str(i)+'.png',
    #                 cv2.cvtColor(np.array(pred[i]/60*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR))

    # cv2.imwrite('conv_pred_1.png', rainfall_shade(pred[-1]))
    # cv2.imwrite('conv_label_1.png', rainfall_shade(data[-1]))

    return [rmse, rmse_rain, rmse_non_rain], csi, csi_multi