Esempio n. 1
0
 def transform_image(image, image_size, num_channels):
     image = np.array(image, dtype=np.uint8)
     image = utility.to_channels(image, num_channels)
     obj = ObjectImageTransform(image)
     obj = get_transforms_det(image_size)(obj)
     # tensor is channel x height x width
     # image = image.transpose((2, 0, 1))
     # obj.to_tensor()
     # value = torch.from_numpy(obj.image).float()
     obj.image = obj.image.unsqueeze(0)
     return obj.to_dict()
Esempio n. 2
0
def main(params=None):
    # This model has a lot of variabilty, so it needs a lot of parameters.
    # We use an arg parser to get all the arguments we need.
    # See above for the default values, definitions and information on the datatypes.
    parser = arg_parser()
    if params:
        args = parser.parse_args(params)
    else:
        args = parser.parse_args()

    # Configuration
    project = args.project
    projectname = args.projectname
    pathnamedataset = args.pathdataset
    pathnamemodel = args.model
    pathproject = os.path.join(project, projectname)
    namedataset = args.namedataset
    breal = args.breal
    name_method = args.name_method
    iteration = args.iteration

    fname = args.name_method
    fnet = {
        'attnet': AttentionNeuralNet,
        'attgmmnet': AttentionGMMNeuralNet,
        'classnet': ClassNeuralNet,
    }

    no_cuda = False
    parallel = False
    gpu = 0
    seed = 1
    brepresentation = True
    bclassification_test = True
    brecover_test = False

    imagesize = 64
    kfold = 5
    nactores = 10
    idenselect = np.arange(nactores) + kfold * nactores

    # experiments
    experiments = [{
        'name': namedataset,
        'subset': FactoryDataset.training,
        'status': breal
    }, {
        'name': namedataset,
        'subset': FactoryDataset.validation,
        'status': breal
    }]

    if brepresentation:

        # create an instance of a model
        print('>> Load model ...')
        network = fnet[fname](
            patchproject=project,
            nameproject=projectname,
            no_cuda=no_cuda,
            parallel=parallel,
            seed=seed,
            gpu=gpu,
        )

        cudnn.benchmark = True

        # load trained model
        if network.load(pathnamemodel) is not True:
            print('>>Error!!! load model')
            assert (False)

        # Perform the experiments
        for i, experiment in enumerate(experiments):

            name_dataset = experiment['name']
            subset = experiment['subset']
            breal = experiment['status']
            dataset = []

            # load dataset
            if breal == 'real':

                # real dataset
                dataset = Dataset(
                    data=FactoryDataset.factory(pathname=pathnamedataset,
                                                name=namedataset,
                                                subset=subset,
                                                idenselect=idenselect,
                                                download=True),
                    num_channels=3,
                    transform=get_transforms_det(imagesize),
                )

            else:

                # synthetic dataset
                dataset = SyntheticFaceDataset(
                    data=FactoryDataset.factory(pathname=pathnamedataset,
                                                name=namedataset,
                                                subset=subset,
                                                idenselect=idenselect,
                                                download=True),
                    pathnameback='~/.datasets/coco',
                    ext='jpg',
                    count=iteration,
                    num_channels=3,
                    iluminate=True,
                    angle=45,
                    translation=0.3,
                    warp=0.2,
                    factor=0.2,
                    transform_data=get_transforms_aug(imagesize),
                    transform_image=get_transforms_det(imagesize),
                )

            dataloader = DataLoader(dataset,
                                    batch_size=64,
                                    shuffle=False,
                                    num_workers=10)

            print("\ndataset:", breal)
            print("Subset:", subset)
            print("Classes", dataloader.dataset.data.classes)
            print("size of data:", len(dataset))
            print("num of batches", len(dataloader))

            # if method is attgmmnet, then the output has representation vector Zs
            # otherwise, the output only has the predicted emotions, and ground truth
            if name_method == 'attgmmnet':
                # representation
                Y_labs, Y_lab_hats, Zs = network.representation(
                    dataloader, breal)
                print(Y_lab_hats.shape, Zs.shape, Y_labs.shape)

                reppathname = os.path.join(
                    pathproject,
                    'rep_{}_{}_{}.pth'.format(namedataset, subset, breal))
                torch.save({
                    'Yh': Y_lab_hats,
                    'Z': Zs,
                    'Y': Y_labs
                }, reppathname)
                print('save representation ...', reppathname)

            else:
                Y_labs, Y_lab_hats = network.representation(dataloader, breal)
                print("Y_lab_hats shape: {}, y_labs shape: {}".format(
                    Y_lab_hats.shape, Y_labs.shape))

                reppathname = os.path.join(
                    pathproject,
                    'rep_{}_{}_{}.pth'.format(namedataset, subset, breal))
                torch.save({'Yh': Y_lab_hats, 'Y': Y_labs}, reppathname)
                print('save representation ...', reppathname)

    # if calculate the classification result, accuracy, precision, recall and f1
    if bclassification_test:
        tuplas = []
        print('|Num\t|Acc\t|Prec\t|Rec\t|F1\t|Set\t|Type\t|Accuracy_type\t')
        for i, experiment in enumerate(experiments):

            name_dataset = experiment['name']
            subset = experiment['subset']
            breal = experiment['status']
            real = breal

            rep_pathname = os.path.join(
                pathproject, 'rep_{}_{}_{}.pth'.format(namedataset, subset,
                                                       breal))

            data_emb = torch.load(rep_pathname)
            Yto = data_emb['Y']
            Yho = data_emb['Yh']

            yhat = np.argmax(Yho, axis=1)
            y = Yto

            acc = metrics.accuracy_score(y, yhat)
            precision = metrics.precision_score(y, yhat, average='macro')
            recall = metrics.recall_score(y, yhat, average='macro')
            f1_score = 2 * precision * recall / (precision + recall)

            print(
                '|{}\t|{:0.3f}\t|{:0.3f}\t|{:0.3f}\t|{:0.3f}\t|{}\t|{}\t|{}\t'.
                format(i, acc, precision, recall, f1_score, subset, real,
                       'topk'))

            cm = metrics.confusion_matrix(y, yhat)
            # label = ['Neutral', 'Happiness', 'Surprise', 'Sadness', 'Anger', 'Disgust', 'Fear', 'Contempt']
            # cm_display = metrics.ConfusionMatrixDisplay(cm, display_labels=label).plot()
            print(cm)

            print(f'save y and yhat to {real}_{subset}_y.npz')
            np.savez(os.path.join(pathproject, f'{real}_{subset}_y.npz'),
                     name1=yhat,
                     name2=y)

            #|Name|Dataset|Cls|Acc| ...
            tupla = {
                'Name': projectname,
                'Dataset': '{}({})_{}'.format(name_dataset, subset, real),
                'Accuracy': acc,
                'Precision': precision,
                'Recall': recall,
                'F1 score': f1_score,
            }
            tuplas.append(tupla)

        # save
        df = pd.DataFrame(tuplas)
        df.to_csv(os.path.join(pathproject, 'experiments_cls.csv'),
                  index=False,
                  encoding='utf-8')
        print('save experiments class ...')
        print()
    print('DONE!!!')
