def prepare_submission_multifolds(model_name, run, epoch_nums, threshold,
                                  submission_name, use_global_cat):
    run_str = '' if run is None or run == '' else f'_{run}'
    models = []

    model_info = MODELS[model_name]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    predictions_dir = f'output/oof2/{model_name}{run_str}_fold_combined'
    os.makedirs(predictions_dir, exist_ok=True)

    for epoch_num in epoch_nums:
        for fold in range(4):
            checkpoint = f'checkpoints/{model_name}{run_str}_fold_{fold}/{model_name}_{epoch_num:03}.pt'
            print('load', checkpoint)
            model = torch.load(checkpoint, map_location=device)
            model = model.to(device)
            model.eval()
            models.append(model)

    sample_submission = pd.read_csv('data/sample_submission.csv')

    img_size = model_info.img_size
    submission = open(f'submissions/{submission_name}.csv', 'w')
    submission.write('patientId,PredictionString\n')

    for patient_id in sample_submission.patientId:
        dcm_data = pydicom.read_file(f'{config.TEST_DIR}/{patient_id}.dcm')
        img = dcm_data.pixel_array
        # img = img / 255.0
        img = skimage.transform.resize(img, (img_size, img_size), order=1)
        # utils.print_stats('img', img)

        img_tensor = torch.zeros(1, img_size, img_size, 1)
        img_tensor[0, :, :, 0] = torch.from_numpy(img)
        img_tensor = img_tensor.permute(0, 3, 1, 2)
        img_tensor = img_tensor.cuda()

        model_raw_results = []
        for model in models:
            model_raw_results.append(
                model(img_tensor,
                      return_loss=False,
                      return_boxes=False,
                      return_raw=True))

        model_raw_results_mean = []
        for i in range(len(model_raw_results[0])):
            model_raw_results_mean.append(
                sum(r[i] for r in model_raw_results) / len(models))

        nms_scores, global_classification, transformed_anchors = models[
            0].boxes(img_tensor, *model_raw_results_mean)
        # nms_scores, global_classification, transformed_anchors = \
        #     model(img_tensor.cuda(), return_loss=False, return_boxes=True)

        scores = nms_scores.cpu().detach().numpy()
        category = global_classification.cpu().detach().numpy()
        boxes = transformed_anchors.cpu().detach().numpy()
        category = category[0, 2] + 0.1 * category[0, 0]

        if len(scores):
            scores[scores < scores[0] * 0.5] = 0.0

            # if category > 0.5 and scores[0] < 0.2:
            #     scores[0] *= 2

        if use_global_cat:
            mask = scores * category * 10 > threshold
        else:
            mask = scores * 5 > threshold

        submission_str = ''

        # plt.imshow(dcm_data.pixel_array)

        if np.any(mask):
            boxes_selected = p1p2_to_xywh(boxes[mask])  # x y w h format
            boxes_selected *= 1024.0 / img_size
            scores_selected = scores[mask]

            for i in range(scores_selected.shape[0]):
                x, y, w, h = boxes_selected[i]
                submission_str += f' {scores_selected[i]:.3f} {x:.1f} {y:.1f} {w:.1f} {h:.1f}'
                # plt.gca().add_patch(plt.Rectangle((x,y), width=w, height=h, fill=False, edgecolor='r', linewidth=2))

        print(f'{patient_id},{submission_str}      {category:.2f}')
        submission.write(f'{patient_id},{submission_str}\n')
