Beispiel #1
0
def tuning(model, name):
    logging.info("Fine tuning model: {}".format(name))
    criterion = cross_entropy2d
    optimizer = optim.RMSprop(model.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=w_decay)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=step_size,
                                    gamma=gamma)

    dsets = {
        x: SegDataset(os.path.join(DATA_DIR, x))
        for x in ['train', 'val']
    }
    dset_loaders = {
        x: DataLoader(dsets[x],
                      batch_size=batch_size,
                      shuffle=True,
                      num_workers=1)
        for x in ['train', 'val']
    }

    train_loader, val_loader = dset_loaders['train'], dset_loaders['val']
    train(model, name, criterion, optimizer, scheduler, train_loader,
          val_loader, epochs)
Beispiel #2
0
def fine_tune(model, name):
    logging.info("Fine tuning model: {}".format(name))
    # criterion = nn.CrossEntropyLoss()
    criterion = cross_entropy2d
    optimizer = optim.RMSprop(model.parameters(),
                              lr=vals['lr'],
                              momentum=vals['momentum'],
                              weight_decay=vals['w_decay'])
    scheduler = lr_scheduler.StepLR(
        optimizer, step_size=vals['step_size'], gamma=vals['gamma']
    )  # decay LR by a factor of 0.5 every {step_size} epochs

    batch_size = vals['batch_size']
    epochs = vals['epochs']
    dsets = {
        x: SegDataset(os.path.join(DATA_DIR, x))
        for x in ['train', 'val']
    }
    dset_loaders = {
        x: DataLoader(dsets[x],
                      batch_size=batch_size,
                      shuffle=True,
                      num_workers=1)
        for x in ['train', 'val']
    }

    logging.info(
        'Parameters. batch_size: {}, epoches: {}, lr: {}, momentum: {}, w_decay: {}, step_size: {}, gamma: {}.'
        .format(batch_size, epochs, vals['lr'], vals['momentum'],
                vals['w_decay'], vals['step_size'], vals['gamma']))

    train_loader, val_loader = dset_loaders['train'], dset_loaders['val']
    train(model, name, criterion, optimizer, scheduler, train_loader,
          val_loader, epochs)
Beispiel #3
0
def get_curves_object_detection(net, args, save_dir, fg_id=1):
    pos_scores, neg_scores = [], []  # 无视阈值,所有图片、所有检出区域的平均分数

    test_set = SegDataset(args.test_set,
                          num_classes=args.out_channels,
                          appoint_size=(args.height, args.width),
                          erode=args.dilate)
    for img_id in tqdm(range(0, len(test_set), 1)):
        data, label = test_set[img_id]
        batch_data = data.unsqueeze(0).cuda()
        obj_scores = get_fg_scores_from_net_output(net, batch_data,
                                                   fg_id)  # 该张图片中各个检出区域的平均分数
        pos_scores = pos_scores + obj_scores

    for name in [i.name for i in Path(args.test_videos).glob('*.*')]:
        video_path = os.path.join(args.test_videos, name)
        cap = cv2.VideoCapture(video_path)
        for idx in tqdm(
                range(0,
                      int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1, 200)):
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            _, frame = cap.read()
            img_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((args.height, args.width)),
                transforms.ToTensor()
            ])
            batch_data = img_transform(frame).unsqueeze(0).cuda()
            obj_scores = get_fg_scores_from_net_output(
                net, batch_data, fg_id)  # 该张图片中各个检出区域的平均分数
            neg_scores = neg_scores + obj_scores

    th_stride = 0.1
    lower_bounds = [i / 10 for i in range(0, 10, int(th_stride * 10))]

    pos_num_intergrate = get_obj_num_each_interval(pos_scores,
                                                   lower_bounds,
                                                   th_stride,
                                                   intergrate=False)
    neg_num_intergrate = get_obj_num_each_interval(neg_scores,
                                                   lower_bounds,
                                                   th_stride,
                                                   intergrate=False)
    save_path = save_dir + '/num.png'
    draw_two_curves(lower_bounds, pos_num_intergrate, neg_num_intergrate,
                    'pos', 'neg', 'thresh', save_path)

    pos_num_intergrate = get_obj_num_each_interval(pos_scores,
                                                   lower_bounds,
                                                   th_stride,
                                                   intergrate=True)
    neg_num_intergrate = get_obj_num_each_interval(neg_scores,
                                                   lower_bounds,
                                                   th_stride,
                                                   intergrate=True)
    save_path = save_dir + '/num_int.png'
    draw_two_curves(lower_bounds, pos_num_intergrate, neg_num_intergrate,
                    'pos', 'neg', 'thresh', save_path)
