Exemplo n.º 1
0
def generate_and_save_heave_files():

    root = path_to_echograms()

    echs = get_echograms()
    for i, ech in enumerate(echs):

        if i % 100 == 0:
            print(len(echs), i)

        # Get vertical pixel resolution
        r = ech.range_vector
        r_diff = np.median(r[1:] - r[:-1])

        # Convert heave value from meters to number of pixels
        heave = np.round(ech.heave / r_diff).astype(np.int)
        assert heave.size == ech.shape[1]

        labels_old = ech.label_numpy()
        labels_new = np.zeros_like(labels_old)

        # Create new labels: Move each labels column up/down corresponding to heave
        for x, h in enumerate(list(heave)):
            if h == 0:
                labels_new[:, x] = labels_old[:, x]
            elif h > 0:
                labels_new[:-h, x] = labels_old[h:, x]
            else:
                labels_new[-h:, x] = labels_old[:h, x]

        # Save new labels as new memmap file
        path_save = root + ech.name + '/labels_heave'
        save_memmap(labels_new, path_save, dtype=labels_new.dtype)
Exemplo n.º 2
0
def sampling_echograms_full(args):
    path_to_echograms = paths.path_to_echograms()
    # path_to_echograms = '/Users/changkyu/Documents/GitHub/save_pts/sampled'
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'sampler3_tr.pt'))
    supervised_count = int(len(samplers_train[0]) * args.semi_ratio)
    samplers_supervised = []
    samplers_tr_rest = []
    for samplers in samplers_train:
        samplers_supervised.append(samplers[:supervised_count])
        samplers_tr_rest.append(samplers[supervised_count:])

    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_semi = DatasetImg(samplers_supervised,
                              args.sampler_probs,
                              augmentation_function=augmentation,
                              data_transform_function=data_transform)

    dataset_tr_rest = DatasetImg(samplers_tr_rest,
                                 args.sampler_probs,
                                 augmentation_function=augmentation,
                                 data_transform_function=data_transform)
    return dataset_semi, dataset_tr_rest
def sampling_echograms_full(args):
    tr_ratio = [0.97808653, 0.01301181, 0.00890166]
    path_to_echograms = paths.path_to_echograms()

    ########
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'sampler3_tr.pt'))

    semi_count = int(len(samplers_train[0]) * args.semi_ratio)
    samplers_semi = [samplers[:semi_count] for samplers in samplers_train]

    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(samplers_train,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            data_transform_function=data_transform)

    dataset_semi = DatasetImg(samplers_semi,
                              args.sampler_probs,
                              augmentation_function=augmentation,
                              data_transform_function=data_transform)

    return dataset_cp, dataset_semi
Exemplo n.º 4
0
def sampling_echograms_full_for_comparisonP2(args):
    path_to_echograms = paths.path_to_echograms()
    data = torch.load(os.path.join(path_to_echograms, 'data_tr_TEST_200.pt'))
    label = torch.load(os.path.join(path_to_echograms, 'label_tr_TEST_200.pt'))
    data_transform = CombineFunctions(
        [remove_nan_inf_for_comparisonP2, db_with_limits_for_comparisonP2])
    label_transform = CombineFunctions([
        index_0_1_27_for_comparisonP2,
        relabel_with_threshold_morph_close_for_comparisonP2,
        seabed_checker_for_comparisonP2
    ])

    semi_count = int(len(data) * args.semi_ratio)

    dataset_cp = DatasetImg_for_comparisonP2(
        data=data,
        label=label,
        label_transform_function=label_transform,
        data_transform_function=data_transform)

    dataset_semi = DatasetImg_for_comparisonP2(
        data=data[:semi_count],
        label=label[:semi_count],
        label_transform_function=label_transform,
        data_transform_function=data_transform)

    return dataset_cp, dataset_semi