Esempio n. 3
0
def main():

    parser = arg_parser()
    args = parser.parse_args()

    # Configuration
    project = args.project
    projectname = args.projectname
    pathnamedataset = args.pathdataset
    pathnamemodel = args.model
    pathproject = os.path.join(project, projectname)
    pathnameout = args.pathnameout
    filename = args.filename
    namedataset = args.namedataset

    no_cuda = False
    parallel = False
    gpu = 0
    seed = 1
    imagesize = 128
    batch_size = 100
    idenselect = []

    # experiments
    experiments = [
        {
            'name': namedataset,
            'subset': FactoryDataset.training,
            'real': True
        },
        {
            'name': namedataset,
            'subset': FactoryDataset.validation,
            'real': True
        },
    ]

    # Load models
    print('>> Load model ...')
    network = AttentionNeuralNet(
        patchproject=project,
        nameproject=projectname,
        no_cuda=no_cuda,
        parallel=parallel,
        seed=seed,
        gpu=gpu,
    )

    cudnn.benchmark = True

    # load model
    if network.load(pathnamemodel) is not True:
        print('>>Error!!! load model')
        assert (False)

    size_input = network.size_input
    for i, experiment in enumerate(experiments):

        name_dataset = experiment['name']
        subset = experiment['subset']
        breal = experiment['real']
        dataset = []

        # real dataset
        dataset = Dataset(
            data=FactoryDataset.factory(pathname=pathnamedataset,
                                        name=namedataset,
                                        subset=subset,
                                        idenselect=idenselect,
                                        download=True),
            num_channels=3,
            transform=get_transforms_det(imagesize),
        )

        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=10)

        print(breal)
        print(subset)
        #print(dataloader.dataset.data.classes)
        print(len(dataset))
        print(len(dataloader))

        # representation
        Y_labs, Y_lab_hats, Zs = network.representation(dataloader, breal)
        print(Y_lab_hats.shape, Zs.shape, Y_labs.shape)

        reppathname = os.path.join(
            pathproject,
            'rep_{}_{}_{}_{}.pth'.format(projectname, namedataset, subset,
                                         'real' if breal else 'no_real'))
        torch.save({'Yh': Y_lab_hats, 'Z': Zs, 'Y': Y_labs}, reppathname)
        print('save representation ...')

    print('DONE!!!')
