コード例 #1
0
ファイル: extract.py プロジェクト: doliolarzz/fcn_senior
def extract(model, config, save_dir, files, file_name, crop=None):

    idx = 0
    try:
        idx = next(i for i, f in enumerate(files)
                   if os.path.basename(f) == file_name)
    except:
        print('not found')
        return

    scale = config['SCALE']
    h = int(global_config['DATA_HEIGHT'] * scale)
    w = int(global_config['DATA_WIDTH'] * scale)
    sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32)
    sliced_label = np.zeros(
        (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'],
         global_config['DATA_WIDTH']),
        dtype=np.float32)

    for i, j in enumerate(range(idx - config['IN_LEN'], idx)):
        sliced_input[0, i] = read_file(files[j], h, w, resize=True)

    for i, j in enumerate(range(idx, idx + global_config['OUT_TARGET_LEN'])):
        sliced_label[0, i] = read_file(files[j])

    sliced_input = (mm_dbz(sliced_input) -
                    global_config['NORM_MIN']) / global_config['NORM_DIV']
    sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE'])[:, None]

    save_dir = save_dir + '/extracted'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model.eval()
    for dt in tqdm(range(6 * 24)):
        cur_input = sliced_input.clone()
        out_time = int(
            np.ceil(global_config['OUT_TARGET_LEN'] / config['OUT_LEN']))
        outputs = None
        with torch.no_grad():
            for t in range(out_time):
                output = model(cur_input)
                if outputs is None:
                    outputs = output.detach().cpu().numpy()
                else:
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=2)
                if config['OPTFLOW']:
                    cur_input = torch.cat([cur_input[:, :, -1, None], output],
                                          axis=2)
                else:
                    cur_input = output
        pred = np.array(outputs)[:, ]
        pred = pred[:, 0, :global_config['OUT_TARGET_LEN']]
        pred = denorm(pred)
        pred_resized = np.zeros(
            (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'],
             global_config['DATA_WIDTH']))
        for i in range(pred.shape[0]):
            for j in range(pred.shape[1]):
                pred_resized[i, j] = cv2.resize(pred[i, j],
                                                (global_config['DATA_WIDTH'],
                                                 global_config['DATA_HEIGHT']),
                                                interpolation=cv2.INTER_AREA)

        csi = fp_fn_image_csi(pred_resized, sliced_label)
        csi_multi, micro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label)
        sum_rmse = np.zeros((3, ), dtype=np.float32)
        for i in range(global_config['OUT_TARGET_LEN']):
            rmse, rmse_rain, rmse_non_rain = torch_cal_rmse_all(
                torch.from_numpy(pred_resized[:, i]).to(config['CAL_DEVICE']),
                torch.from_numpy(sliced_label[:, i]).to(config['CAL_DEVICE']))
            sum_rmse += np.array([
                rmse.cpu().numpy(),
                rmse_rain.cpu().numpy(),
                rmse_non_rain.cpu().numpy()
            ])
        mean_rmse = sum_rmse / global_config['OUT_TARGET_LEN']

        h_small = pred.shape[2]
        w_small = pred.shape[3]
        label_small = np.zeros(
            (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
        for i in range(sliced_label.shape[0]):
            for j in range(sliced_label.shape[1]):
                label_small[i, j] = cv2.resize(sliced_label[i, j],
                                               (w_small, h_small),
                                               interpolation=cv2.INTER_AREA)

        time_name = os.path.basename(files[idx + dt])[:-4]
        path = save_dir + '/' + time_name
        if not os.path.exists(path):
            os.makedirs(path)
        if not os.path.exists(path + '/pred'):
            os.makedirs(path + '/pred')
        if not os.path.exists(path + '/label'):
            os.makedirs(path + '/label')
        for i in range(pred_resized.shape[0]):
            for j in range(global_config['OUT_TARGET_LEN']):
                alpha_mask = (label_small[i, j] > 0.2).astype(np.uint8) * 255
                lb = rainfall_shade(label_small[i, j], mode='BGR')
                lb = np.dstack((lb, alpha_mask))

                alpha_mask = (pred[i, j] > 0.2).astype(np.uint8) * 255
                prd = rainfall_shade(pred[i, j], mode='BGR')
                prd = np.dstack((prd, alpha_mask))

                cv2.imwrite(path + '/label/' + str(j) + '.png', lb)
                cv2.imwrite(path + '/pred/' + str(j) + '.png', prd)
                np.savetxt(path + '/metrics.txt',
                           [csi, micro_csi, np.mean(csi_multi)] +
                           list(csi_multi) + list(mean_rmse),
                           delimiter=',',
                           fmt='%.2f')

        sliced_input[:, :-1] = sliced_input[:, 1:]
        next_input = (mm_dbz(read_file(files[idx + dt], h, w, resize=True)) -
                      global_config['NORM_MIN']) / global_config['NORM_DIV']
        sliced_input[:, -1] = torch.from_numpy(next_input).to(config['DEVICE'])

        sliced_label[:, :-1] = sliced_label[:, 1:]
        sliced_label[:, -1] = read_file(
            files[idx + global_config['OUT_TARGET_LEN'] + dt])
コード例 #2
0
def test(model, data_loader, config, save_dir, files, file_name, crop=None):

    h1, h2, w1, w2 = 0, global_config['DATA_HEIGHT'] - \
        1, 0, global_config['DATA_WIDTH'] - 1
    if crop is not None:
        h1, h2, w1, w2 = get_crop_boundary_idx(crop)

    idx = 0
    try:
        idx = next(i for i, f in enumerate(files)
                   if os.path.basename(f) == file_name)
    except:
        print('not found')
        return

    scale = config['SCALE']
    h = int(global_config['DATA_HEIGHT'] * scale)
    w = int(global_config['DATA_WIDTH'] * scale)
    sliced_input = np.zeros((1, config['IN_LEN'], h, w), dtype=np.float32)
    sliced_label = np.zeros(
        (1, global_config['OUT_TARGET_LEN'], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']), dtype=np.float32)

    for i, j in enumerate(range(idx - config['IN_LEN'], idx)):
        f = np.fromfile(files[j], dtype=np.float32) \
            .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
        sliced_input[0, i] = \
            cv2.resize(f, (w, h), interpolation=cv2.INTER_AREA)

    for i, j in enumerate(range(idx, idx+global_config['OUT_TARGET_LEN'])):
        sliced_label[0, i] = np.fromfile(files[j], dtype=np.float32) \
            .reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))

    sliced_input = (mm_dbz(sliced_input) -
                    global_config['NORM_MIN']) / global_config['NORM_DIV']
    sliced_input = torch.from_numpy(sliced_input).to(config['DEVICE'])

    if config['DIM'] == '3D':
        sliced_input = sliced_input[:, None, :]
    elif config['DIM'] == '2D':
        sliced_input = sliced_input[:, :, None]

    outputs = None
    with torch.no_grad():
        for t in range(int(np.ceil(global_config['OUT_TARGET_LEN']/config['OUT_LEN']))):
            # print('input data', sliced_input.shape)
            output = model(sliced_input)
            # print('output', output.shape)
            if outputs is None:
                outputs = output.detach().cpu().numpy()
            else:
                if config['DIM'] == '3D':
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=2)
                else:
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=1)

            if config['DIM'] == '3D':
                if config['OPTFLOW']:
                    sliced_input = torch.cat([sliced_input[:, :, -1, None], output], axis=2)
                else:
                    sliced_input = output
            else:
                sliced_input = torch.cat([sliced_input[:, config['OUT_LEN']:], output], dim=1)

    pred = np.array(outputs)
    if config['DIM'] == '3D':
        pred = pred[:, 0]
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    elif config['DIM'] == '3D':
        pred = pred[:, :, 0]
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    else:
        pred = pred[:, :global_config['OUT_TARGET_LEN']]
    # print('pred label shape', pred.shape, sliced_label.shape)
    pred = denorm(pred)
    pred_resized = np.zeros(
        (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))
    for i in range(pred.shape[0]):
        for j in range(pred.shape[1]):
            pred_resized[i, j] = cv2.resize(
                pred[i, j], (global_config['DATA_WIDTH'], global_config['DATA_HEIGHT']), interpolation=cv2.INTER_AREA)

    pred_resized = pred_resized[:, :, h1: h2 + 1, w1: w2 + 1]
    sliced_label = sliced_label[:, :, h1: h2 + 1, w1: w2 + 1]
    # csi = fp_fn_image_csi(pred_resized, sliced_label)
    csis = []
    for c in range(pred.shape[1]):
        csis.append(fp_fn_image_csi(pred_resized[:, c], sliced_label[:, c]))
    csi = np.mean(csis)
    csi_multi, macro_csi = fp_fn_image_csi_muti(pred_resized, sliced_label)
    rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred_resized, sliced_label)
    result_all = [csi] + list(csi_multi) + [rmse, rmse_rain, rmse_non_rain]

    h_small = int(pred_resized.shape[2] * 0.5)
    w_small = int(pred_resized.shape[3] * 0.5)

    pred_small = np.zeros(
        (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
    label_small = np.zeros(
        (sliced_label.shape[0], sliced_label.shape[1], h_small, w_small))
    for i in range(sliced_label.shape[0]):
        for j in range(sliced_label.shape[1]):
            pred_small[i, j] = cv2.resize(
                pred_resized[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA)
            label_small[i, j] = cv2.resize(
                sliced_label[i, j], (w_small, h_small), interpolation=cv2.INTER_AREA)

    path = save_dir + '/imgs'
    if not os.path.exists(path):
        os.makedirs(path)
    for i in range(pred_resized.shape[0]):
        # Save pred gif
        # make_gif(pred[i] / 80 * 255, path + '/pred_{}_{}.gif'.format(b, i))
        # Save colored pred gif
        make_gif_color(pred_small[i], path + '/pred_colored.gif')
        # Save gt gif
        # make_gif(label_small[i] / 80 * 255, path + '/gt_{}_{}.gif'.format(b, i))
        # Save colored gt gif
        make_gif_color(label_small[i], path + '/gt_colored.gif')

        labels = [os.path.basename(files[idx+i]) for i in range(global_config['OUT_TARGET_LEN'])]
        make_gif_color_label(label_small[i], pred_small[i], labels, fname=path + '/all.gif')

    fig, ax = plt.subplots(figsize=(8, 4), facecolor='white')
    ax.plot(np.arange(len(csis))+1, csis)
    ax.set_xticks(np.arange(global_config['OUT_TARGET_LEN'])+1)
    ax.set_ylabel('Binary - CSI')
    ax.set_xlabel('Time Steps')
    plt.savefig(path + '/csis.png')

    result_all = np.array(result_all)
    result_all = np.around(result_all, decimals=3)
    np.savetxt(save_dir + '/result.txt', result_all, delimiter=',', fmt='%.3f')
    np.savetxt(save_dir + '/csi.txt', np.array(csis), delimiter=',', fmt='%.3f')
コード例 #3
0
def test(model, data_loader, config, save_dir, crop=None):

    save_dir = save_dir + '/res'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model.eval()
    result_all = []
    n_test = data_loader.n_test_batch()
    test_idx = np.arange(0, n_test - 3 * global_config['OUT_TARGET_LEN'],
                         3 * global_config['OUT_TARGET_LEN'])
    # np.random.seed(42)
    # np.random.shuffle(test_idx)
    all_csis = np.zeros((global_config['OUT_TARGET_LEN'], ))
    for b in tqdm(test_idx):
        # for b in tqdm(range(n_test)):
        data, label = data_loader.get_test(b)
        # per = torch.sum(data >= thres)
        # if per <= 4e4:
        #     print('skip', per)
        #     continue
        outputs = None
        out_time = int(
            np.ceil(global_config['OUT_TARGET_LEN'] / config['OUT_LEN']))
        with torch.no_grad():
            for t in range(out_time):
                output = model(data)
                if outputs is None:
                    outputs = output.detach().cpu().numpy()
                else:
                    outputs = np.concatenate(
                        [outputs, output.detach().cpu().numpy()], axis=0)

                data = output[-config['IN_LEN']:]
        pred = np.array(outputs)[:global_config['OUT_TARGET_LEN'], 0,
                                 0][None, :]
        label = label[0, None]
        pred = denorm(pred)
        pred_resized = np.zeros(
            (pred.shape[0], pred.shape[1], global_config['DATA_HEIGHT'],
             global_config['DATA_WIDTH']))
        for i in range(pred.shape[0]):
            for j in range(pred.shape[1]):
                pred_resized[i, j] = cv2.resize(pred[i, j],
                                                (global_config['DATA_WIDTH'],
                                                 global_config['DATA_HEIGHT']),
                                                interpolation=cv2.INTER_AREA)
        # don't need to denorm test
        csi_time = []
        for t in range(global_config['OUT_TARGET_LEN']):
            csi_time.append(fp_fn_image_csi(pred_resized[:, t], label[:, t]))
        csi = np.mean(csi_time)
        all_csis += csi_time
        csi_multi, macro_csi = fp_fn_image_csi_muti(pred_resized, label)
        # rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred_resized, label)
        sum_rmse = np.zeros((3, ), dtype=np.float32)
        for i in range(global_config['OUT_TARGET_LEN']):
            rmse, rmse_rain, rmse_non_rain = torch_cal_rmse_all(
                torch.from_numpy(pred_resized[:, i]).to(config['CAL_DEVICE']),
                torch.from_numpy(label[:, i]).to(config['CAL_DEVICE']))
            sum_rmse += np.array([
                rmse.cpu().numpy(),
                rmse_rain.cpu().numpy(),
                rmse_non_rain.cpu().numpy()
            ])
        mean_rmse = sum_rmse / global_config['OUT_TARGET_LEN']
        result_all.append([csi, macro_csi] + list(csi_multi) + list(mean_rmse))

        h_small = pred.shape[2]
        w_small = pred.shape[3]
        label_small = np.zeros(
            (label.shape[0], label.shape[1], h_small, w_small))
        for i in range(label.shape[0]):
            for j in range(label.shape[1]):
                label_small[i, j] = cv2.resize(label[i, j], (w_small, h_small),
                                               interpolation=cv2.INTER_AREA)

        path = save_dir + '/imgs'
        if not os.path.exists(path):
            os.makedirs(path)
        for i in range(pred_resized.shape[0]):
            # Save pred gif
            #             make_gif(pred[i] / 80 * 255, path + '/pred_{}_{}.gif'.format(b, i))
            # Save colored pred gif
            # make_gif_color(pred[i], path + '/pred_colored.gif')
            # Save gt gif
            #             make_gif(label_small[i] / 80 * 255, path + '/gt_{}_{}.gif'.format(b, i))
            # Save colored gt gif
            # make_gif_color(label_small[i], path + '/gt_colored.gif')

            labels = ['' for i in range(global_config['OUT_TARGET_LEN'])]
            make_gif_color_label(label_small[i],
                                 pred[i],
                                 labels,
                                 fname=path + '/{}.gif'.format(b))

    all_csis /= len(test_idx)
    np.savetxt(save_dir + '/csi_time.txt', all_csis, delimiter=',', fmt='%.3f')

    result_all = np.array(result_all)
    result_all_mean = np.mean(result_all, axis=0)
    result_all_mean = np.around(result_all_mean, decimals=3)
    np.savetxt(save_dir + '/result.txt',
               result_all_mean,
               delimiter=',',
               fmt='%.3f')
コード例 #4
0
def conv_test(model_path,
              start_pred_fn,
              test_case,
              in_len,
              out_len,
              batch_size,
              multitask,
              crop=None):

    config = {
        'DEVICE': torch.device('cuda:0'),
        'IN_LEN': in_len,
        'OUT_LEN': out_len,
        'BATCH_SIZE': batch_size,
    }

    convlstm_encoder_params = [[
        OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}),
        OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}),
        OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}),
    ],
                               [
                                   ConvLSTM(input_channel=8,
                                            num_filter=64,
                                            b_h_w=(batch_size, 96, 96),
                                            kernel_size=3,
                                            stride=1,
                                            padding=1,
                                            config=config),
                                   ConvLSTM(input_channel=192,
                                            num_filter=192,
                                            b_h_w=(batch_size, 32, 32),
                                            kernel_size=3,
                                            stride=1,
                                            padding=1,
                                            config=config),
                                   ConvLSTM(input_channel=192,
                                            num_filter=192,
                                            b_h_w=(batch_size, 16, 16),
                                            kernel_size=3,
                                            stride=1,
                                            padding=1,
                                            config=config),
                               ]]

    convlstm_forecaster_params = [[
        OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}),
        OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}),
        OrderedDict({
            'deconv3_leaky_1': [64, 8, 7, 5, 1],
            'conv3_leaky_2': [8, 8, 3, 1, 1],
            'conv3_3': [8, 1, 1, 1, 0]
        }),
    ],
                                  [
                                      ConvLSTM(input_channel=192,
                                               num_filter=192,
                                               b_h_w=(batch_size, 16, 16),
                                               kernel_size=3,
                                               stride=1,
                                               padding=1,
                                               config=config),
                                      ConvLSTM(input_channel=192,
                                               num_filter=192,
                                               b_h_w=(batch_size, 32, 32),
                                               kernel_size=3,
                                               stride=1,
                                               padding=1,
                                               config=config),
                                      ConvLSTM(input_channel=64,
                                               num_filter=64,
                                               b_h_w=(batch_size, 96, 96),
                                               kernel_size=3,
                                               stride=1,
                                               padding=1,
                                               config=config),
                                  ]]

    encoder = Encoder(convlstm_encoder_params[0],
                      convlstm_encoder_params[1]).to(config['DEVICE'])
    forecaster = Forecaster(convlstm_forecaster_params[0],
                            convlstm_forecaster_params[1],
                            config=config).to(config['DEVICE'])
    model = EF(encoder, forecaster).to(config['DEVICE'])
    model.load_state_dict(torch.load(model_path, map_location='cuda'))

    data = get_data(start_pred_fn, crop=crop, config=config)
    data = mm_dbz(data)

    weight = global_config['MERGE_WEIGHT']

    pred, label = prepare_testing(data, model, weight=weight, config=config)
    pred = np.maximum(dbz_mm(pred), 0)
    label = dbz_mm(label)
    csi = fp_fn_image_csi(pred, label)
    # print('CSI: ', csi)
    csi_multi = fp_fn_image_csi_muti_reg(pred, label)
    # print('CSI Multi: ', csi_multi)
    rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred, label)
    # print('rmse_all', rmse)
    # print('rmse_rain', rmse_rain)
    # print('rmse_non_rain', rmse_non_rain)

    if not os.path.exists('./imgs_conv'):
        os.makedirs('./imgs_conv')
    path = './imgs_conv/conv_{}_{}_{}_{}/'.format(test_case, in_len, out_len,
                                                  multitask)
    try:
        os.makedirs(path)
    except:
        pass
    try:
        os.makedirs(path + 'imgs/')
    except:
        pass

    #Save erros gif
    errs = np.sqrt(np.square(pred - label))
    make_gif(errs / errs.max() * 255, path + 'errs.gif')
    #Save pred gif
    make_gif(pred / 60 * 255, path + 'pred.gif')
    #Save colored pred gif
    make_gif_color(pred, path + 'pred_colored.gif')
    #Save gt gif
    make_gif(label / 60 * 255, path + 'gt.gif')
    #Save colored gt gif
    make_gif_color(label, path + 'gt_colored.gif')
    #Save imgs
    for i in range(pred.shape[0]):
        cv2.imwrite(
            path + 'imgs/' + str(i) + '.png',
            cv2.cvtColor(np.array(pred[i] / 80 * 255, dtype=np.uint8),
                         cv2.COLOR_GRAY2BGR))

    # for i in range(pred.shape[0]):
    #     cv2.imwrite(path+str(i)+'.png',
    #                 cv2.cvtColor(np.array(pred[i]/60*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR))

    # cv2.imwrite('conv_pred_1.png', rainfall_shade(pred[-1]))
    # cv2.imwrite('conv_label_1.png', rainfall_shade(data[-1]))

    return [rmse, rmse_rain, rmse_non_rain], csi, csi_multi