Exemplo n.º 5
0
def test_and_plot_2019(test_pred_large_2019, test_label_large_2019, epoch, args, idx=2):
    path_to_echograms = paths.path_to_echograms()
    data_2019, label_2019, patch_loc = torch.load(os.path.join(path_to_echograms, 'data_label_patch_loc_te_2019_%d.pt' % idx))
    data_transform = CombineFunctions([remove_nan_inf_for_comparisonP2, db_with_limits_for_comparisonP2])
    label_transform = CombineFunctions([index_0_1_27_for_comparisonP2, relabel_with_threshold_morph_close_for_comparisonP2, seabed_checker_for_comparisonP2])
    boxsize = 5
    plt.figure(figsize=(boxsize * patch_loc[1], boxsize * patch_loc[0] * 4))
    for i in range(len(data_2019)):
        l = label_2019[i]
        d = data_2019[i]
        d, l = label_transform(d, l)
        d, l = data_transform(d, l)
        dim = np.shape(l)

        labels_rgb = np.ones((dim[0], dim[1], 3))
        sandeel = np.where(l == 1)
        other = np.where(l == 2)
        labels_rgb[sandeel[0], sandeel[1], 0] = 0
        labels_rgb[sandeel[0], sandeel[1], 1] = 0  # sandeel blue
        labels_rgb[other[0], other[1], 1] = 0
        labels_rgb[other[0], other[1], 2] = 0  # other red

        pred = test_pred_large_2019[i]
        pred_rgb = np.ones((dim[0], dim[1], 3))
        pred_sandeel = np.where(pred == 1)
        pred_other = np.where(pred == 2)
        pred_rgb[pred_sandeel[0], pred_sandeel[1], 0] = 0
        pred_rgb[pred_sandeel[0], pred_sandeel[1], 1] = 0  # sandeel blue
        pred_rgb[pred_other[0], pred_other[1], 1] = 0
        pred_rgb[pred_other[0], pred_other[1], 2] = 0  # other red

        lbb = test_label_large_2019[i]
        lbb_rgb = np.ones((dim[0], dim[1], 3))
        lbb_sandeel = np.where(lbb == 1)
        lbb_other = np.where(lbb == 2)
        lbb_rgb[lbb_sandeel[0], lbb_sandeel[1], 0] = 0
        lbb_rgb[lbb_sandeel[0], lbb_sandeel[1], 1] = 0  # sandeel blue
        lbb_rgb[lbb_other[0], lbb_other[1], 1] = 0
        lbb_rgb[lbb_other[0], lbb_other[1], 2] = 0  # other red

        plt.subplot(patch_loc[0] * 4, patch_loc[1], i + 1)
        plt.imshow(labels_rgb)

        plt.subplot(patch_loc[0] * 4, patch_loc[1], patch_loc[1] * patch_loc[0] + i + 1)
        plt.imshow(pred_rgb)

        plt.subplot(patch_loc[0] * 4, patch_loc[1], patch_loc[1] * patch_loc[0] * 2 + i + 1)
        plt.imshow(lbb_rgb)

        plt.subplot(patch_loc[0] * 4, patch_loc[1], patch_loc[1] * patch_loc[0] * 3 + i + 1)
        plt.imshow(d[-1])

    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
    plt.tight_layout()
    plt.savefig(os.path.join(args.pred_2019, '%d_data_pred_label_2019_%d_patch.pdf' % (epoch, idx)))
    plt.close()
    return