Esempio n. 4
0
def main():

    parser = arg_parser()
    args = parser.parse_args()

    # Configuration
    project = args.project
    projectname = args.projectname
    pathnamedataset = args.pathdataset
    pathnamemodel = args.model
    pathproject = os.path.join(project, projectname)
    pathnameout = args.pathnameout
    filename = args.filename
    namedataset = args.namedataset

    no_cuda = False
    parallel = False
    gpu = 0
    seed = 1
    brepresentation = True
    bclassification_test = True
    brecover_test = True

    imagesize = 64
    idenselect = np.arange(10)

    # experiments
    experiments = [
        {
            'name': namedataset,
            'subset': FactoryDataset.training,
            'real': True
        },
        {
            'name': namedataset,
            'subset': FactoryDataset.validation,
            'real': True
        },
        {
            'name': namedataset + 'dark',
            'subset': FactoryDataset.training,
            'real': False
        },
        {
            'name': namedataset + 'dark',
            'subset': FactoryDataset.validation,
            'real': False
        },
    ]

    # representation datasets
    if brepresentation:

        # Load models
        print('>> Load model ...')
        network = AttentionNeuralNet(
            patchproject=project,
            nameproject=projectname,
            no_cuda=no_cuda,
            parallel=parallel,
            seed=seed,
            gpu=gpu,
        )

        cudnn.benchmark = True

        # load model
        if network.load(pathnamemodel) is not True:
            print('>>Error!!! load model')
            assert (False)

        size_input = network.size_input
        for i, experiment in enumerate(experiments):

            name_dataset = experiment['name']
            subset = experiment['subset']
            breal = experiment['real']
            dataset = []

            # load dataset
            if breal:

                # real dataset
                dataset = Dataset(
                    data=FactoryDataset.factory(pathname=pathnamedataset,
                                                name=namedataset,
                                                subset=subset,
                                                idenselect=idenselect,
                                                download=True),
                    num_channels=3,
                    transform=get_transforms_det(imagesize),
                )

            else:

                # synthetic dataset
                dataset = SyntheticFaceDataset(
                    data=FactoryDataset.factory(pathname=pathnamedataset,
                                                name=namedataset,
                                                subset=subset,
                                                idenselect=idenselect,
                                                download=True),
                    pathnameback='~/.datasets/coco',
                    ext='jpg',
                    count=2000,
                    num_channels=3,
                    iluminate=True,
                    angle=45,
                    translation=0.3,
                    warp=0.2,
                    factor=0.2,
                    transform_data=get_transforms_aug(imagesize),
                    transform_image=get_transforms_det(imagesize),
                )

            dataloader = DataLoader(dataset,
                                    batch_size=100,
                                    shuffle=False,
                                    num_workers=10)

            print(breal)
            print(subset)
            print(dataloader.dataset.data.classes)
            print(len(dataset))
            print(len(dataloader))

            # representation
            Y_labs, Y_lab_hats, Zs = network.representation(dataloader, breal)
            print(Y_lab_hats.shape, Zs.shape, Y_labs.shape)

            reppathname = os.path.join(
                pathproject,
                'rep_{}_{}_{}_{}.pth'.format(projectname, namedataset, subset,
                                             'real' if breal else 'no_real'))
            torch.save({'Yh': Y_lab_hats, 'Z': Zs, 'Y': Y_labs}, reppathname)
            print('save representation ...')

    if bclassification_test:

        tuplas = []
        print('|Num\t|Acc\t|Prec\t|Rec\t|F1\t|Set\t|Type\t')
        for i, experiment in enumerate(experiments):

            name_dataset = experiment['name']
            subset = experiment['subset']
            breal = experiment['real']
            real = 'real' if breal else 'no_real'

            rep_pathname = os.path.join(
                pathproject,
                'rep_{}_{}_{}_{}.pth'.format(projectname, namedataset, subset,
                                             real))

            data_emb = torch.load(rep_pathname)
            Xto = data_emb['Z']
            Yto = data_emb['Y']
            Yho = data_emb['Yh']

            yhat = np.argmax(Yho, axis=1)
            y = Yto

            acc = metrics.accuracy_score(y, yhat)
            precision = metrics.precision_score(y, yhat, average='macro')
            recall = metrics.recall_score(y, yhat, average='macro')
            f1_score = 2 * precision * recall / (precision + recall)

            print('|{}\t|{:0.3f}\t|{:0.3f}\t|{:0.3f}\t|{:0.3f}\t|{}\t|{}\t'.
                  format(
                      i,
                      acc,
                      precision,
                      recall,
                      f1_score,
                      subset,
                      real,
                  ).replace('.', ','))

            #|Name|Dataset|Cls|Acc| ...
            tupla = {
                'Name': projectname,
                'Dataset': '{}({})_{}'.format(name_dataset, subset, real),
                'Accuracy': acc,
                'Precision': precision,
                'Recall': recall,
                'F1 score': f1_score,
            }
            tuplas.append(tupla)

        # save
        df = pd.DataFrame(tuplas)
        df.to_csv(os.path.join(pathnameout, 'experiments_cls.csv'),
                  index=False,
                  encoding='utf-8')
        print('save experiments class ...')
        print()

    if brecover_test:
        experiments = [
            {
                'name': namedataset,
                'train': True,
                'val': True
            },
            {
                'name': namedataset,
                'train': False,
                'val': False
            },
            {
                'name': namedataset,
                'train': False,
                'val': True
            },
            {
                'name': namedataset,
                'train': True,
                'val': False
            },
        ]

        tuplas = []
        print('|Num\t|Acc\t|Prec\t|Rec\t|F1\t|Type\t')
        for i, experiment in enumerate(experiments):
            name_dataset = experiment['name']
            real_train = 'real' if experiment['train'] else 'no_real'
            real_val = 'real' if experiment['val'] else 'no_real'

            rep_trn_pathname = os.path.join(
                pathproject,
                'rep_{}_{}_{}_{}.pth'.format(projectname, name_dataset,
                                             'train', real_train))
            rep_val_pathname = os.path.join(
                pathproject,
                'rep_{}_{}_{}_{}.pth'.format(projectname, name_dataset, 'val',
                                             real_val))

            data_emb_train = torch.load(rep_trn_pathname)
            data_emb_val = torch.load(rep_val_pathname)
            Xo = data_emb_train['Z']
            Yo = data_emb_train['Y']
            Xto = data_emb_val['Z']
            Yto = data_emb_val['Y']

            clf = KNeighborsClassifier(n_neighbors=11)
            #clf = GaussianNB()
            #clf = RandomForestClassifier(n_estimators=150, oob_score=True, random_state=123456)
            #clf = MLPClassifier(hidden_layer_sizes=(100,100), max_iter=100, alpha=1e-4,
            #                     solver='sgd', verbose=10, tol=1e-4, random_state=1,
            #                     learning_rate_init=.01)

            clf.fit(Xo, Yo)

            y = Yto
            yhat = clf.predict(Xto)

            acc = metrics.accuracy_score(y, yhat)
            nmi_s = metrics.cluster.normalized_mutual_info_score(y, yhat)
            mi = metrics.cluster.mutual_info_score(y, yhat)
            h1 = metrics.cluster.entropy(y)
            h2 = metrics.cluster.entropy(yhat)
            nmi = 2 * mi / (h1 + h2)
            precision = metrics.precision_score(y, yhat, average='macro')
            recall = metrics.recall_score(y, yhat, average='macro')
            f1_score = 2 * precision * recall / (precision + recall)

            #|Name|Dataset|Cls|Acc| ...
            tupla = {
                'Name': projectname,
                'Dataset': '{}({}_{})'.format(name_dataset, real_train,
                                              real_val),
                'Accuracy': acc,
                'NMI': nmi_s,
                'Precision': precision,
                'Recall': recall,
                'F1 score': f1_score,
            }
            tuplas.append(tupla)

            print(
                '|{}\t|{:0.3f}\t|{:0.3f}\t|{:0.3f}\t|{:0.3f}\t|{}/{}\t'.format(
                    i,
                    acc,
                    precision,
                    recall,
                    f1_score,
                    real_train,
                    real_val,
                ).replace('.', ','))

        # save
        df = pd.DataFrame(tuplas)
        df.to_csv(os.path.join(pathnameout, 'experiments_recovery.csv'),
                  index=False,
                  encoding='utf-8')
        print('save experiments recovery ...')
        print()

    print('DONE!!!')
