Exemple #1
0
def sampling_echograms_full(args):
    path_to_echograms = paths.path_to_echograms()
    # path_to_echograms = '/Users/changkyu/Documents/GitHub/save_pts/sampled'
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'sampler3_tr.pt'))
    supervised_count = int(len(samplers_train[0]) * args.semi_ratio)
    samplers_supervised = []
    samplers_tr_rest = []
    for samplers in samplers_train:
        samplers_supervised.append(samplers[:supervised_count])
        samplers_tr_rest.append(samplers[supervised_count:])

    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_semi = DatasetImg(samplers_supervised,
                              args.sampler_probs,
                              augmentation_function=augmentation,
                              data_transform_function=data_transform)

    dataset_tr_rest = DatasetImg(samplers_tr_rest,
                                 args.sampler_probs,
                                 augmentation_function=augmentation,
                                 data_transform_function=data_transform)
    return dataset_semi, dataset_tr_rest
def sampling_echograms_full(args):
    tr_ratio = [0.97808653, 0.01301181, 0.00890166]
    path_to_echograms = paths.path_to_echograms()

    ########
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'sampler3_tr.pt'))

    semi_count = int(len(samplers_train[0]) * args.semi_ratio)
    samplers_semi = [samplers[:semi_count] for samplers in samplers_train]

    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(samplers_train,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            data_transform_function=data_transform)

    dataset_semi = DatasetImg(samplers_semi,
                              args.sampler_probs,
                              augmentation_function=augmentation,
                              data_transform_function=data_transform)

    return dataset_cp, dataset_semi
def sampling_echograms_full(args):
    tr_ratio = [0.97808653, 0.01301181, 0.00890166]
    path_to_echograms = paths.path_to_echograms()

    ########
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'sampler3_tr.pt'))
    samplers_bg = torch.load(
        os.path.join(path_to_echograms, 'train_bg_32766.pt'))

    supervised_count = int(len(samplers_train[0]) * args.semi_ratio)
    total_unsupervised_count = int(
        (len(samplers_train[0]) - supervised_count) * args.nmb_category)
    unlab_size = [int(ratio * total_unsupervised_count) for ratio in tr_ratio]
    if np.sum(unlab_size) != total_unsupervised_count:
        unlab_size[0] += total_unsupervised_count - np.sum(unlab_size)

    samplers_supervised = []
    samplers_unsupervised = []
    for samplers in samplers_train:
        samplers_supervised.append(samplers[:supervised_count])
        samplers_unsupervised.append(samplers[supervised_count:])
    samplers_unsupervised[0].extend(samplers_bg)

    samplers_unbal_unlab = []
    for sampler, size in zip(samplers_unsupervised, unlab_size):
        samplers_unbal_unlab.append(sampler[:size])

    samplers_semi_unbal_unlab_long = []
    for sampler_semi, sampler_unb_unl in zip(samplers_supervised,
                                             samplers_unbal_unlab):
        samplers_semi_unbal_unlab_long.extend(
            np.concatenate([sampler_semi, sampler_unb_unl]))

    list_length = len(samplers_train[0])
    num_classes = len(samplers_train)
    samplers_cp = [
        samplers_semi_unbal_unlab_long[i * list_length:(i + 1) * list_length]
        for i in range(num_classes)
    ]
    ########

    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(samplers_cp,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            data_transform_function=data_transform)

    dataset_semi = DatasetImg(samplers_supervised,
                              args.sampler_probs,
                              augmentation_function=augmentation,
                              data_transform_function=data_transform)

    return dataset_cp, dataset_semi
Exemple #4
0
def sampling_echograms_test(args):
    path_to_echograms = paths.path_to_echograms()
    samplers_test = torch.load(
        os.path.join(path_to_echograms, 'sampler6_te.pt'))
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_test = DatasetImg(samplers_test,
                              args.sampler_probs,
                              augmentation_function=None,
                              data_transform_function=data_transform)
    return dataset_test
def sampling_echograms_full(args):
    path_to_echograms = paths.path_to_echograms()
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'samplers_sb_1024_3.pt'))
    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(samplers_train,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            data_transform_function=data_transform)
    return dataset_cp
Exemple #6
0
def sampling_echograms_for_s3vm(args):
    path_to_echograms = paths.path_to_echograms()
    samplers_bg = torch.load(
        os.path.join(path_to_echograms, 'train_bg_32766.pt'))
    list_length = len(samplers_bg) // args.nmb_category
    samplers_bg = [
        samplers_bg[i * list_length:(i + 1) * list_length]
        for i in range(args.nmb_category)
    ]
    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_bg_full = DatasetImg(samplers_bg,
                                 args.sampler_probs,
                                 augmentation_function=augmentation,
                                 data_transform_function=data_transform)
    return dataset_bg_full