def sampling_echograms_full(args):
    tr_ratio = [0.97808653, 0.01301181, 0.00890166]
    path_to_echograms = paths.path_to_echograms()

    ########
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'sampler3_tr.pt'))
    samplers_bg = torch.load(
        os.path.join(path_to_echograms, 'train_bg_32766.pt'))

    supervised_count = int(len(samplers_train[0]) * args.semi_ratio)
    total_unsupervised_count = int(
        (len(samplers_train[0]) - supervised_count) * args.nmb_category)
    unlab_size = [int(ratio * total_unsupervised_count) for ratio in tr_ratio]
    if np.sum(unlab_size) != total_unsupervised_count:
        unlab_size[0] += total_unsupervised_count - np.sum(unlab_size)

    samplers_supervised = []
    samplers_unsupervised = []
    for samplers in samplers_train:
        samplers_supervised.append(samplers[:supervised_count])
        samplers_unsupervised.append(samplers[supervised_count:])
    samplers_unsupervised[0].extend(samplers_bg)

    samplers_unbal_unlab = []
    for sampler, size in zip(samplers_unsupervised, unlab_size):
        samplers_unbal_unlab.append(sampler[:size])

    samplers_semi_unbal_unlab_long = []
    for sampler_semi, sampler_unb_unl in zip(samplers_supervised,
                                             samplers_unbal_unlab):
        samplers_semi_unbal_unlab_long.extend(
            np.concatenate([sampler_semi, sampler_unb_unl]))

    list_length = len(samplers_train[0])
    num_classes = len(samplers_train)
    samplers_cp = [
        samplers_semi_unbal_unlab_long[i * list_length:(i + 1) * list_length]
        for i in range(num_classes)
    ]
    ########

    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(samplers_cp,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            data_transform_function=data_transform)

    dataset_semi = DatasetImg(samplers_supervised,
                              args.sampler_probs,
                              augmentation_function=augmentation,
                              data_transform_function=data_transform)

    return dataset_cp, dataset_semi
Exemplo n.º 7
0
def save_all_seabeds():
    """
    Loop through all echograms and generate seabed-estimates
    :return:
    """
    path_to_echograms = paths.path_to_echograms()
    echogram_names = os.listdir(path_to_echograms)
    echograms = [Echogram(path_to_echograms + e) for e in echogram_names]
    for e in echograms:
        e.get_seabed(save_to_file=True, ignore_saved=True)
Exemplo n.º 8
0
def get_echograms_revised(eg_names_full, sample_idx, num_echograms=100):
    path_to_echograms = paths.path_to_echograms()
    index_list = np.arange(len(eg_names_full)//num_echograms)
    if sample_idx not in set(index_list):
        sample_idx = 0
        print('Reset_sample_idx')
    eg_idx = np.arange(num_echograms*sample_idx, num_echograms *(sample_idx+1))
    eg_names = list(map(eg_names_full.__getitem__, eg_idx))
    echograms = [Echogram(os.path.join(path_to_echograms, e)) for e in eg_names]
    sample_idx += 1
    return echograms, sample_idx
Exemplo n.º 9
0
def sampling_echograms_test(args):
    path_to_echograms = paths.path_to_echograms()
    samplers_test = torch.load(
        os.path.join(path_to_echograms, 'sampler6_te.pt'))
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_test = DatasetImg(samplers_test,
                              args.sampler_probs,
                              augmentation_function=None,
                              data_transform_function=data_transform)
    return dataset_test
Exemplo n.º 10
0
def sampling_echograms_full(args):
    path_to_echograms = paths.path_to_echograms()
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'samplers_sb_1024_3.pt'))
    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(samplers_train,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            data_transform_function=data_transform)
    return dataset_cp
Exemplo n.º 11
0
def test_analysis(predictions, predictions_mat, epoch, args):
    path_to_echograms = paths.path_to_echograms()
    labels_origin = torch.load(os.path.join(path_to_echograms, 'label_TEST_60_after_transformation.pt'))
    if np.shape(predictions_mat) == (60, 256, 256, 3):
        predictions_mat = predictions_mat.transpose(0, 3, 1, 2)
    keep_test_idx = np.where(labels_origin > -1)
    labels_vec = labels_origin[keep_test_idx]
    predictions_vec = predictions[keep_test_idx]
    predictions_mat_sampled = predictions_mat[keep_test_idx[0], :, keep_test_idx[1], keep_test_idx[2]]
    fpr, tpr, roc_auc, roc_auc_macro = roc_curve_macro(labels_vec, predictions_mat_sampled)
    prob_mat, mat, f1_score, kappa = conf_mat(ylabel=labels_vec, ypred=predictions_vec, args=args)
    acc_bg, acc_se, acc_ot = prob_mat.diagonal()
    plot_macro(fpr, tpr, roc_auc, epoch, args)
    plot_conf(epoch, prob_mat, mat, f1_score, kappa, args)
    return fpr, tpr, roc_auc, roc_auc_macro, prob_mat, mat, f1_score, kappa, acc_bg, acc_se, acc_ot
