def generate_and_save_heave_files(): root = path_to_echograms() echs = get_echograms() for i, ech in enumerate(echs): if i % 100 == 0: print(len(echs), i) # Get vertical pixel resolution r = ech.range_vector r_diff = np.median(r[1:] - r[:-1]) # Convert heave value from meters to number of pixels heave = np.round(ech.heave / r_diff).astype(np.int) assert heave.size == ech.shape[1] labels_old = ech.label_numpy() labels_new = np.zeros_like(labels_old) # Create new labels: Move each labels column up/down corresponding to heave for x, h in enumerate(list(heave)): if h == 0: labels_new[:, x] = labels_old[:, x] elif h > 0: labels_new[:-h, x] = labels_old[h:, x] else: labels_new[-h:, x] = labels_old[:h, x] # Save new labels as new memmap file path_save = root + ech.name + '/labels_heave' save_memmap(labels_new, path_save, dtype=labels_new.dtype)
def sampling_echograms_full(args): path_to_echograms = paths.path_to_echograms() # path_to_echograms = '/Users/changkyu/Documents/GitHub/save_pts/sampled' samplers_train = torch.load( os.path.join(path_to_echograms, 'sampler3_tr.pt')) supervised_count = int(len(samplers_train[0]) * args.semi_ratio) samplers_supervised = [] samplers_tr_rest = [] for samplers in samplers_train: samplers_supervised.append(samplers[:supervised_count]) samplers_tr_rest.append(samplers[supervised_count:]) augmentation = CombineFunctions([add_noise_img, flip_x_axis_img]) data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img]) dataset_semi = DatasetImg(samplers_supervised, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) dataset_tr_rest = DatasetImg(samplers_tr_rest, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) return dataset_semi, dataset_tr_rest
def sampling_echograms_full(args): tr_ratio = [0.97808653, 0.01301181, 0.00890166] path_to_echograms = paths.path_to_echograms() ######## samplers_train = torch.load( os.path.join(path_to_echograms, 'sampler3_tr.pt')) semi_count = int(len(samplers_train[0]) * args.semi_ratio) samplers_semi = [samplers[:semi_count] for samplers in samplers_train] augmentation = CombineFunctions([add_noise_img, flip_x_axis_img]) data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img]) dataset_cp = DatasetImg(samplers_train, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) dataset_semi = DatasetImg(samplers_semi, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) return dataset_cp, dataset_semi
def sampling_echograms_full_for_comparisonP2(args): path_to_echograms = paths.path_to_echograms() data = torch.load(os.path.join(path_to_echograms, 'data_tr_TEST_200.pt')) label = torch.load(os.path.join(path_to_echograms, 'label_tr_TEST_200.pt')) data_transform = CombineFunctions( [remove_nan_inf_for_comparisonP2, db_with_limits_for_comparisonP2]) label_transform = CombineFunctions([ index_0_1_27_for_comparisonP2, relabel_with_threshold_morph_close_for_comparisonP2, seabed_checker_for_comparisonP2 ]) semi_count = int(len(data) * args.semi_ratio) dataset_cp = DatasetImg_for_comparisonP2( data=data, label=label, label_transform_function=label_transform, data_transform_function=data_transform) dataset_semi = DatasetImg_for_comparisonP2( data=data[:semi_count], label=label[:semi_count], label_transform_function=label_transform, data_transform_function=data_transform) return dataset_cp, dataset_semi
def test_and_plot_2019(test_pred_large_2019, test_label_large_2019, epoch, args, idx=2): path_to_echograms = paths.path_to_echograms() data_2019, label_2019, patch_loc = torch.load(os.path.join(path_to_echograms, 'data_label_patch_loc_te_2019_%d.pt' % idx)) data_transform = CombineFunctions([remove_nan_inf_for_comparisonP2, db_with_limits_for_comparisonP2]) label_transform = CombineFunctions([index_0_1_27_for_comparisonP2, relabel_with_threshold_morph_close_for_comparisonP2, seabed_checker_for_comparisonP2]) boxsize = 5 plt.figure(figsize=(boxsize * patch_loc[1], boxsize * patch_loc[0] * 4)) for i in range(len(data_2019)): l = label_2019[i] d = data_2019[i] d, l = label_transform(d, l) d, l = data_transform(d, l) dim = np.shape(l) labels_rgb = np.ones((dim[0], dim[1], 3)) sandeel = np.where(l == 1) other = np.where(l == 2) labels_rgb[sandeel[0], sandeel[1], 0] = 0 labels_rgb[sandeel[0], sandeel[1], 1] = 0 # sandeel blue labels_rgb[other[0], other[1], 1] = 0 labels_rgb[other[0], other[1], 2] = 0 # other red pred = test_pred_large_2019[i] pred_rgb = np.ones((dim[0], dim[1], 3)) pred_sandeel = np.where(pred == 1) pred_other = np.where(pred == 2) pred_rgb[pred_sandeel[0], pred_sandeel[1], 0] = 0 pred_rgb[pred_sandeel[0], pred_sandeel[1], 1] = 0 # sandeel blue pred_rgb[pred_other[0], pred_other[1], 1] = 0 pred_rgb[pred_other[0], pred_other[1], 2] = 0 # other red lbb = test_label_large_2019[i] lbb_rgb = np.ones((dim[0], dim[1], 3)) lbb_sandeel = np.where(lbb == 1) lbb_other = np.where(lbb == 2) lbb_rgb[lbb_sandeel[0], lbb_sandeel[1], 0] = 0 lbb_rgb[lbb_sandeel[0], lbb_sandeel[1], 1] = 0 # sandeel blue lbb_rgb[lbb_other[0], lbb_other[1], 1] = 0 lbb_rgb[lbb_other[0], lbb_other[1], 2] = 0 # other red plt.subplot(patch_loc[0] * 4, patch_loc[1], i + 1) plt.imshow(labels_rgb) plt.subplot(patch_loc[0] * 4, patch_loc[1], patch_loc[1] * patch_loc[0] + i + 1) plt.imshow(pred_rgb) plt.subplot(patch_loc[0] * 4, patch_loc[1], patch_loc[1] * patch_loc[0] * 2 + i + 1) plt.imshow(lbb_rgb) plt.subplot(patch_loc[0] * 4, patch_loc[1], patch_loc[1] * patch_loc[0] * 3 + i + 1) plt.imshow(d[-1]) plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]) plt.tight_layout() plt.savefig(os.path.join(args.pred_2019, '%d_data_pred_label_2019_%d_patch.pdf' % (epoch, idx))) plt.close() return
def sampling_echograms_full(args): tr_ratio = [0.97808653, 0.01301181, 0.00890166] path_to_echograms = paths.path_to_echograms() ######## samplers_train = torch.load( os.path.join(path_to_echograms, 'sampler3_tr.pt')) samplers_bg = torch.load( os.path.join(path_to_echograms, 'train_bg_32766.pt')) supervised_count = int(len(samplers_train[0]) * args.semi_ratio) total_unsupervised_count = int( (len(samplers_train[0]) - supervised_count) * args.nmb_category) unlab_size = [int(ratio * total_unsupervised_count) for ratio in tr_ratio] if np.sum(unlab_size) != total_unsupervised_count: unlab_size[0] += total_unsupervised_count - np.sum(unlab_size) samplers_supervised = [] samplers_unsupervised = [] for samplers in samplers_train: samplers_supervised.append(samplers[:supervised_count]) samplers_unsupervised.append(samplers[supervised_count:]) samplers_unsupervised[0].extend(samplers_bg) samplers_unbal_unlab = [] for sampler, size in zip(samplers_unsupervised, unlab_size): samplers_unbal_unlab.append(sampler[:size]) samplers_semi_unbal_unlab_long = [] for sampler_semi, sampler_unb_unl in zip(samplers_supervised, samplers_unbal_unlab): samplers_semi_unbal_unlab_long.extend( np.concatenate([sampler_semi, sampler_unb_unl])) list_length = len(samplers_train[0]) num_classes = len(samplers_train) samplers_cp = [ samplers_semi_unbal_unlab_long[i * list_length:(i + 1) * list_length] for i in range(num_classes) ] ######## augmentation = CombineFunctions([add_noise_img, flip_x_axis_img]) data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img]) dataset_cp = DatasetImg(samplers_cp, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) dataset_semi = DatasetImg(samplers_supervised, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) return dataset_cp, dataset_semi
def save_all_seabeds(): """ Loop through all echograms and generate seabed-estimates :return: """ path_to_echograms = paths.path_to_echograms() echogram_names = os.listdir(path_to_echograms) echograms = [Echogram(path_to_echograms + e) for e in echogram_names] for e in echograms: e.get_seabed(save_to_file=True, ignore_saved=True)
def get_echograms_revised(eg_names_full, sample_idx, num_echograms=100): path_to_echograms = paths.path_to_echograms() index_list = np.arange(len(eg_names_full)//num_echograms) if sample_idx not in set(index_list): sample_idx = 0 print('Reset_sample_idx') eg_idx = np.arange(num_echograms*sample_idx, num_echograms *(sample_idx+1)) eg_names = list(map(eg_names_full.__getitem__, eg_idx)) echograms = [Echogram(os.path.join(path_to_echograms, e)) for e in eg_names] sample_idx += 1 return echograms, sample_idx
def sampling_echograms_test(args): path_to_echograms = paths.path_to_echograms() samplers_test = torch.load( os.path.join(path_to_echograms, 'sampler6_te.pt')) data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img]) dataset_test = DatasetImg(samplers_test, args.sampler_probs, augmentation_function=None, data_transform_function=data_transform) return dataset_test
def sampling_echograms_full(args): path_to_echograms = paths.path_to_echograms() samplers_train = torch.load( os.path.join(path_to_echograms, 'samplers_sb_1024_3.pt')) augmentation = CombineFunctions([add_noise_img, flip_x_axis_img]) data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img]) dataset_cp = DatasetImg(samplers_train, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) return dataset_cp
def test_analysis(predictions, predictions_mat, epoch, args): path_to_echograms = paths.path_to_echograms() labels_origin = torch.load(os.path.join(path_to_echograms, 'label_TEST_60_after_transformation.pt')) if np.shape(predictions_mat) == (60, 256, 256, 3): predictions_mat = predictions_mat.transpose(0, 3, 1, 2) keep_test_idx = np.where(labels_origin > -1) labels_vec = labels_origin[keep_test_idx] predictions_vec = predictions[keep_test_idx] predictions_mat_sampled = predictions_mat[keep_test_idx[0], :, keep_test_idx[1], keep_test_idx[2]] fpr, tpr, roc_auc, roc_auc_macro = roc_curve_macro(labels_vec, predictions_mat_sampled) prob_mat, mat, f1_score, kappa = conf_mat(ylabel=labels_vec, ypred=predictions_vec, args=args) acc_bg, acc_se, acc_ot = prob_mat.diagonal() plot_macro(fpr, tpr, roc_auc, epoch, args) plot_conf(epoch, prob_mat, mat, f1_score, kappa, args) return fpr, tpr, roc_auc, roc_auc_macro, prob_mat, mat, f1_score, kappa, acc_bg, acc_se, acc_ot
def sampling_echograms_for_s3vm(args): path_to_echograms = paths.path_to_echograms() samplers_bg = torch.load( os.path.join(path_to_echograms, 'train_bg_32766.pt')) list_length = len(samplers_bg) // args.nmb_category samplers_bg = [ samplers_bg[i * list_length:(i + 1) * list_length] for i in range(args.nmb_category) ] augmentation = CombineFunctions([add_noise_img, flip_x_axis_img]) data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img]) dataset_bg_full = DatasetImg(samplers_bg, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) return dataset_bg_full
def sampling_echograms_test_for_comparisonP2(): path_to_echograms = paths.path_to_echograms() data = torch.load(os.path.join(path_to_echograms, 'data_te_TEST_60.pt')) label = torch.load(os.path.join(path_to_echograms, 'label_te_TEST_60.pt')) data_transform = CombineFunctions( [remove_nan_inf_for_comparisonP2, db_with_limits_for_comparisonP2]) label_transform = CombineFunctions([ index_0_1_27_for_comparisonP2, relabel_with_threshold_morph_close_for_comparisonP2, seabed_checker_for_comparisonP2 ]) dataset = DatasetImg_for_comparisonP2( data=data, label=label, label_transform_function=label_transform, data_transform_function=data_transform) return dataset
def sampling_echograms_full(window_size, args): path_to_echograms = paths.path_to_echograms() with open(os.path.join(path_to_echograms, 'memmap_2014_heave.pkl'), 'rb') as fp: eg_names_full = pickle.load(fp) echograms = get_echograms_full(eg_names_full) echograms_train, echograms_val, echograms_test = cps.partition_data( echograms, args.partition, portion_train_test=0.8, portion_train_val=0.75) sampler_bg_train = Background(echograms_train, window_size) sampler_sh27_train = Shool(echograms_train, window_size, 27) sampler_sbsh27_train = ShoolSeabed(echograms_train, window_size, args.window_dim // 4, fish_type=27) sampler_sh01_train = Shool(echograms_train, window_size, 1) sampler_sbsh01_train = ShoolSeabed(echograms_train, window_size, args.window_dim // 4, fish_type=1) samplers_train = [ sampler_bg_train, sampler_sh27_train, sampler_sbsh27_train, sampler_sh01_train, sampler_sbsh01_train ] augmentation = CombineFunctions([add_noise, flip_x_axis]) label_transform = CombineFunctions( [index_0_1_27, relabel_with_threshold_morph_close]) data_transform = CombineFunctions([remove_nan_inf, db_with_limits]) dataset_train = Dataset(samplers_train, window_size, args.frequencies, args.batch * args.iteration_train, args.sampler_probs, augmentation_function=augmentation, label_transform_function=label_transform, data_transform_function=data_transform) return dataset_train
def sampling_echograms_2019_for_comparisonP2(echogram_idx=2, path_to_echograms=None): if path_to_echograms == None: path_to_echograms = paths.path_to_echograms() data, label, patch_loc = torch.load( os.path.join(path_to_echograms, 'data_label_patch_loc_te_2019_%d.pt' % echogram_idx)) data_transform = CombineFunctions( [remove_nan_inf_for_comparisonP2, db_with_limits_for_comparisonP2]) label_transform = CombineFunctions([ index_0_1_27_for_comparisonP2, relabel_with_threshold_morph_close_for_comparisonP2, seabed_checker_for_comparisonP2 ]) dataset_2019 = DatasetImg_for_comparisonP2( data=data, label=label, label_transform_function=label_transform, data_transform_function=data_transform) return dataset_2019, label, patch_loc
parser.add_argument('--partition', type=str, default='train_only', help='echogram partition (tr/val/te) by year') parser.add_argument('--iteration_train', type=int, default=1200, help='num_tr_iterations per one batch and epoch') parser.add_argument('--sampler_probs', type=list, default=None, help='[bg, sh27, sbsh27, sh01, sbsh01], default=[1, 1, 1, 1, 1]') parser.add_argument('--resume', default=os.path.join(current_dir, 'checkpoint.pth.tar'), type=str, metavar='PATH', help='path to checkpoint (default: None)') parser.add_argument('--exp', type=str, default=current_dir, help='path to exp folder') return parser.parse_args(args=[]) args = parse_args() window_size = [args.window_dim, args.window_dim] path_to_echograms = paths.path_to_echograms() with open(os.path.join(path_to_echograms, 'memmap_2014_heave.pkl'), 'rb') as fp: eg_names_full = pickle.load(fp) echograms = get_echograms_full(eg_names_full) echograms_train, echograms_val, echograms_test = cps.partition_data(echograms, args.partition, portion_train_test=0.8, portion_train_val=0.75) sampler_bg_train = Background(echograms_train, window_size) # sampler_sh27_train = Shool(echograms_train, window_size, 27) # sampler_sh01_train = Shool(echograms_train, window_size, 1) # sampler_sbsh27_train = ShoolSeabed(echograms_train, window_size, args.window_dim // 4, fish_type=27) # sampler_sbsh01_train = ShoolSeabed(echograms_train, window_size, args.window_dim // 4, fish_type=1) data_bg = DatasetSampler(sampler_bg_train, window_size, args.frequencies) # data_sh27 = DatasetSampler(sampler_sh27_train, window_size, args.frequencies)
def sampling_echograms_full(args): # idx_2000 = np.random.choice(3000, size=2000, replace=False).tolist() # sampler_2000 = [] # for i in range(5): # sampler_2000.append([samplers_train[i][idx] for idx in idx_2000]) # torch.save(sampler_2000, 'samplers_2000.pt') # path_to_echograms = "/Users/changkyu/Documents/GitHub/echogram/memmap/memmap_set" # bg = torch.load(os.path.join(path_to_echograms, 'numpy_bg_2999.pt')) + \ # torch.load(os.path.join(path_to_echograms, 'numpy_bg_5999.pt')) # bg_idx = np.random.choice(np.arange(len(bg)), size=3000, replace=False) # bg = [bg[idx] for idx in bg_idx] # # sbsh01 = torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_2999.pt')) +\ # torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_5999.pt')) +\ # torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_8999.pt')) +\ # torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_11999.pt')) +\ # torch.load(os.path.join(path_to_echograms, 'numpy_sbsh01_12667.pt')) # sbsh01_idx = np.random.choice(np.arange(len(sbsh01)), size=3000, replace=False) # sbsh01 = [sbsh01[idx] for idx in sbsh01_idx] # # sbsh27 = torch.load(os.path.join(path_to_echograms, 'numpy_sbsh27_2999.pt'))+\ # torch.load(os.path.join(path_to_echograms, 'numpy_sbsh27_3079.pt')) # sbsh27_idx = np.random.choice(np.arange(len(sbsh27)), size=3000, replace=False) # sbsh27 = [sbsh27[idx] for idx in sbsh27_idx] # # sh01 = torch.load(os.path.join(path_to_echograms, 'numpy_sh01_2999.pt'))+\ # torch.load(os.path.join(path_to_echograms, 'numpy_sh01_4046.pt')) # sh01_idx = np.random.choice(np.arange(len(sh01)), size=3000, replace=False) # sh01 = [sh01[idx] for idx in sh01_idx] # # sh27 = torch.load(os.path.join(path_to_echograms, 'numpy_sh27_2999.pt'))+\ # torch.load(os.path.join(path_to_echograms, 'numpy_sh27_3549.pt')) # sh27_idx = np.random.choice(np.arange(len(sh27)), size=3000, replace=False) # sh27 = [sh27[idx] for idx in sh27_idx] # samplers_train = [bg, sh27, sbsh27, sh01, sbsh01] # torch.save(samplers_train, 'samplers_3000.pt') # samplers_train = [bg, sh27, sbsh27, sh01, sbsh01] # def sample_align(samplers): # num_samples = [] # new_samplers = [] # for i in range(len(samplers)): # num_samples.append(len(samplers[i])) # max_num_sample = np.min(num_samples) # print(max_num_sample) # for i in range(len(samplers)): # new_samplers.append(np.random.choice(samplers[i], size=max_num_sample, replace=False)) # return new_samplers # path_to_echograms = '/Users/changkyu/Documents/GitHub/echogram/memmap/memmap_set/' path_to_echograms = paths.path_to_echograms() samplers_train = torch.load( os.path.join(path_to_echograms, 'samplers_500_three.pt')) augmentation = CombineFunctions([add_noise_img, flip_x_axis_img]) data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img]) dataset_cp = DatasetImg(samplers_train, 1500, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) return dataset_cp
def sampling_echograms_full(args): path_to_echograms = paths.path_to_echograms() assert (args.semi_ratio in [0.01, 0.05, 0.1, 0.2]), 'Fix args.semi-ratio in a given range' samplers_train = torch.load(os.path.join(path_to_echograms, 'sampler6_tr.pt')) samplers_bg_unb = torch.load(os.path.join(path_to_echograms, 'bg_unb_11000_tr.pt')) semi_count = int(len(samplers_train[0]) * args.semi_ratio) if len(samplers_train) == 6: if args.semi_ratio == 0.01: unlab_size = [12992, 195, 200, 1344, 90, 29] elif args.semi_ratio == 0.05: unlab_size = [12468, 187, 192, 1289, 86, 28] elif args.semi_ratio == 0.1: unlab_size = [11811, 177, 182, 1222, 82, 26] elif args.semi_ratio == 0.2: unlab_size = [10499, 157, 162, 1086, 73, 23] elif len(samplers_train) == 3: # combined 3classes case if args.semi_ratio == 0.01: unlab_size = [14336, 285, 229] elif args.semi_ratio == 0.05: unlab_size = [13757, 273, 220] elif args.semi_ratio == 0.1: unlab_size = [13033, 259, 208] elif args.semi_ratio == 0.2: unlab_size = [11585, 230, 185] samplers_semi = [] samplers_rest = [] for samplers in samplers_train: samplers_semi.append(samplers[:semi_count]) samplers_rest.append(samplers[semi_count:]) samplers_rest[0].extend(samplers_bg_unb) samplers_unbal_unlab = [] for sampler, size in zip(samplers_rest, unlab_size): samplers_unbal_unlab.append(sampler[:size]) samplers_semi_unbal_unlab_long = [] for sampler_semi, sampler_unb_unl in zip(samplers_semi, samplers_unbal_unlab): samplers_semi_unbal_unlab_long.extend(np.concatenate([sampler_semi, sampler_unb_unl])) list_length = len(samplers_train[0]) num_classes = len(samplers_train) samplers_cp = [samplers_semi_unbal_unlab_long[i*list_length: (i+1)*list_length] for i in range(num_classes)] augmentation = CombineFunctions([add_noise_img, flip_x_axis_img]) data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img]) dataset_cp = DatasetImg( samplers_cp, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) dataset_semi = DatasetImg( samplers_semi, args.sampler_probs, augmentation_function=augmentation, data_transform_function=data_transform) return dataset_cp, dataset_semi
def sampling_echograms(sample_idx, window_size, args): path_to_echograms = paths.path_to_echograms() with open(os.path.join(path_to_echograms, 'memmap_2014_heave.pkl'), 'rb') as fp: eg_names_full = pickle.load(fp) print(sample_idx, "IN SAMPLING_ECHOGRAMS") sample_idx = int(sample_idx) echograms, sample_idx = get_echograms_revised( eg_names_full, sample_idx, num_echograms=args.num_echogram) echograms_train, echograms_val, echograms_test = cps.partition_data( echograms, args.partition, portion_train_test=0.8, portion_train_val=0.75) sampler_bg_train = Background(echograms_train, window_size) # sampler_sb_train = Seabed(echograms_train, window_size) sampler_sh27_train = Shool(echograms_train, window_size, 27) sampler_sbsh27_train = ShoolSeabed(echograms_train, window_size, args.window_dim // 4, fish_type=27) sampler_sh01_train = Shool(echograms_train, window_size, 1) sampler_sbsh01_train = ShoolSeabed(echograms_train, window_size, args.window_dim // 4, fish_type=1) # sampler_bg_val = Background(echograms_val, window_size) # # sampler_sb_val = Seabed(echograms_val, window_size) # sampler_sh27_val = Shool(echograms_val, window_size, 27) # sampler_sbsh27_val = ShoolSeabed(echograms_val, window_size, args.window_dim//4, fish_type=27) # sampler_sh01_val = Shool(echograms_val, window_size, 1) # sampler_sbsh01_val = ShoolSeabed(echograms_val, window_size, args.window_dim//4, fish_type=1) samplers_train = [ sampler_bg_train, #sampler_sb_train, sampler_sh27_train, sampler_sbsh27_train, sampler_sh01_train, sampler_sbsh01_train ] # samplers_val = [sampler_bg_val, # sampler_sb_val, # sampler_sh27_val, sampler_sbsh27_val, # sampler_sh01_val, sampler_sbsh01_val] augmentation = CombineFunctions([add_noise, flip_x_axis]) label_transform = CombineFunctions( [index_0_1_27, relabel_with_threshold_morph_close]) data_transform = CombineFunctions([remove_nan_inf, db_with_limits]) dataset_train = Dataset(samplers_train, window_size, args.frequencies, args.batch * args.iteration_train, args.sampler_probs, augmentation_function=augmentation, label_transform_function=label_transform, data_transform_function=data_transform) # dataset_val = DatasetVal( # samplers_val, # window_size, # args.frequencies, # args.batch * args.iteration_val, # args.sampler_probs, # augmentation_function=None, # label_transform_function=label_transform, # data_transform_function=data_transform) # val_dataloader = torch.utils.data.DataLoader(dataset_val, # shuffle=True, # batch_size=args.batch, # num_workers=args.workers, # pin_memory=True) return dataset_train, int(sample_idx)