コード例 #5
0
def conv_test(model_path, start_pred_fn, test_case, in_len, out_len, batch_size, multitask, crop=None):

    config = {
        'DEVICE': torch.device('cuda:0'),
        'IN_LEN': in_len,
        'OUT_LEN': out_len,
        'BATCH_SIZE': batch_size,
    }

    convlstm_encoder_params = [
        [
            OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}),
            OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}),
            OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}),
        ],

        [
            ConvLSTM(input_channel=8, num_filter=64, b_h_w=(batch_size, 168, 126),
                    kernel_size=3, stride=1, padding=1, config=config),
            ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 56, 42),
                    kernel_size=3, stride=1, padding=1, config=config),
            ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 28, 21),
                    kernel_size=3, stride=1, padding=1, config=config),
        ]
    ]

    convlstm_forecaster_params = [
        [
            OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}),
            OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}),
            OrderedDict({
                'deconv3_leaky_1': [64, 8, 7, 5, 1],
                'conv3_leaky_2': [8, 8, 3, 1, 1],
                'conv3_3': [8, 1, 1, 1, 0]
            }),
        ],

        [
            ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 28, 21),
                    kernel_size=3, stride=1, padding=1, config=config),
            ConvLSTM(input_channel=192, num_filter=192, b_h_w=(batch_size, 56, 42),
                    kernel_size=3, stride=1, padding=1, config=config),
            ConvLSTM(input_channel=64, num_filter=64, b_h_w=(batch_size, 168, 126),
                    kernel_size=3, stride=1, padding=1, config=config),
        ]
    ]

    encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(
        config['DEVICE'])
    forecaster = Forecaster(
        convlstm_forecaster_params[0], convlstm_forecaster_params[1], config=config).to(config['DEVICE'])
    model = EF(encoder, forecaster).to(config['DEVICE'])
    model.load_state_dict(
        torch.load(model_path, map_location='cuda'))

    files = sorted([file for file in glob.glob(global_config['TEST_PATH'])])
    idx = 0
    if start_pred_fn != '':
        try:
            idx = next(i for i,f in enumerate(files) if os.path.basename(f) == start_pred_fn)
        except:
            idx = -1
            print('not found')
    scale_div=4
    data = np.zeros((config['IN_LEN'] + global_config['OUT_TARGET_LEN'], int(global_config['DATA_HEIGHT']/scale_div), int((global_config['DATA_WIDTH'] - 40)/scale_div)), dtype=np.float32)
    for i, file in enumerate(files[idx - config['IN_LEN']:idx + global_config['OUT_TARGET_LEN']]):
        fd = np.fromfile(file, dtype=np.float32).reshape((global_config['DATA_HEIGHT'], global_config['DATA_WIDTH']))[:,20:-20]
        fd = cv2.resize(fd, (int((global_config['DATA_WIDTH'] - 40)/scale_div), int(global_config['DATA_HEIGHT']/scale_div)), interpolation = cv2.INTER_AREA)
        data[i, :] = fd
    
    data = mm_dbz(data)
    input = data[:in_len]
    label = data[in_len:]
    preds = []
    with torch.no_grad():
        input = torch.from_numpy(input[:, None, None]).to(config['DEVICE'])
        for i in range(18):
            pred = model(input)
            pred[pred<0.2 + 0.035*i] -= 0.04*i
            pred = torch.clamp(pred, min=mm_dbz(0), max=mm_dbz(60))
            preds.append(pred.detach().cpu().numpy()[0, 0, 0])
            input = torch.cat((input[1:], pred), 0)
    pred = np.array(preds)
    print(pred.shape, label.shape)
    pred = np.maximum(dbz_mm(pred), 0)
    label = dbz_mm(label)
    csi = fp_fn_image_csi(pred, label)
    # print('CSI: ', csi)
    csi_multi = fp_fn_image_csi_muti_reg(pred, label)
    # print('CSI Multi: ', csi_multi)
    rmse, rmse_rain, rmse_non_rain = cal_rmse_all(pred, label)
    # print('rmse_all', rmse)
    # print('rmse_rain', rmse_rain)
    # print('rmse_non_rain', rmse_non_rain)

    if not os.path.exists('./imgs_conv'):
        os.makedirs('./imgs_conv')
    path = './imgs_conv/conv_{}_{}_{}_{}/'.format(test_case, in_len, out_len, multitask)
    try:
        os.makedirs(path)
    except:
        pass
    try:
        os.makedirs(path+'imgs/')
    except:
        pass

    #Save erros gif
    errs = np.sqrt(np.square(pred - label))
    make_gif(errs / errs.max() * 255, path + 'errs.gif')
    #Save pred gif
    make_gif(pred / 60 * 255, path + 'pred.gif')
    #Save colored pred gif
    make_gif_color(pred, path + 'pred_colored.gif')
    #Save gt gif
    make_gif(label / 60 * 255, path + 'gt.gif')
    #Save colored gt gif
    make_gif_color(label, path + 'gt_colored.gif')
    #Save imgs
    for i in range(pred.shape[0]):
        cv2.imwrite(path+'imgs/'+str(i)+'.png', 
                    cv2.cvtColor(np.array(pred[i]/80*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR))
    
    # for i in range(pred.shape[0]):
    #     cv2.imwrite(path+str(i)+'.png',
    #                 cv2.cvtColor(np.array(pred[i]/60*255, dtype=np.uint8), cv2.COLOR_GRAY2BGR))

    # cv2.imwrite('conv_pred_1.png', rainfall_shade(pred[-1]))
    # cv2.imwrite('conv_label_1.png', rainfall_shade(data[-1]))

    return [rmse, rmse_rain, rmse_non_rain], csi, csi_multi