Exemplo n.º 12
0
def sampling_echograms_for_s3vm(args):
    path_to_echograms = paths.path_to_echograms()
    samplers_bg = torch.load(
        os.path.join(path_to_echograms, 'train_bg_32766.pt'))
    list_length = len(samplers_bg) // args.nmb_category
    samplers_bg = [
        samplers_bg[i * list_length:(i + 1) * list_length]
        for i in range(args.nmb_category)
    ]
    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_bg_full = DatasetImg(samplers_bg,
                                 args.sampler_probs,
                                 augmentation_function=augmentation,
                                 data_transform_function=data_transform)
    return dataset_bg_full
Exemplo n.º 13
0
def sampling_echograms_test_for_comparisonP2():
    path_to_echograms = paths.path_to_echograms()
    data = torch.load(os.path.join(path_to_echograms, 'data_te_TEST_60.pt'))
    label = torch.load(os.path.join(path_to_echograms, 'label_te_TEST_60.pt'))
    data_transform = CombineFunctions(
        [remove_nan_inf_for_comparisonP2, db_with_limits_for_comparisonP2])
    label_transform = CombineFunctions([
        index_0_1_27_for_comparisonP2,
        relabel_with_threshold_morph_close_for_comparisonP2,
        seabed_checker_for_comparisonP2
    ])

    dataset = DatasetImg_for_comparisonP2(
        data=data,
        label=label,
        label_transform_function=label_transform,
        data_transform_function=data_transform)
    return dataset
Exemplo n.º 14
0
def sampling_echograms_full(window_size, args):
    path_to_echograms = paths.path_to_echograms()
    with open(os.path.join(path_to_echograms, 'memmap_2014_heave.pkl'),
              'rb') as fp:
        eg_names_full = pickle.load(fp)
    echograms = get_echograms_full(eg_names_full)
    echograms_train, echograms_val, echograms_test = cps.partition_data(
        echograms,
        args.partition,
        portion_train_test=0.8,
        portion_train_val=0.75)

    sampler_bg_train = Background(echograms_train, window_size)
    sampler_sh27_train = Shool(echograms_train, window_size, 27)
    sampler_sbsh27_train = ShoolSeabed(echograms_train,
                                       window_size,
                                       args.window_dim // 4,
                                       fish_type=27)
    sampler_sh01_train = Shool(echograms_train, window_size, 1)
    sampler_sbsh01_train = ShoolSeabed(echograms_train,
                                       window_size,
                                       args.window_dim // 4,
                                       fish_type=1)

    samplers_train = [
        sampler_bg_train, sampler_sh27_train, sampler_sbsh27_train,
        sampler_sh01_train, sampler_sbsh01_train
    ]

    augmentation = CombineFunctions([add_noise, flip_x_axis])
    label_transform = CombineFunctions(
        [index_0_1_27, relabel_with_threshold_morph_close])
    data_transform = CombineFunctions([remove_nan_inf, db_with_limits])

    dataset_train = Dataset(samplers_train,
                            window_size,
                            args.frequencies,
                            args.batch * args.iteration_train,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            label_transform_function=label_transform,
                            data_transform_function=data_transform)

    return dataset_train