Esempio n. 5
0
def main():
    
    # parameters
    parser = arg_parser()
    args = parser.parse_args()
    imsize = args.image_size
    parallel=args.parallel
    num_classes=args.num_classes
    num_channels=args.channels
    dim=args.dim
    view_freq=1

    fname = args.name_method
    fnet = {
        'attnet':AttentionNeuralNet, 
        'attstnnet':AttentionSTNNeuralNet, 
        'attgmmnet':AttentionGMMNeuralNet, 
        'attgmmstnnet':AttentionGMMSTNNeuralNet
        }

    network = fnet[fname](
        patchproject=args.project,
        nameproject=args.name,
        no_cuda=args.no_cuda,
        parallel=parallel,
        seed=args.seed,
        print_freq=args.print_freq,
        gpu=args.gpu,
        view_freq=view_freq,
        )
        
    network.create( 
        arch=args.arch, 
        num_output_channels=dim, 
        num_input_channels=num_channels,
        loss=args.loss, 
        lr=args.lr, 
        momentum=args.momentum,
        optimizer=args.opt,
        lrsch=args.scheduler,
        pretrained=args.finetuning,
        size_input=imsize,
        num_classes=num_classes
        )
    
    # resume
    network.resume( os.path.join(network.pathmodels, args.resume ) )
    cudnn.benchmark = True
    
    kfold=args.kfold
    nactores=args.nactor
    idenselect = np.arange(nactores) + kfold*nactores

    # datasets
    # training dataset
    # SyntheticFaceDataset, SecuencialSyntheticFaceDataset
    train_data = SecuencialSyntheticFaceDataset(
        data=FactoryDataset.factory(
            pathname=args.data, 
            name=args.name_dataset, 
            subset=FactoryDataset.training, 
            idenselect=idenselect,
            download=True 
            ),
        pathnameback=args.databack, 
        ext='jpg',
        count=50000, #100000
        num_channels=num_channels,
        iluminate=True, angle=30, translation=0.2, warp=0.1, factor=0.2,
        #iluminate=True, angle=45, translation=0.3, warp=0.2, factor=0.2,
        transform_data=get_transforms_aug( imsize ),
        transform_image=get_transforms_det( imsize ),
        )
    
    