Beispiel #4
0
def get_score_info(net, args, save_dir=None, fg_id=1, type='pos'):
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    means, medians = [], []
    if type == 'pos':
        test_set = SegDataset(args.test_set,
                              num_classes=args.out_channels,
                              appoint_size=(args.height, args.width),
                              erode=args.dilate)
        for i in range(0, len(test_set), 1):
            data, label = test_set[i]
            batch_data = data.unsqueeze(0).cuda()
            mean, median = get_score_info_from_net_output(
                net, batch_data, fg_id)
            if mean is not None and median is not None:
                means.append(mean)
                medians.append(median)
    else:
        stride = 200
        names = [i.name for i in Path(args.test_videos).glob('*.*')]
        for name in names:
            video_path = os.path.join(args.test_videos, name)
            cap = cv2.VideoCapture(video_path)
            total_frame_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            for idx in range(0, total_frame_num - 1, stride):  # 抛弃最后一帧才能有效保存视频
                cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                _, frame = cap.read()
                img_transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Resize((args.height, args.width)),
                    transforms.ToTensor()
                ])
                batch_data = img_transform(frame).unsqueeze(0).cuda()

                mean, median = get_score_info_from_net_output(
                    net, batch_data, fg_id)
                if mean is not None and median is not None:
                    means.append(mean)
                    medians.append(median)

    save_dict = {}
    save_dict.setdefault('mean', np.mean(means).astype('float'))
    save_dict.setdefault('median', np.median(medians).astype('float'))
    if type == 'pos':
        path = save_dir + '/score_info_from_pos.json'
    else:
        path = save_dir + '/score_info_from_neg.json'
    with open(path, 'w') as f:
        json.dump(save_dict, f, indent=2)
Beispiel #5
0
def train():
    model = ReSeg(2, pretrained=False, use_coordinates=True, usegpu=False)
    batch_size = 2
    criterion1 = DiceLoss()
    criterion2 = DiscriminativeLoss(0.5, 1.5, 2)
    optimizer = optim.Adam(model.parameters())
    dst = SegDataset(list_path=list_path,
                     img_root=img_root,
                     height=448,
                     width=448,
                     number_of_instances=number_of_instances,
                     semantic_ann_npy=semantic_ann_npy,
                     instances_ann_npy=instances_ann_npy,
                     transform=x_transform)
    ac = AlignCollate(2, 100, 448, 448)
    trainloader = torch.utils.data.DataLoader(dst, batch_size=1, collate_fn=ac)
    train_model(model, criterion1, criterion2, optimizer, trainloader)