Exemple #7
0
def sampling_echograms_full(args):
    # idx_2000 = np.random.choice(3000, size=2000, replace=False).tolist()
    # sampler_2000 = []
    # for i in range(5):
    #     sampler_2000.append([samplers_train[i][idx] for idx in idx_2000])
    # torch.save(sampler_2000, 'samplers_2000.pt')
    # path_to_echograms = "/Users/changkyu/Documents/GitHub/echogram/memmap/memmap_set"
    # bg = torch.load(os.path.join(path_to_echograms, 'numpy_bg_2999.pt')) + \
    #      torch.load(os.path.join(path_to_echograms, 'numpy_bg_5999.pt'))
    # bg_idx = np.random.choice(np.arange(len(bg)), size=3000, replace=False)
    # bg = [bg[idx] for idx in bg_idx]
    #
    # sbsh01 = torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_2999.pt')) +\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_5999.pt')) +\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_8999.pt')) +\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_11999.pt')) +\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_12667.pt'))
    # sbsh01_idx = np.random.choice(np.arange(len(sbsh01)), size=3000, replace=False)
    # sbsh01 = [sbsh01[idx] for idx in sbsh01_idx]
    #
    # sbsh27 = torch.load(os.path.join(path_to_echograms, 'numpy_sbsh27_2999.pt'))+\
    #          torch.load(os.path.join(path_to_echograms, 'numpy_sbsh27_3079.pt'))
    # sbsh27_idx = np.random.choice(np.arange(len(sbsh27)), size=3000, replace=False)
    # sbsh27 = [sbsh27[idx] for idx in sbsh27_idx]
    #
    # sh01 = torch.load(os.path.join(path_to_echograms, 'numpy_sh01_2999.pt'))+\
    #        torch.load(os.path.join(path_to_echograms, 'numpy_sh01_4046.pt'))
    # sh01_idx = np.random.choice(np.arange(len(sh01)), size=3000, replace=False)
    # sh01 = [sh01[idx] for idx in sh01_idx]
    #
    # sh27 = torch.load(os.path.join(path_to_echograms, 'numpy_sh27_2999.pt'))+\
    #        torch.load(os.path.join(path_to_echograms, 'numpy_sh27_3549.pt'))
    # sh27_idx = np.random.choice(np.arange(len(sh27)), size=3000, replace=False)
    # sh27 = [sh27[idx] for idx in sh27_idx]
    # samplers_train = [bg, sh27, sbsh27, sh01, sbsh01]
    # torch.save(samplers_train, 'samplers_3000.pt')
    # samplers_train = [bg, sh27, sbsh27, sh01, sbsh01]
    # def sample_align(samplers):
    #     num_samples = []
    #     new_samplers = []
    #     for i in range(len(samplers)):
    #         num_samples.append(len(samplers[i]))
    #     max_num_sample = np.min(num_samples)
    #     print(max_num_sample)
    #     for i in range(len(samplers)):
    #         new_samplers.append(np.random.choice(samplers[i], size=max_num_sample, replace=False))
    #     return new_samplers

    # path_to_echograms = '/Users/changkyu/Documents/GitHub/echogram/memmap/memmap_set/'
    path_to_echograms = paths.path_to_echograms()
    samplers_train = torch.load(
        os.path.join(path_to_echograms, 'samplers_500_three.pt'))
    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(samplers_train,
                            1500,
                            args.sampler_probs,
                            augmentation_function=augmentation,
                            data_transform_function=data_transform)
    return dataset_cp
def sampling_echograms_full(args):
    path_to_echograms = paths.path_to_echograms()
    assert (args.semi_ratio in [0.01, 0.05, 0.1, 0.2]), 'Fix args.semi-ratio in a given range'

    samplers_train = torch.load(os.path.join(path_to_echograms, 'sampler6_tr.pt'))
    samplers_bg_unb = torch.load(os.path.join(path_to_echograms, 'bg_unb_11000_tr.pt'))
    semi_count = int(len(samplers_train[0]) * args.semi_ratio)

    if len(samplers_train) == 6:
        if args.semi_ratio == 0.01:
            unlab_size = [12992, 195, 200, 1344, 90, 29]
        elif args.semi_ratio == 0.05:
            unlab_size = [12468, 187, 192, 1289, 86, 28]
        elif args.semi_ratio == 0.1:
            unlab_size = [11811, 177, 182, 1222, 82, 26]
        elif args.semi_ratio == 0.2:
            unlab_size = [10499, 157, 162, 1086, 73, 23]

    elif len(samplers_train) == 3:  # combined 3classes case
        if args.semi_ratio == 0.01:
            unlab_size = [14336, 285, 229]
        elif args.semi_ratio == 0.05:
            unlab_size = [13757, 273, 220]
        elif args.semi_ratio == 0.1:
            unlab_size = [13033, 259, 208]
        elif args.semi_ratio == 0.2:
            unlab_size = [11585, 230, 185]

    samplers_semi = []
    samplers_rest = []
    for samplers in samplers_train:
        samplers_semi.append(samplers[:semi_count])
        samplers_rest.append(samplers[semi_count:])
    samplers_rest[0].extend(samplers_bg_unb)

    samplers_unbal_unlab = []
    for sampler, size in zip(samplers_rest, unlab_size):
        samplers_unbal_unlab.append(sampler[:size])

    samplers_semi_unbal_unlab_long = []
    for sampler_semi, sampler_unb_unl in zip(samplers_semi, samplers_unbal_unlab):
        samplers_semi_unbal_unlab_long.extend(np.concatenate([sampler_semi, sampler_unb_unl]))

    list_length = len(samplers_train[0])
    num_classes = len(samplers_train)
    samplers_cp = [samplers_semi_unbal_unlab_long[i*list_length: (i+1)*list_length] for i in range(num_classes)]

    augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
    data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])

    dataset_cp = DatasetImg(
        samplers_cp,
        args.sampler_probs,
        augmentation_function=augmentation,
        data_transform_function=data_transform)

    dataset_semi = DatasetImg(
        samplers_semi,
        args.sampler_probs,
        augmentation_function=augmentation,
        data_transform_function=data_transform)

    return dataset_cp, dataset_semi