#     labels, counts = np.unique(train_data.labels, return_counts=True)
#     weights = 1/(counts/counts.sum())        
#     samples_weights = np.array([ weights[ x ]  for x in train_data.labels ])    
    
    num_train = len(train_data)
    sampler = SubsetRandomSampler(np.random.permutation( num_train ) ) 
#     sampler = WeightedRandomSampler( weights=samples_weights, num_samples=len(samples_weights) , replacement=True )

    train_loader = DataLoader(train_data, batch_size=args.batch_size,
        num_workers=args.workers, pin_memory=network.cuda, drop_last=True, sampler=sampler ) #shuffle=True,
    
    
    # validate dataset
    # SyntheticFaceDataset, SecuencialSyntheticFaceDataset
    val_data = SecuencialSyntheticFaceDataset(
        data=FactoryDataset.factory(
            pathname=args.data, 
            name=args.name_dataset, 
            idenselect=idenselect,
            subset=FactoryDataset.validation, 
            download=True
            ),
        pathnameback=args.databack, 
        ext='jpg',
        #count=1000, #10000
        num_channels=num_channels,
        iluminate=True, angle=30, translation=0.2, warp=0.1, factor=0.2, 
        #iluminate=True, angle=45, translation=0.3, warp=0.2, factor=0.2,         
        transform_data=get_transforms_aug( imsize ),
        transform_image=get_transforms_det( imsize ),
        )

    val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=True, 
        num_workers=args.workers, pin_memory=network.cuda, drop_last=False)
       
    # print neural net class
    print('SEG-Torch: {}'.format(datetime.datetime.now()) )
    print(network)

    # training neural net
    network.fit( train_loader, val_loader, args.epochs, args.snapshot )
    
               
    print("Optimization Finished!")
    print("DONE!!!")