def prepare_submission_from_saved(model_name, run, epoch_nums, threshold,
                                  submission_name, use_global_cat, size_scale):
    run_str = '' if run is None or run == '' else f'_{run}'

    model_info = MODELS[model_name]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    checkpoint = f'checkpoints/{model_name}{run_str}_fold_0/{model_name}_{epoch_nums[0]:03}.pt'
    print('load', checkpoint)
    model = torch.load(checkpoint, map_location=device)
    model = model.to(device)
    model.eval()

    img_size = model_info.img_size
    img_tensor = torch.zeros(1, img_size, img_size, 1).permute(0, 3, 1,
                                                               2).to(device)

    sample_submission = pd.read_csv(config.SAMPLE_SUBMISSION_FILE)

    img_size = model_info.img_size
    submission = open(f'submissions/{submission_name}.csv', 'w')
    submission.write('patientId,PredictionString\n')

    anchors = model.anchors(img_tensor)

    for patient_id in sample_submission.patientId:
        regression_results = []
        classification_results = []
        global_classification_results = []
        # anchors = []

        for epoch_num in epoch_nums:
            for fold in range(4):
                saved_dir = f'{config.TEST_PREDICTIONS_DIR}/{model_name}{run_str}_fold_{fold}/{epoch_num:03}/'
                model_raw_result = pickle.load(
                    open(f'{saved_dir}/{patient_id}.pkl', 'rb'))
                # model_raw_result = [torch.from_numpy(r).to(device) for r in model_raw_result_numpy]

                regression_results.append(model_raw_result[0])
                classification_results.append(model_raw_result[1])
                global_classification_results.append(model_raw_result[2])
                # anchors = model_raw_result[3]  # anchors all the same

        regression_results = np.concatenate(regression_results, axis=0)

        regression_results_pos = regression_results[:, :, :2]
        regression_results_pos = np.mean(regression_results_pos,
                                         axis=0,
                                         keepdims=True)

        regression_results_size = regression_results[:, :, 2:]
        regression_results_size_p80 = np.percentile(regression_results_size,
                                                    q=80,
                                                    axis=0,
                                                    keepdims=True)
        regression_results_size = np.percentile(regression_results_size,
                                                q=20,
                                                axis=0,
                                                keepdims=True)

        regression_results_size += (regression_results_size -
                                    regression_results_size_p80) * size_scale

        regression_results = np.concatenate(
            [regression_results_pos, regression_results_size],
            axis=2).astype(np.float32)

        # regression_results = np.mean(regression_results, axis=0, keepdims=True)

        classification_results = np.concatenate(classification_results, axis=0)
        classification_results = np.mean(classification_results,
                                         axis=0,
                                         keepdims=True)

        global_classification_results = np.concatenate(
            global_classification_results, axis=0)
        global_classification_results = np.mean(global_classification_results,
                                                axis=0,
                                                keepdims=True)

        # model_raw_results_mean = []
        # for i in range(len(model_raw_results[0])):
        #     model_raw_results_mean.append(sum(r[i] for r in model_raw_results)/len(model_raw_results))

        nms_scores, global_classification, transformed_anchors = model.boxes(
            img_tensor,
            torch.from_numpy(regression_results).to(device),
            torch.from_numpy(classification_results).to(device),
            torch.from_numpy(global_classification_results).to(device),
            anchors)
        # nms_scores, global_classification, transformed_anchors = \
        #     model(img_tensor.cuda(), return_loss=False, return_boxes=True)

        scores = nms_scores.cpu().detach().numpy()
        category = global_classification.cpu().detach().numpy()
        boxes = transformed_anchors.cpu().detach().numpy()
        category = category[0, 2] + 0.1 * category[0, 0]

        if len(scores):
            scores[scores < scores[0] * 0.5] = 0.0

            # if category > 0.5 and scores[0] < 0.2:
            #     scores[0] *= 2

        if use_global_cat:
            mask = ((scores * category)**0.5) * 5 > threshold
        else:
            mask = scores * 5 > threshold

        submission_str = ''

        # plt.imshow(dcm_data.pixel_array)

        if np.any(mask):
            boxes_selected = p1p2_to_xywh(boxes[mask])  # x y w h format
            boxes_selected *= 1024.0 / img_size
            scores_selected = scores[mask]

            for i in range(scores_selected.shape[0]):
                x, y, w, h = boxes_selected[i]
                submission_str += f' {scores_selected[i]:.3f} {x:.1f} {y:.1f} {w:.1f} {h:.1f}'
                # plt.gca().add_patch(plt.Rectangle((x,y), width=w, height=h, fill=False, edgecolor='r', linewidth=2))

        print(f'{patient_id},{submission_str}      {category:.2f}')
        submission.write(f'{patient_id},{submission_str}\n')
        # plt.show()
    submission.close()
    check_submission_stat(submission_name)