Beispiel #6
0
def do_test(mode, args):
    print('\nTesting: {}. ################################# Mode: {}'.format(args.pt_dir, mode))
    pt_dir = args.pt_root + '/' + args.pt_dir
    args = merge_args_from_train_json(args, json_path=pt_dir + '/train_args.json')
    pt_path = find_latest_pt(pt_dir)
    if pt_path is None:
        return
    print('Loading:', pt_path)
    net = choose_net(args.net_name, args.out_channels).cuda()
    net.load_state_dict(torch.load(pt_path, map_location={'cuda:5':'cuda:0'}))
    net.eval()
    test_loader, class_weights = None, None
    if mode == 0 or mode == 1:
        test_set = SegDataset(args.test_set, args.out_channels, appoint_size=(args.height, args.width), erode=0)
        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
        class_weights = get_class_weights(test_loader, args.out_channels, args.weighting)

    if mode == 0:
        eval_dataset_full(net, args.out_channels, test_loader, class_weights=class_weights, save_dir=pt_dir)

    if mode == 1:
        qualitative_results_from_dataset(net, args, pause=0, save_dir=pt_dir)

    elif mode == 2:
        predict_videos(net, args, partial=True, save_vid=False, dst_size=(960, 540), save_dir=pt_dir)

    elif mode == 3:
        predict_videos(net, args, partial=False, save_vid=True, dst_size=(960, 540), save_dir='/workspace/lanenet_mod_cbam_withneg/')

    elif mode == 4:
        predict_images(net, args, dst_size=(960, 540), save_dir='/workspace/tmp/')

    elif mode == 5:
        save_all_negs_from_videos(net, args, save_dir='/workspace/negs_from_videos0904_more_negs/')

    elif mode == 6:
        get_score_info(net, args, save_dir=pt_dir, type='neg')

    elif mode == 7:
        get_pos_neg_thresh_curves(net, args, save_dir=pt_dir)

    elif mode == 8:
        get_curves_object_detection(net, args, save_dir=pt_dir)
Beispiel #7
0
def qualitative_results_from_dataset(net,
                                     args,
                                     sample_rate=0.2,
                                     pause=0.2,
                                     save_dir=None):
    test_set = SegDataset(args.test_set,
                          num_classes=args.out_channels,
                          appoint_size=(args.height, args.width),
                          erode=args.dilate)
    stride = int(len(test_set) * sample_rate)
    for i in range(0, len(test_set), stride):
        data, label = test_set[i]
        batch_data = data.unsqueeze(0).cuda()
        batch_label = label.unsqueeze(0).cuda()
        prediction_np, _, (_, _) = predict_a_batch(net,
                                                   args.out_channels,
                                                   batch_data,
                                                   batch_label,
                                                   class_weights=None,
                                                   do_criterion=False,
                                                   do_metric=False)
        if args.dilate > 0:
            kernel = cv2.getStructuringElement(cv2.MORPH_RECT,
                                               (args.dilate, args.dilate))
            prediction_np = cv2.dilate(prediction_np.astype('uint8'), kernel)

        data_transform = transforms.Compose([transforms.ToPILImage()])
        data_pil = data_transform(data)
        data_np = cv2.cvtColor(np.array(data_pil), cv2.COLOR_BGR2RGB)
        label_np = np.array(label.unsqueeze(0))

        show_np = add_mask_to_source_multi_classes(data_np, prediction_np,
                                                   args.out_channels)
        save_path = save_dir + '/' + args.pt_dir + str(i) + '.png'
        subplots(data_np,
                 label_np.squeeze(0),
                 prediction_np,
                 show_np,
                 text='data/label/prediction/merge',
                 pause=pause,
                 save_path=save_path)
        print('Processed image', i)
Beispiel #8
0
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--epoch', type=int, default=40)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--dataset', type=str, default='./data/')
parser.add_argument('--workers', type=int, default=4)
parser.add_argument('--save_model', type=str, default='./saved_model/')

cfg = parser.parse_args()
print(cfg)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

if __name__ == "__main__":
    ds_train = SegDataset(transforms=preprocessing)
    ds_test = SegDataset(split='test', transforms=preprocessing)
    dl_train = DataLoader(ds_train,
                          batch_size=cfg.batch_size,
                          shuffle=True,
                          num_workers=cfg.workers)
    dl_test = DataLoader(ds_test,
                         batch_size=cfg.batch_size,
                         shuffle=False,
                         num_workers=cfg.workers)

    print("DATA LOADED")
    model = U_Net()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    criterion = dice_loss
    success_metric = dice_coeff