Exemplo n.º 15
0
def sampling_echograms_2019_for_comparisonP2(echogram_idx=2,
                                             path_to_echograms=None):
    if path_to_echograms == None:
        path_to_echograms = paths.path_to_echograms()
    data, label, patch_loc = torch.load(
        os.path.join(path_to_echograms,
                     'data_label_patch_loc_te_2019_%d.pt' % echogram_idx))
    data_transform = CombineFunctions(
        [remove_nan_inf_for_comparisonP2, db_with_limits_for_comparisonP2])
    label_transform = CombineFunctions([
        index_0_1_27_for_comparisonP2,
        relabel_with_threshold_morph_close_for_comparisonP2,
        seabed_checker_for_comparisonP2
    ])

    dataset_2019 = DatasetImg_for_comparisonP2(
        data=data,
        label=label,
        label_transform_function=label_transform,
        data_transform_function=data_transform)
    return dataset_2019, label, patch_loc
Exemplo n.º 16
0
    parser.add_argument('--partition', type=str, default='train_only',
                        help='echogram partition (tr/val/te) by year')
    parser.add_argument('--iteration_train', type=int, default=1200,
                        help='num_tr_iterations per one batch and epoch')
    parser.add_argument('--sampler_probs', type=list, default=None,
                        help='[bg, sh27, sbsh27, sh01, sbsh01], default=[1, 1, 1, 1, 1]')
    parser.add_argument('--resume',
                        default=os.path.join(current_dir, 'checkpoint.pth.tar'), type=str, metavar='PATH',
                        help='path to checkpoint (default: None)')
    parser.add_argument('--exp', type=str,
                        default=current_dir, help='path to exp folder')
    return parser.parse_args(args=[])

args = parse_args()
window_size = [args.window_dim, args.window_dim]
path_to_echograms = paths.path_to_echograms()
with open(os.path.join(path_to_echograms, 'memmap_2014_heave.pkl'), 'rb') as fp:
    eg_names_full = pickle.load(fp)
echograms = get_echograms_full(eg_names_full)
echograms_train, echograms_val, echograms_test = cps.partition_data(echograms, args.partition, portion_train_test=0.8,
                                                                    portion_train_val=0.75)

sampler_bg_train = Background(echograms_train, window_size)
# sampler_sh27_train = Shool(echograms_train, window_size, 27)
# sampler_sh01_train = Shool(echograms_train, window_size, 1)
# sampler_sbsh27_train = ShoolSeabed(echograms_train, window_size, args.window_dim // 4, fish_type=27)
# sampler_sbsh01_train = ShoolSeabed(echograms_train, window_size, args.window_dim // 4, fish_type=1)


