import torch
import sys
sys.path.append("..")
# ===================
from OCT_train import trainModels
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

trainModels(model='SOASNet_single',
            data_set='duke',
            input_dim=1,
            epochs=250,
            width=64,
            depth=4,
            depth_limit=6,
            repeat=5,
            l_r=1e-3,
            l_r_s=True,
            train_batch=4,
            shuffle=True,
            loss='ce',
            norm='bn',
            log='MICCAI_Duke_Results',
            class_no=8,
            cluster=True,
            data_augmentation_train='all',
            data_augmentation_test='none')
import torch
import sys
sys.path.append("..")
# ===================
from OCT_train import trainModels
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    #
    trainModels(model='SOASNet',
                data_set='ours',
                input_dim=1,
                epochs=50,
                width=16,
                depth=4,
                depth_limit=6,
                repeat=3,
                l_r=1e-3,
                l_r_s=True,
                train_batch=4,
                shuffle=True,
                loss='ce',
                norm='bn',
                log='MICCAI_Our_Data_Results',
                class_no=2,
                cluster=True,
                data_augmentation_train='all',
                data_augmentation_test='all')

    print('Finished.')
 #             cluster=False)
 # #
 # ====================================
 # SegNet based
 # ====================================
 #
 trainModels(model='SOASNet_very_large_kernel',
             data_set='ours',
             input_dim=1,
             epochs=1,
             width=16,
             depth=4,
             depth_limit=6,
             repeat=1,
             l_r=1e-3,
             l_r_s=True,
             train_batch=4,
             shuffle=True,
             loss='ce',
             norm='bn',
             log='Test',
             class_no=2,
             cluster=False,
             data_augmentation_train='all',
             data_augmentation_test='none')
 #
 # trainModels(model='RelayNet',
 #             input_dim=1,
 #             epochs=250,
 #             width=64,
 #             depth=4,