Beispiel #9
0
    # PREPROCESSING FOR IMAGES
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        # transforms.Normalize((0.485, 0.456, 0.406),(0.485, 0.456, 0.406)),
    ])
    """
    Shape of images is (batch,CLASSES , 256, 256)
    Shape of masks is (batch,  256, 256)
    """
    # LOSS FUNCTION
    loss_fn = nn.CrossEntropyLoss()

    # LOADING THE DATASET INTO TRAINLOADER
    trainset = SegDataset(image_path, mask_path, transform=preprocess)
    train_loader = DataLoader(trainset,
                              BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              pin_memory=PIN_MEM,
                              shuffle=True)

    # Load the model & Optimizer
    model = SegNet()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    if args.train:
        print("Train Mode.")
        # Train Model
        train_model(model, optimizer, train_loader, loss_fn, device, EPOCHS)
    else:
Beispiel #10
0
    # Preprocess dataset
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    ])

    # Video Prepreocess
    video_preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load the datasets needed.
    # Shape: 3 x 256 x 256 image
    trainset = SegDataset(img_dir, mask_dir, transform=preprocess)
    train_loader = DataLoader(trainset,
                              BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              pin_memory=PIN_MEM,
                              shuffle=True)

    valset = SegDataset(val_imgs_dir, val_masks_dir, transform=preprocess)
    val_loader = DataLoader(valset,
                            BATCH_SIZE,
                            num_workers=NUM_WORKERS,
                            pin_memory=PIN_MEM)

    # Load Model
    model = SegNetV2()
    # Loss Function
Beispiel #11
0
def train(args):
    # Prepare training set
    train_set = SegDataset(args.train_set,
                           num_classes=args.out_channels,
                           appoint_size=(args.height, args.width),
                           erode=args.erode,
                           aug=args.train_aug)
    print('Length of train_set:', len(train_set))
    train_dataloader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4)
    train_class_weights = get_class_weights(train_dataloader,
                                            out_channels=args.out_channels,
                                            weighting=args.weighting)
    if args.eval:
        val_set = SegDataset(args.val_set,
                             num_classes=args.out_channels,
                             appoint_size=(args.height, args.width),
                             erode=0)
        val_dataloader = DataLoader(val_set,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=4)
        val_class_weights = get_class_weights(val_dataloader,
                                              out_channels=args.out_channels,
                                              weighting=args.weighting)
    else:
        val_dataloader, val_class_weights = None, None

    # Prepare save dir
    save_dir = './Results/' + args.save_suffix + '-' + args.net_name + '-h' + str(train_set[0][0].shape[1]) + 'w' \
               + str(train_set[0][0].shape[2]) + '-erode' + str(args.erode) + '-weighting_' + str(args.weighting)
    print('Save dir is:{}  Input size is:{}'.format(save_dir,
                                                    train_set[0][0].shape))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    with open(save_dir + '/train_args.json', 'w') as f:
        json.dump(vars(args), f, indent=2)

    # Prepare network
    writer = SummaryWriter(save_dir)
    val_dicts = []
    net = choose_net(args.net_name, args.out_channels).cuda()
    train_criterion = get_criterion(args.out_channels,
                                    class_weights=train_class_weights)
    optimizer = get_optimizer(net, args.opt_name)

    if args.f16:
        model, optimizer = amp.initialize(net, optimizer,
                                          opt_level="O1")  # 这里是“欧一”,不是“零一”

    steps = len(train_dataloader)
    lr_scheduler = get_lr_scheduler(optimizer,
                                    max_iters=args.epoch * steps,
                                    sch_name=args.sch_name)

    # Begin to train
    iter_cnt = 0
    for epo in range(args.epoch):
        net.train()
        for batch_id, (batch_data, batch_label) in enumerate(train_dataloader):
            if args.out_channels == 1:
                batch_label = batch_label.float(
                )  # 逻辑损失需要label的类型和data相同,均为float,而不是long
            else:
                batch_label = batch_label.squeeze(
                    1)  # 交叉熵label的类型采用默认的long,但需要去除C通道维

            iter_cnt += 1
            output = net(batch_data.cuda())
            loss = train_criterion(output, batch_label.cuda())
            iter_loss = loss.item()
            print('Epoch:{} Batch:[{}/{}] Train loss:{}'.format(
                epo + 1,
                str(batch_id + 1).zfill(3), steps, round(iter_loss, 4)))
            writer.add_scalar('Train loss', iter_loss, iter_cnt)

            optimizer.zero_grad()
            if args.f16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            if lr_scheduler is not None and args.opt_name != 'adam':
                lr_scheduler.step()

        if args.eval:
            v_loss, (miou,
                     pa) = eval_dataset_full(net.eval(),
                                             args.out_channels,
                                             val_dataloader,
                                             class_weights=val_class_weights,
                                             save_dir=None)
            writer.add_scalar('Val loss', v_loss, epo + 1)
            writer.add_scalar('Val miou', miou, epo + 1)
            writer.add_scalar('Val pa', pa, epo + 1)
            val_dict_tmp = {}
            val_dict_tmp.setdefault('epoch', epo + 1)
            val_dict_tmp.setdefault('loss', v_loss)
            val_dict_tmp.setdefault('miou', miou)
            val_dict_tmp.setdefault('pa', pa)
            val_dicts.append(val_dict_tmp)

        if (epo + 1) == args.epoch or (epo + 1) % 25 == 0 or epo == 0:
            save_file = save_dir + '/' + args.net_name + '_{}.pt'.format(epo +
                                                                         1)
            torch.save(net.state_dict(), save_file)
            print('Saved checkpoint:', save_file)

    writer.close()
    with open(save_dir + '/val_log.json', 'w') as f2:
        json.dump(val_dicts, f2, indent=2)

    if args.eval:
        predict_images(net, args, dst_size=(960, 540), save_dir=save_dir)