data_bg = DatasetSampler(sampler_bg_train, window_size, args.frequencies)
# data_sh27 = DatasetSampler(sampler_sh27_train, window_size, args.frequencies)
Exemplo n.º 17
0
def sampling_echograms_full(args):
    # idx_2000 = np.random.choice(3000, size=2000, replace=False).tolist()
    # sampler_2000 = []
    # for i in range(5):
    #     sampler_2000.append([samplers_train[i][idx] for idx in idx_2000])
    # torch.save(sampler_2000, 'samplers_2000.pt')
    # path_to_echograms = "/Users/changkyu/Documents/GitHub/echogram/memmap/memmap_set"
    # bg = torch.load(os.path.join(path_to_echograms, 'numpy_bg_2999.pt')) + \
    #      torch.load(os.path.join(path_to_echograms, 'numpy_bg_5999.pt'))
    # bg_idx = np.random.choice(np.arange(len(bg)), size=3000, replace=False)
    # bg = [bg[idx] for idx in bg_idx]
    #
    # sbsh01 = torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_2999.pt')) +\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_5999.pt')) +\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_8999.pt')) +\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_11999.pt')) +\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_12667.pt'))
    # sbsh01_idx = np.random.choice(np.arange(len(sbsh01)), size=3000, replace=False)
    # sbsh01 = [sbsh01[idx] for idx in sbsh01_idx]
    #
    # sbsh27 = torch.load(os.path.join(path_to_echograms, 'numpy_sbsh27_2999.pt'))+\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh27_3079.pt'))
    # sbsh27_idx = np.random.choice(np.arange(len(sbsh27)), size=3000, replace=False)
    # sbsh27 = [sbsh27[idx] for idx in sbsh27_idx]
    #
    # sh01 = torch.load(os.path.join(path_to_echograms, 'numpy_sh01_2999.pt'))+\
    #        torch.load(os.path.join(path_to_echograms, 'numpy_sh01_4046.pt'))
    # sh01_idx = np.random.choice(np.arange(len(sh01)), size=3000, replace=False)
    # sh01 = [sh01[idx] for idx in sh01_idx]
    #
    # sh27 = torch.load(os.path.join(path_to_echograms, 'numpy_sh27_2999.pt'))+\
    #        torch.load(os.path.join(path_to_echograms, 'numpy_sh27_3549.pt'))
    # sh27_idx = np.random.choice(np.arange(len(sh27)), size=3000, replace=False)
    # sh27 = [sh27[idx] for idx in sh27_idx]
    # samplers_train = [bg, sh27, sbsh27, sh01, sbsh01]
    # torch.save(samplers_train, 'samplers_3000.pt')
    # samplers_train = [bg, sh27, sbsh27, sh01, sbsh01]
    # def sample_align(samplers):
    #     num_samples = []
    #     new_samplers = []
    #     for i in range(len(samplers)):
    #         num_samples.append(len(samplers[i]))
    #     max_num_sample = np.min(num_samples)
    #     print(max_num_sample)
    #     for i in range(len(samplers)):
    #         new_samplers.append(np.random.choice(samplers[i], size=max_num_sample, replace=False))
    #     return new_samplers

    # path_to_echograms = '/Users/changkyu/Documents/GitHub/echogram/memmap/memmap_set/'
    path_to_echograms = paths.path_to_echograms()
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'samplers_500_three.pt'))
    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(samplers_train,
                            1500,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            data_transform_function=data_transform)
    return dataset_cp
def sampling_echograms_full(args):
    path_to_echograms = paths.path_to_echograms()
    assert (args.semi_ratio in [0.01, 0.05, 0.1, 0.2]), 'Fix args.semi-ratio in a given range'

    samplers_train = torch.load(os.path.join(path_to_echograms, 'sampler6_tr.pt'))
    samplers_bg_unb = torch.load(os.path.join(path_to_echograms, 'bg_unb_11000_tr.pt'))
    semi_count = int(len(samplers_train[0]) * args.semi_ratio)

    if len(samplers_train) == 6:
        if args.semi_ratio == 0.01:
            unlab_size = [12992, 195, 200, 1344, 90, 29]
        elif args.semi_ratio == 0.05:
            unlab_size = [12468, 187, 192, 1289, 86, 28]
        elif args.semi_ratio == 0.1:
            unlab_size = [11811, 177, 182, 1222, 82, 26]
        elif args.semi_ratio == 0.2:
            unlab_size = [10499, 157, 162, 1086, 73, 23]

    elif len(samplers_train) == 3:  # combined 3classes case
        if args.semi_ratio == 0.01:
            unlab_size = [14336, 285, 229]
        elif args.semi_ratio == 0.05:
            unlab_size = [13757, 273, 220]
        elif args.semi_ratio == 0.1:
            unlab_size = [13033, 259, 208]
        elif args.semi_ratio == 0.2:
            unlab_size = [11585, 230, 185]

    samplers_semi = []
    samplers_rest = []
    for samplers in samplers_train:
        samplers_semi.append(samplers[:semi_count])
        samplers_rest.append(samplers[semi_count:])
    samplers_rest[0].extend(samplers_bg_unb)

    samplers_unbal_unlab = []
    for sampler, size in zip(samplers_rest, unlab_size):
        samplers_unbal_unlab.append(sampler[:size])

    samplers_semi_unbal_unlab_long = []
    for sampler_semi, sampler_unb_unl in zip(samplers_semi, samplers_unbal_unlab):
        samplers_semi_unbal_unlab_long.extend(np.concatenate([sampler_semi, sampler_unb_unl]))

    list_length = len(samplers_train[0])
    num_classes = len(samplers_train)
    samplers_cp = [samplers_semi_unbal_unlab_long[i*list_length: (i+1)*list_length] for i in range(num_classes)]

    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(
        samplers_cp,
        args.sampler_probs,
        augmentation_function=augmentation,
        data_transform_function=data_transform)

    dataset_semi = DatasetImg(
        samplers_semi,
        args.sampler_probs,
        augmentation_function=augmentation,
        data_transform_function=data_transform)

    return dataset_cp, dataset_semi