Esempio n. 6
0
def main():
    # This model has a lot of variabilty, so it needs a lot of parameters.
    # We use an arg parser to get all the arguments we need.
    # See above for the default values, definitions and information on the datatypes.
    parser = arg_parser()
    args = parser.parse_args()
    imsize = args.image_size
    parallel = args.parallel
    num_classes = args.num_classes
    num_channels = args.channels
    num_filters = args.num_filters
    dim = args.dim
    view_freq = 1
    trainiteration = args.trainiteration
    testiteration = args.testiteration
    alpha = args.alpha
    beta = args.beta

    # Which network do we want to use? Initialize it.
    fname = args.name_method
    fnet = {
        'attnet': AttentionNeuralNet,  # network for
        'attgmmnet': AttentionGMMNeuralNet,
        'classnet': ClassNeuralNet,
    }
    network = fnet[fname](
        patchproject=args.project,
        nameproject=args.name,
        no_cuda=args.no_cuda,
        parallel=parallel,
        seed=args.seed,
        print_freq=args.print_freq,
        gpu=args.gpu,
        view_freq=view_freq,
    )
    network.create(arch=args.arch,
                   num_output_channels=dim,
                   num_input_channels=num_channels,
                   loss=args.loss,
                   lr=args.lr,
                   momentum=args.momentum,
                   optimizer=args.opt,
                   lrsch=args.scheduler,
                   pretrained=args.finetuning,
                   size_input=imsize,
                   num_classes=num_classes,
                   breal=args.breal,
                   alpha=alpha,
                   beta=beta)

    # resume if a model can be loaded
    if args.resume is not None:
        network.resume(os.path.join(network.pathmodels, args.resume))
    cudnn.benchmark = True

    kfold = args.kfold
    nactores = args.nactor
    idenselect = np.arange(nactores) + kfold * nactores

    # choose and load our datasets
    # training dataset
    if args.breal == 'real':
        train_data = Dataset(
            data=FactoryDataset.factory(pathname=args.data,
                                        name=args.name_dataset,
                                        subset=FactoryDataset.training,
                                        idenselect=idenselect,
                                        download=True),
            num_channels=3,
            transform=get_transforms_det(args.image_size),
        )

        val_data = Dataset(
            data=FactoryDataset.factory(pathname=args.data,
                                        name=args.name_dataset,
                                        subset=FactoryDataset.validation,
                                        idenselect=idenselect,
                                        download=True),
            num_channels=3,
            transform=get_transforms_det(args.image_size),
        )
    else:
        # SyntheticFaceDataset
        train_data = SyntheticFaceDataset(
            data=FactoryDataset.factory(pathname=args.data,
                                        name=args.name_dataset,
                                        subset=FactoryDataset.training,
                                        idenselect=idenselect,
                                        download=True),
            pathnameback=args.databack,
            ext='jpg',
            count=trainiteration,
            num_channels=num_channels,
            iluminate=True,
            angle=30,
            translation=0.2,
            warp=0.1,
            factor=0.2,
            transform_data=get_transforms_aug(imsize),
            transform_image=get_transforms_det(imsize),
        )

        # validate dataset
        val_data = SyntheticFaceDataset(
            data=FactoryDataset.factory(pathname=args.data,
                                        name=args.name_dataset,
                                        idenselect=idenselect,
                                        subset=FactoryDataset.validation,
                                        download=True),
            pathnameback=args.databack,
            ext='jpg',
            count=testiteration,
            num_channels=num_channels,
            iluminate=True,
            angle=30,
            translation=0.2,
            warp=0.1,
            factor=0.2,
            transform_data=get_transforms_aug(imsize),
            transform_image=get_transforms_det(imsize),
        )

    num_train = len(train_data)
    # sample to balance the dataset
    if args.balance:
        labels, counts = np.unique(train_data.labels, return_counts=True)
        weights = 1 / (counts / counts.sum())
        samples_weights = np.array([weights[x] for x in train_data.labels])
        sampler = WeightedRandomSampler(weights=samples_weights,
                                        num_samples=len(samples_weights),
                                        replacement=True)
    else:
        sampler = SubsetRandomSampler(np.random.permutation(num_train))

    # Now that we have our dataset, make loaders to facilitate training and validation
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              pin_memory=network.cuda,
                              drop_last=True,
                              sampler=sampler)  # shuffle=True,

    val_loader = DataLoader(val_data,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=network.cuda,
                            drop_last=False)

    print(f"Train size: {len(train_data)}, test size: {len(val_data)}")

    # print neural net class
    start = datetime.datetime.now()
    print('SEG-Torch: {}'.format(start))
    print(network)

    # fit the neural net - this function performs both training and validation,
    # printing loss values to the screen and saving the best model for use later.
    network.fit(train_loader, val_loader, args.epochs, args.snapshot)

    print("Optimization Finished!")
    end = datetime.datetime.now()
    print("End time is", end)
    print("Time duration is:", round((end - start).total_seconds() / 60, 2))
    print("DONE!!!")