Beispiel #12
0
def get_pos_neg_thresh_curves(net,
                              args,
                              save_dir=None,
                              fg_id=1,
                              measure='mean'):
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    def normalize_list(list):
        max_ = max(list)
        min_ = min(list)
        return [(i - min_) / (max_ - min_) for i in list]

    pos = []
    neg = []
    dicts = []
    th_stride = 0.1
    lower_bounds = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    for lower_bound in lower_bounds:
        values_pos, values_neg = [], []
        test_set = SegDataset(args.test_set,
                              num_classes=args.out_channels,
                              appoint_size=(args.height, args.width),
                              erode=args.dilate)
        for i in range(0, len(test_set), 1):
            data, label = test_set[i]
            batch_data = data.unsqueeze(0).cuda()
            value = get_fg_value_from_net_output(net,
                                                 batch_data,
                                                 fg_id,
                                                 lower_bound,
                                                 th_stride=th_stride,
                                                 measure=measure)
            values_pos.append(value)
        mean_pos = np.mean(values_pos).astype('float')

        stride = 200
        names = [i.name for i in Path(args.test_videos).glob('*.*')]
        for name in names:
            video_path = os.path.join(args.test_videos, name)
            cap = cv2.VideoCapture(video_path)
            total_frame_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            for idx in range(0, total_frame_num - 1, stride):  # 抛弃最后一帧才能有效保存视频
                cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                _, frame = cap.read()
                img_transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Resize((args.height, args.width)),
                    transforms.ToTensor()
                ])
                batch_data = img_transform(frame).unsqueeze(0).cuda()
                value = get_fg_value_from_net_output(net,
                                                     batch_data,
                                                     fg_id,
                                                     lower_bound,
                                                     th_stride=th_stride,
                                                     measure=measure)
                values_neg.append(value)
        mean_neg = np.mean(values_neg).astype('float')

        save_dict = {}
        save_dict.setdefault('interval',
                             (lower_bound, lower_bound + th_stride))
        save_dict.setdefault('mean_pos', mean_pos)
        save_dict.setdefault('mean_neg', mean_neg)
        dicts.append(save_dict)

        pos.append(mean_pos)
        neg.append(mean_neg)

    def intergrate_list(list):
        for i in range(len(list)):
            tmp = list[:i]
            list[i] += sum(tmp)
        return list

    pos = intergrate_list(pos)
    neg = intergrate_list(neg)

    # pos = normalize_list(pos)
    # neg = normalize_list(neg)

    save_dict1 = {}
    save_dict1.setdefault('pos', pos)
    save_dict1.setdefault('neg', neg)
    dicts.append(save_dict1)

    with open(save_dir + '/' + measure + '-pos_neg.json', 'w') as f:
        json.dump(dicts, f, indent=2)

    def show_curves(x, pos, neg):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.set_xlabel("thresh")
        lns1 = ax.plot(x, pos, '-', label='pos')
        # ax.set_ylim(0, 1)
        plt.yticks([])
        ax2 = ax.twinx()
        lns2 = ax2.plot(x, neg, '-r', label='neg')
        # ax2.set_ylim(0, 1)
        lns = lns1 + lns2
        labs = [l.get_label() for l in lns]
        ax.legend(lns, labs, loc=0)
        plt.yticks([])
        plt.savefig(save_dir + '/' + measure + '-pos_neg.jpg')
        plt.show()

    show_curves(lower_bounds, pos, neg)