Exemplo n.º 19
0
def sampling_echograms(sample_idx, window_size, args):
    path_to_echograms = paths.path_to_echograms()
    with open(os.path.join(path_to_echograms, 'memmap_2014_heave.pkl'),
              'rb') as fp:
        eg_names_full = pickle.load(fp)
    print(sample_idx, "IN SAMPLING_ECHOGRAMS")
    sample_idx = int(sample_idx)

    echograms, sample_idx = get_echograms_revised(
        eg_names_full, sample_idx, num_echograms=args.num_echogram)
    echograms_train, echograms_val, echograms_test = cps.partition_data(
        echograms,
        args.partition,
        portion_train_test=0.8,
        portion_train_val=0.75)

    sampler_bg_train = Background(echograms_train, window_size)
    # sampler_sb_train = Seabed(echograms_train, window_size)
    sampler_sh27_train = Shool(echograms_train, window_size, 27)
    sampler_sbsh27_train = ShoolSeabed(echograms_train,
                                       window_size,
                                       args.window_dim // 4,
                                       fish_type=27)
    sampler_sh01_train = Shool(echograms_train, window_size, 1)
    sampler_sbsh01_train = ShoolSeabed(echograms_train,
                                       window_size,
                                       args.window_dim // 4,
                                       fish_type=1)

    # sampler_bg_val = Background(echograms_val, window_size)
    # # sampler_sb_val = Seabed(echograms_val, window_size)
    # sampler_sh27_val = Shool(echograms_val, window_size, 27)
    # sampler_sbsh27_val = ShoolSeabed(echograms_val, window_size, args.window_dim//4, fish_type=27)
    # sampler_sh01_val = Shool(echograms_val, window_size, 1)
    # sampler_sbsh01_val = ShoolSeabed(echograms_val, window_size, args.window_dim//4, fish_type=1)

    samplers_train = [
        sampler_bg_train,  #sampler_sb_train,
        sampler_sh27_train,
        sampler_sbsh27_train,
        sampler_sh01_train,
        sampler_sbsh01_train
    ]

    # samplers_val = [sampler_bg_val, # sampler_sb_val,
    #                 sampler_sh27_val, sampler_sbsh27_val,
    #                 sampler_sh01_val, sampler_sbsh01_val]

    augmentation = CombineFunctions([add_noise, flip_x_axis])
    label_transform = CombineFunctions(
        [index_0_1_27, relabel_with_threshold_morph_close])
    data_transform = CombineFunctions([remove_nan_inf, db_with_limits])

    dataset_train = Dataset(samplers_train,
                            window_size,
                            args.frequencies,
                            args.batch * args.iteration_train,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            label_transform_function=label_transform,
                            data_transform_function=data_transform)

    # dataset_val = DatasetVal(
    #     samplers_val,
    #     window_size,
    #     args.frequencies,
    #     args.batch * args.iteration_val,
    #     args.sampler_probs,
    #     augmentation_function=None,
    #     label_transform_function=label_transform,
    #     data_transform_function=data_transform)

    # val_dataloader = torch.utils.data.DataLoader(dataset_val,
    #                                          shuffle=True,
    #                                          batch_size=args.batch,
    #                                          num_workers=args.workers,
    #                                          pin_memory=True)
    return dataset_train, int(sample_idx)