Beispiel #13
0
import torch

import numpy as np
from label_io import read_data_names, read_images, read_segmaps, read_labels

from dataset import SegDataset

train_data_names = read_data_names('./train_labels')
train_segmaps = read_segmaps('./train_labels', train_data_names)

train_dataset = SegDataset('./train_images', train_segmaps, train_data_names,
                           None)

for train_d in train_dataset:
    i = train_d['image']
    l = train_d['label']

    print(i.shape)
    print(torch.mean(i))
Beispiel #14
0
batchsize = 64
epochs = 200
imagesize = 256 #缩放图片大小
cropsize = 224 #训练图片大小
train_data_path = 'data/train.txt' #训练数据集
val_data_path = 'data/val.txt' #验证数据集

# 数据预处理
data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])


# 图像分割数据集
train_dataset = SegDataset(train_data_path,imagesize,cropsize,data_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batchsize, shuffle=True)
val_dataset = SegDataset(val_data_path,imagesize,cropsize,data_transform)
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.__len__(), shuffle=True)

image_datasets = {}
image_datasets['train'] = train_dataset
image_datasets['val'] = val_dataset
dataloaders = {}
dataloaders['train'] = train_dataloader
dataloaders['val'] = val_dataloader

# 定义网络,优化目标,优化方法
device = torch.device('cpu')
net = simpleNet5().to(device)
criterion = nn.CrossEntropyLoss() #使用softmax loss损失,输入label是图片
Beispiel #15
0
    return model

def get_transform(train):

    transforms = []
    transforms.append(T.ToTensor())
#     if train:
#         transforms.append(T.RandomHorizontalFlip(0.5))
    return transforms

if __name__== "__main__":
    
    torch.cuda.empty_cache()

    dbroot = '/datasets/OpenImages/processedv4'
    dataset_train = SegDataset(os.path.join(dbroot, 'test'), get_transform(train=True))
    dataset_val = SegDataset(os.path.join(dbroot, 'validation'), get_transform(train=False))

    # define training and validation data loaders
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, drop_last=True, batch_size=48, shuffle=True, num_workers=12)
#         collate_fn=utils.collate_fn)

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, drop_last=True, batch_size=48, shuffle=False, num_workers=12)
#         collate_fn=utils.collate_fn)

    print(f"Train set size: {len(data_loader_train.dataset)}, n_batches: {len(data_loader_train)}")
    print(f"Validation set size: {len(data_loader_val.dataset)}, n_batches: {len(data_loader_val)}")

    # model