def main(ratio_known_abnormal_i):
    """
    Joint DSAD experiment with different fraction of known abnormal sample : 0.0%
    0.25% 0.5% 1.0% 5.0% and 10.0% over a single replicate.
    """
    # initialize logger
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    try:
        logger.handlers[1].stream.close()
        logger.removeHandler(logger.handlers[1])
    except IndexError:
        pass
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
    log_file = OUTPUT_PATH + 'logs/' + f'log_{ratio_known_abnormal_i:.2%}.txt'
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # print path and main docstring with experiment summary
    logger.info('Brief summary of experiment : \n' + main.__doc__)
    logger.info(f'Log file : {log_file}')
    logger.info(f'Data path : {DATA_PATH}')
    logger.info(f'Outputs path : {OUTPUT_PATH}' + '\n')

    ############################## Make datasets ###############################
    # load data_info
    df_info = pd.read_csv(DATA_INFO_PATH)
    df_info = df_info.drop(df_info.columns[0], axis=1)
    # remove low contrast images (all black)
    df_info = df_info[df_info.low_contrast == 0]

    # Train Validation Test Split
    spliter = MURA_TrainValidTestSplitter(
        df_info,
        train_frac=train_frac,
        ratio_known_normal=ratio_known_normal,
        ratio_known_abnormal=ratio_known_abnormal_i,
        random_state=42)
    spliter.split_data(verbose=False)
    train_df = spliter.get_subset('train')
    valid_df = spliter.get_subset('valid')
    test_df = spliter.get_subset('test')
    # make datasets
    train_dataset = MURA_Dataset(train_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size,
                                 data_augmentation=True)
    valid_dataset = MURA_Dataset(valid_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size,
                                 data_augmentation=False)
    test_dataset = MURA_Dataset(test_df,
                                data_path=DATA_PATH,
                                load_mask=True,
                                load_semilabels=True,
                                output_size=img_size,
                                data_augmentation=False)
    # print info to logger
    logger.info(f'Train fraction : {train_frac:.0%}')
    logger.info(f'Fraction knonw normal : {ratio_known_normal:.0%}')
    logger.info(f'Fraction known abnormal : {ratio_known_abnormal_i:.0%}')
    logger.info('Split Summary \n' + str(spliter.print_stat(returnTable=True)))
    logger.info('Online preprocessing pipeline : \n' +
                str(train_dataset.transform) + '\n')

    ################################ Set Up ####################################
    # Set seed
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        logger.info(f'Set seed to {seed}')

    # set number of thread
    if n_thread > 0:
        torch.set_num_threads(n_thread)

    # print info in logger
    logger.info(f'Device : {device}')
    logger.info(f'Number of thread : {n_thread}')
    logger.info(
        f'Number of dataloader worker for Joint DeepSAD : {n_jobs_dataloader}'
        + '\n')

    ######################### Networks Initialization ##########################
    net = AE_SVDD_Hybrid(pretrain_ResNetEnc=ae_pretrain,
                         output_channels=ae_out_size[0],
                         return_svdd_embed=True)
    net = net.to(device)

    # add info to logger
    logger.info(f'Network : {net.__class__.__name__}')
    logger.info(f'Autoencoder pretrained on ImageNet : {ae_pretrain}')
    logger.info(f'DeepSAD eta : {eta}')
    logger.info('Network architecture: \n' +
                summary_string(net, (1, img_size, img_size),
                               device=str(device),
                               batch_size=batch_size) + '\n')

    # initialization of the Model
    jointDeepSAD = JointDeepSAD(net, eta=eta)

    if model_path_to_load:
        jointDeepSAD.load_model(model_path_to_load, map_location=device)
        logger.info(f'Model Loaded from {model_path_to_load}' + '\n')

    ################################ Training ##################################
    # add parameter info
    logger.info(f'Joint DeepSAD number of epoch : {n_epoch}')
    logger.info(
        f'Joint DeepSAD number of pretraining epoch: {n_epoch_pretrain}')
    logger.info(f'Joint DeepSAD learning rate : {lr}')
    logger.info(f'Joint DeepSAD learning rate milestone : {lr_milestone}')
    logger.info(f'Joint DeepSAD weight_decay : {weight_decay}')
    logger.info(f'Joint DeepSAD optimizer : Adam')
    logger.info(f'Joint DeepSAD batch_size {batch_size}')
    logger.info(
        f'Joint DeepSAD number of dataloader worker : {n_jobs_dataloader}')
    logger.info(
        f'Joint DeepSAD criterion weighting : {criterion_weight[0]} Reconstruction loss + {criterion_weight[1]} SVDD embdedding loss'
        + '\n')

    # train DeepSAD
    jointDeepSAD.train(train_dataset,
                       lr=lr,
                       n_epoch=n_epoch,
                       n_epoch_pretrain=n_epoch_pretrain,
                       lr_milestone=lr_milestone,
                       batch_size=batch_size,
                       weight_decay=weight_decay,
                       device=device,
                       n_jobs_dataloader=n_jobs_dataloader,
                       print_batch_progress=print_batch_progress,
                       criterion_weight=criterion_weight,
                       valid_dataset=None)

    # validate DeepSAD
    jointDeepSAD.validate(valid_dataset,
                          device=device,
                          n_jobs_dataloader=n_jobs_dataloader,
                          print_batch_progress=print_batch_progress,
                          criterion_weight=criterion_weight)

    # test DeepSAD
    jointDeepSAD.test(test_dataset,
                      device=device,
                      n_jobs_dataloader=n_jobs_dataloader,
                      print_batch_progress=print_batch_progress,
                      criterion_weight=criterion_weight)

    # save results
    jointDeepSAD.save_results(
        OUTPUT_PATH +
        f'results/JointDeepSAD_results_{ratio_known_abnormal_i:.2%}.json')
    logger.info(
        'Test results saved at ' + OUTPUT_PATH +
        f'results/JointDeepSAD_results_{ratio_known_abnormal_i:.2%}.json')

    # save model
    jointDeepSAD.save_model(
        OUTPUT_PATH +
        f'model/JointDeepSAD_model_{ratio_known_abnormal_i:.2%}.pt')
    logger.info('Model saved at ' + OUTPUT_PATH +
                f'model/JointDeepSAD_model_{ratio_known_abnormal_i:.2%}.pt')
def main(seed_i):
    """
    Train jointly the AutoEncoder and the DeepSAD model following Lukas Ruff et
    al. (2019) work adapted to the MURA dataset (preprocessing inspired from the
    work of Davletshina et al. (2020)). The network structure is a ResNet18
    AutoEncoder until the Adaptative average pooling layer. The AE embdedding is
    thus (512, 16, 16). This embdedding is further processed through 3 convolutional
    layers (specific to the SVDD embdedding generation) to provide the SVDD
    embdedding of 512. The network is trained with two loss functions: a masked MSE
    loss for the reconstruction and the DeepSAD loss on the embedding. The two
    losses are scaled to be comparable by perfoming one forward pass prior the
    training. The Encoder is not initialized with weights trained on ImageNet.
    The AE masked reconstruction loss is not computed for known abnormal sample
    so that the AE learn to reconstruct normal samples only. The network input is
    masked with the mask : only the body part is kept and the background is set
    to zero. The AE is pretrained over 5 epochs in order to improve the initialization
    of the hypersphere center (we hypothetize that with a pretrained AE the
    hypersphere center estimation will be more meaningful). Note that the 'affine'
    parameters of BatchNorm2d layers has been changed to False in this implementation.
    """
    # initialize logger
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    try:
        logger.handlers[1].stream.close()
        logger.removeHandler(logger.handlers[1])
    except IndexError:
        pass
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
    log_file = OUTPUT_PATH + 'logs/' + f'log_{seed_i+1}.txt'
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # print path and main docstring with experiment summary
    logger.info('Brief summary of experiment : \n' + main.__doc__)
    logger.info(f'Log file : {log_file}')
    logger.info(f'Data path : {DATA_PATH}')
    logger.info(f'Outputs path : {OUTPUT_PATH}' + '\n')

    ############################## Make datasets ###############################
    # load data_info
    df_info = pd.read_csv(DATA_INFO_PATH)
    df_info = df_info.drop(df_info.columns[0], axis=1)
    # remove low contrast images (all black)
    df_info = df_info[df_info.low_contrast == 0]

    #df_info = df_info[df_info.body_part == 'HAND']

    # Train Validation Test Split
    spliter = MURA_TrainValidTestSplitter(
        df_info,
        train_frac=train_frac,
        ratio_known_normal=ratio_known_normal,
        ratio_known_abnormal=ratio_known_abnormal,
        random_state=42)
    spliter.split_data(verbose=False)
    train_df = spliter.get_subset('train')
    valid_df = spliter.get_subset('valid')
    test_df = spliter.get_subset('test')
    # make datasets
    train_dataset = MURA_Dataset(train_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size,
                                 data_augmentation=True)
    valid_dataset = MURA_Dataset(valid_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size,
                                 data_augmentation=False)
    test_dataset = MURA_Dataset(test_df,
                                data_path=DATA_PATH,
                                load_mask=True,
                                load_semilabels=True,
                                output_size=img_size,
                                data_augmentation=False)
    # print info to logger
    logger.info(f'Train fraction : {train_frac:.0%}')
    logger.info(f'Fraction knonw normal : {ratio_known_normal:.0%}')
    logger.info(f'Fraction known abnormal : {ratio_known_abnormal:.0%}')
    logger.info('Split Summary \n' + str(spliter.print_stat(returnTable=True)))
    logger.info('Online preprocessing pipeline : \n' +
                str(train_dataset.transform) + '\n')

    ################################ Set Up ####################################
    # Set seed
    seed = seeds[seed_i]
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        logger.info(f'Set seed {seed_i+1:02}/{n_seeds:02} to {seed}')

    # set number of thread
    if n_thread > 0:
        torch.set_num_threads(n_thread)

    # print info in logger
    logger.info(f'Device : {device}')
    logger.info(f'Number of thread : {n_thread}')
    logger.info(
        f'Number of dataloader worker for Joint DeepSAD : {n_jobs_dataloader}'
        + '\n')

    ######################### Networks Initialization ##########################
    net = AE_SVDD_Hybrid(pretrain_ResNetEnc=ae_pretrain,
                         output_channels=ae_out_size[0],
                         return_svdd_embed=True)
    net = net.to(device)

    # add info to logger
    logger.info(f'Network : {net.__class__.__name__}')
    logger.info(f'Autoencoder pretrained on ImageNet : {ae_pretrain}')
    logger.info(f'DeepSAD eta : {eta}')
    logger.info('Network architecture: \n' +
                summary_string(net, (1, img_size, img_size),
                               device=str(device),
                               batch_size=batch_size) + '\n')

    # initialization of the Model
    jointDeepSAD = JointDeepSAD(net, eta=eta)

    if model_path_to_load:
        jointDeepSAD.load_model(model_path_to_load, map_location=device)
        logger.info(f'Model Loaded from {model_path_to_load}' + '\n')

    ################################ Training ##################################
    # add parameter info
    logger.info(f'Joint DeepSAD number of epoch : {n_epoch}')
    logger.info(
        f'Joint DeepSAD number of pretraining epoch: {n_epoch_pretrain}')
    logger.info(f'Joint DeepSAD learning rate : {lr}')
    logger.info(f'Joint DeepSAD learning rate milestone : {lr_milestone}')
    logger.info(f'Joint DeepSAD weight_decay : {weight_decay}')
    logger.info(f'Joint DeepSAD optimizer : Adam')
    logger.info(f'Joint DeepSAD batch_size {batch_size}')
    logger.info(
        f'Joint DeepSAD number of dataloader worker : {n_jobs_dataloader}')
    logger.info(
        f'Joint DeepSAD criterion weighting : {criterion_weight[0]} Reconstruction loss + {criterion_weight[1]} SVDD embdedding loss'
        + '\n')

    # train DeepSAD
    jointDeepSAD.train(train_dataset,
                       lr=lr,
                       n_epoch=n_epoch,
                       n_epoch_pretrain=n_epoch_pretrain,
                       lr_milestone=lr_milestone,
                       batch_size=batch_size,
                       weight_decay=weight_decay,
                       device=device,
                       n_jobs_dataloader=n_jobs_dataloader,
                       print_batch_progress=print_batch_progress,
                       criterion_weight=criterion_weight)

    # validate DeepSAD
    jointDeepSAD.validate(valid_dataset,
                          device=device,
                          n_jobs_dataloader=n_jobs_dataloader,
                          print_batch_progress=print_batch_progress,
                          criterion_weight=criterion_weight)

    # test DeepSAD
    jointDeepSAD.test(test_dataset,
                      device=device,
                      n_jobs_dataloader=n_jobs_dataloader,
                      print_batch_progress=print_batch_progress,
                      criterion_weight=criterion_weight)

    # save results
    jointDeepSAD.save_results(OUTPUT_PATH +
                              f'results/JointDeepSAD_results_{seed_i+1}.json')
    logger.info('Test results saved at ' + OUTPUT_PATH +
                f'results/JointDeepSAD_results_{seed_i+1}.json' + '\n')

    # save model
    jointDeepSAD.save_model(OUTPUT_PATH +
                            f'model/JointDeepSAD_model_{seed_i+1}.pt')
    logger.info('Model saved at ' + OUTPUT_PATH +
                f'model/JointDeepSAD_model_{seed_i+1}.pt')
示例#3
0
def main(config_path):
    """
    Train a DSAD on the MURA dataset using a SimCLR pretraining.
    """
    # Load config file
    cfg = Config(settings=None)
    cfg.load_config(config_path)

    # Get path to output
    OUTPUT_PATH = cfg.settings['PATH']['OUTPUT'] + cfg.settings[
        'Experiment_Name'] + datetime.today().strftime('%Y_%m_%d_%Hh%M') + '/'
    # make output dir
    if not os.path.isdir(OUTPUT_PATH + 'models/'):
        os.makedirs(OUTPUT_PATH + 'model/', exist_ok=True)
    if not os.path.isdir(OUTPUT_PATH + 'results/'):
        os.makedirs(OUTPUT_PATH + 'results/', exist_ok=True)
    if not os.path.isdir(OUTPUT_PATH + 'logs/'):
        os.makedirs(OUTPUT_PATH + 'logs/', exist_ok=True)

    for seed_i, seed in enumerate(cfg.settings['seeds']):
        ############################### Set Up #################################
        # initialize logger
        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger()
        try:
            logger.handlers[1].stream.close()
            logger.removeHandler(logger.handlers[1])
        except IndexError:
            pass
        logger.setLevel(logging.INFO)
        formatter = logging.Formatter(
            '%(asctime)s | %(levelname)s | %(message)s')
        log_file = OUTPUT_PATH + 'logs/' + f'log_{seed_i+1}.txt'
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

        # print path
        logger.info(f"Log file : {log_file}")
        logger.info(f"Data path : {cfg.settings['PATH']['DATA']}")
        logger.info(f"Outputs path : {OUTPUT_PATH}" + "\n")

        # Set seed
        if seed != -1:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            logger.info(
                f"Set seed {seed_i+1:02}/{len(cfg.settings['seeds']):02} to {seed}"
            )

        # set number of thread
        if cfg.settings['n_thread'] > 0:
            torch.set_num_threads(cfg.settings['n_thread'])

        # check if GPU available
        cfg.settings['device'] = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        # Print technical info in logger
        logger.info(f"Device : {cfg.settings['device']}")
        logger.info(f"Number of thread : {cfg.settings['n_thread']}")

        ############################### Split Data #############################
        # Load data informations
        df_info = pd.read_csv(cfg.settings['PATH']['DATA_INFO'])
        df_info = df_info.drop(df_info.columns[0], axis=1)
        # remove low contrast images (all black)
        df_info = df_info[df_info.low_contrast == 0]

        # Train Validation Test Split
        spliter = MURA_TrainValidTestSplitter(
            df_info,
            train_frac=cfg.settings['Split']['train_frac'],
            ratio_known_normal=cfg.settings['Split']['known_normal'],
            ratio_known_abnormal=cfg.settings['Split']['known_abnormal'],
            random_state=42)
        spliter.split_data(verbose=False)
        train_df = spliter.get_subset('train')
        valid_df = spliter.get_subset('valid')
        test_df = spliter.get_subset('test')

        # print info to logger
        for key, value in cfg.settings['Split'].items():
            logger.info(f"Split param {key} : {value}")
        logger.info("Split Summary \n" +
                    str(spliter.print_stat(returnTable=True)))

        ############################# Build Model  #############################
        # make networks
        net_CLR = SimCLR_net(
            MLP_Neurons_layer=cfg.settings['SimCLR']['MLP_head'])
        net_CLR = net_CLR.to(cfg.settings['device'])
        net_DSAD = SimCLR_net(
            MLP_Neurons_layer=cfg.settings['DSAD']['MLP_head'])
        net_DSAD = net_DSAD.to(cfg.settings['device'])
        # print network architecture
        net_architecture = summary_string(
            net_CLR, (1, cfg.settings['Split']['img_size'],
                      cfg.settings['Split']['img_size']),
            batch_size=cfg.settings['SimCLR']['batch_size'],
            device=str(cfg.settings['device']))
        logger.info("SimCLR net architecture: \n" + net_architecture + '\n')
        net_architecture = summary_string(
            net_DSAD, (1, cfg.settings['Split']['img_size'],
                       cfg.settings['Split']['img_size']),
            batch_size=cfg.settings['DSAD']['batch_size'],
            device=str(cfg.settings['device']))
        logger.info("DSAD net architecture: \n" + net_architecture + '\n')

        # make model
        clr_DSAD = SimCLR_DSAD(net_CLR,
                               net_DSAD,
                               tau=cfg.settings['SimCLR']['tau'],
                               eta=cfg.settings['DSAD']['eta'])

        ############################# Train SimCLR #############################
        # make datasets
        train_dataset_CLR = MURADataset_SimCLR(
            train_df,
            data_path=cfg.settings['PATH']['DATA'],
            output_size=cfg.settings['Split']['img_size'],
            mask_img=True)
        valid_dataset_CLR = MURADataset_SimCLR(
            valid_df,
            data_path=cfg.settings['PATH']['DATA'],
            output_size=cfg.settings['Split']['img_size'],
            mask_img=True)
        test_dataset_CLR = MURADataset_SimCLR(
            test_df,
            data_path=cfg.settings['PATH']['DATA'],
            output_size=cfg.settings['Split']['img_size'],
            mask_img=True)

        logger.info("SimCLR Online preprocessing pipeline : \n" +
                    str(train_dataset_CLR.transform) + "\n")

        # Load model if required
        if cfg.settings['SimCLR']['model_path_to_load']:
            clr_DSAD.load_repr_net(
                cfg.settings['SimCLR']['model_path_to_load'],
                map_location=cfg.settings['device'])
            logger.info(
                f"SimCLR Model Loaded from {cfg.settings['SimCLR']['model_path_to_load']}"
                + "\n")

        # print Train parameters
        for key, value in cfg.settings['SimCLR'].items():
            logger.info(f"SimCLR {key} : {value}")

        # Train SimCLR
        clr_DSAD.train_SimCLR(
            train_dataset_CLR,
            valid_dataset=None,
            n_epoch=cfg.settings['SimCLR']['n_epoch'],
            batch_size=cfg.settings['SimCLR']['batch_size'],
            lr=cfg.settings['SimCLR']['lr'],
            weight_decay=cfg.settings['SimCLR']['weight_decay'],
            lr_milestones=cfg.settings['SimCLR']['lr_milestone'],
            n_job_dataloader=cfg.settings['SimCLR']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'])

        # Evaluate SimCLR to get embeddings
        clr_DSAD.evaluate_SimCLR(
            valid_dataset_CLR,
            batch_size=cfg.settings['SimCLR']['batch_size'],
            n_job_dataloader=cfg.settings['SimCLR']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            set='valid')

        clr_DSAD.evaluate_SimCLR(
            test_dataset_CLR,
            batch_size=cfg.settings['SimCLR']['batch_size'],
            n_job_dataloader=cfg.settings['SimCLR']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            set='test')

        # save repr net
        clr_DSAD.save_repr_net(OUTPUT_PATH + f'model/SimCLR_net_{seed_i+1}.pt')
        logger.info("SimCLR model saved at " + OUTPUT_PATH +
                    f"model/SimCLR_net_{seed_i+1}.pt")

        # save Results
        clr_DSAD.save_results(OUTPUT_PATH + f'results/results_{seed_i+1}.json')
        logger.info("Results saved at " + OUTPUT_PATH +
                    f"results/results_{seed_i+1}.json")

        ######################## Transfer Encoder Weight #######################

        clr_DSAD.transfer_encoder()

        ############################## Train DSAD ##############################
        # make dataset
        train_dataset_AD = MURA_Dataset(
            train_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'])
        valid_dataset_AD = MURA_Dataset(
            valid_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'])
        test_dataset_AD = MURA_Dataset(
            test_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'])

        logger.info("DSAD Online preprocessing pipeline : \n" +
                    str(train_dataset_AD.transform) + "\n")

        # Load model if required
        if cfg.settings['DSAD']['model_path_to_load']:
            clr_DSAD.load_AD(cfg.settings['DSAD']['model_path_to_load'],
                             map_location=cfg.settings['device'])
            logger.info(
                f"DSAD Model Loaded from {cfg.settings['DSAD']['model_path_to_load']} \n"
            )

        # print Train parameters
        for key, value in cfg.settings['DSAD'].items():
            logger.info(f"DSAD {key} : {value}")

        # Train DSAD
        clr_DSAD.train_AD(
            train_dataset_AD,
            valid_dataset=valid_dataset_AD,
            n_epoch=cfg.settings['DSAD']['n_epoch'],
            batch_size=cfg.settings['DSAD']['batch_size'],
            lr=cfg.settings['DSAD']['lr'],
            weight_decay=cfg.settings['DSAD']['weight_decay'],
            lr_milestone=cfg.settings['DSAD']['lr_milestone'],
            n_job_dataloader=cfg.settings['DSAD']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'])
        logger.info('--- Validation')
        clr_DSAD.evaluate_AD(
            valid_dataset_AD,
            batch_size=cfg.settings['DSAD']['batch_size'],
            n_job_dataloader=cfg.settings['DSAD']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            set='valid')
        logger.info('--- Test')
        clr_DSAD.evaluate_AD(
            test_dataset_AD,
            batch_size=cfg.settings['DSAD']['batch_size'],
            n_job_dataloader=cfg.settings['DSAD']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            set='test')

        # save DSAD
        clr_DSAD.save_AD(OUTPUT_PATH + f'model/DSAD_{seed_i+1}.pt')
        logger.info("model saved at " + OUTPUT_PATH +
                    f"model/DSAD_{seed_i+1}.pt")

        ########################## Save Results ################################
        # save Results
        clr_DSAD.save_results(OUTPUT_PATH + f'results/results_{seed_i+1}.json')
        logger.info("Results saved at " + OUTPUT_PATH +
                    f"results/results_{seed_i+1}.json")

    # save config file
    cfg.settings['device'] = str(cfg.settings['device'])
    cfg.save_config(OUTPUT_PATH + 'config.json')
    logger.info("Config saved at " + OUTPUT_PATH + "config.json")
示例#4
0
def main(seed_i):
    """
    Implementation of the unsupervised DROCC model proposed by Goyal et al (2020).
    The model uses a binary classifier with a ResNet18 backbone. The output is a
    logit that serves as anomaly score. The loss is computed as Binary Cross
    Entropy loss with logit. The training consist of few epoch trained only with
    the normal samples. Then each epoch starts with the generation of adversarial
    examples. The adversarial search is performed only on normal samples as we
    want the network to learn the manifold of normal samples. It uses a slightly
    modifided projection gradient descent algorithm. Then the samples and
    adversarial samples are passed through the network similarly to a standard
    classification task. Note that the input samples are masked with the mask
    generated in the preprocesing steps.
    """
    # initialize logger
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    try:
        logger.handlers[1].stream.close()
        logger.removeHandler(logger.handlers[1])
    except IndexError:
        pass
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
    log_file = OUTPUT_PATH + 'logs/' + f'log_{seed_i+1}.txt'
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # print path and main docstring with experiment summary
    logger.info('Brief summary of experiment : \n' + main.__doc__)
    logger.info(f'Log file : {log_file}')
    logger.info(f'Data path : {DATA_PATH}')
    logger.info(f'Outputs path : {OUTPUT_PATH}' + '\n')

    ############################## Make datasets ###############################
    # load data_info
    df_info = pd.read_csv(DATA_INFO_PATH)
    df_info = df_info.drop(df_info.columns[0], axis=1)
    # remove low contrast images (all black)
    df_info = df_info[df_info.low_contrast == 0]

    # Train Validation Test Split
    spliter = MURA_TrainValidTestSplitter(
        df_info,
        train_frac=train_frac,
        ratio_known_normal=ratio_known_normal,
        ratio_known_abnormal=ratio_known_abnormal,
        random_state=42)
    spliter.split_data(verbose=False)
    train_df = spliter.get_subset('train')
    valid_df = spliter.get_subset('valid')
    test_df = spliter.get_subset('test')
    # make datasets
    train_dataset = MURA_Dataset(train_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size)
    valid_dataset = MURA_Dataset(valid_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size)
    test_dataset = MURA_Dataset(test_df,
                                data_path=DATA_PATH,
                                load_mask=True,
                                load_semilabels=True,
                                output_size=img_size)
    # print info to logger
    logger.info(f'Train fraction : {train_frac:.0%}')
    logger.info(f'Fraction knonw normal : {ratio_known_normal:.0%}')
    logger.info(f'Fraction known abnormal : {ratio_known_abnormal:.0%}')
    logger.info('Split Summary \n' + str(spliter.print_stat(returnTable=True)))
    logger.info('Online preprocessing pipeline : \n' +
                str(train_dataset.transform) + '\n')

    ################################ Set Up ####################################
    # Set seed
    seed = seeds[seed_i]
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        logger.info(f'Set seed {seed_i+1:02}/{n_seeds:02} to {seed}')

    # set number of thread
    if n_thread > 0:
        torch.set_num_threads(n_thread)

    # print info in logger
    logger.info(f'Device : {device}')
    logger.info(f'Number of thread : {n_thread}')
    logger.info(
        f'Number of dataloader worker for {Experiment_Name} : {n_jobs_dataloader}'
        + '\n')

    ######################### Networks Initialization ##########################
    net = ResNet18_binary(pretrained=pretrain)
    net = net.to(device)

    # add info to logger
    logger.info(f'Network : {net.__class__.__name__}')
    logger.info(f'ResNet18 pretrained on ImageNet : {pretrain}')
    logger.info('Network architecture: \n' +
                summary_string(net, (1, img_size, img_size),
                               device=str(device),
                               batch_size=batch_size) + '\n')

    # initialization of the Model
    drocc = DROCC(net, r)

    if model_path_to_load:
        drocc.load_model(model_path_to_load, map_location=device)
        logger.info(f'Model Loaded from {model_path_to_load}' + '\n')

    ################################ Training ##################################
    # add parameter info
    logger.info(f'{Experiment_Name} radius r : {r}')
    logger.info(f'{Experiment_Name} gamma : {gamma}')
    logger.info(f'{Experiment_Name} adversarial importance mu : {mu}')
    logger.info(f'{Experiment_Name} number of initial epoch : {n_epoch_init}')
    logger.info(f'{Experiment_Name} number of epoch : {n_epoch}')
    logger.info(
        f'{Experiment_Name} number of adversarial search epoch: {n_epoch_adv}')
    logger.info(f'{Experiment_Name} learning rate : {lr}')
    logger.info(
        f'{Experiment_Name} adversarial search learning rate : {lr_adv}')
    logger.info(f'{Experiment_Name} learning rate milestone : {lr_milestone}')
    logger.info(f'{Experiment_Name} weight_decay : {weight_decay}')
    logger.info(f'{Experiment_Name} optimizer : Adam')
    logger.info(f'{Experiment_Name} batch_size {batch_size}')
    logger.info(
        f'{Experiment_Name} number of dataloader worker : {n_jobs_dataloader}')

    # train DROCC
    drocc.train(train_dataset,
                gamma=gamma,
                mu=mu,
                lr=lr,
                lr_adv=lr_adv,
                lr_milestone=lr_milestone,
                weight_decay=weight_decay,
                n_epoch=n_epoch,
                n_epoch_init=n_epoch_init,
                n_epoch_adv=n_epoch_adv,
                batch_size=batch_size,
                device=device,
                n_jobs_dataloader=n_jobs_dataloader,
                LFOC=use_LFOC,
                print_batch_progress=print_batch_progress)

    # validate DROCC
    drocc.validate(valid_dataset,
                   device=device,
                   n_jobs_dataloader=n_jobs_dataloader,
                   print_batch_progress=print_batch_progress)

    # test DROCC
    drocc.test(test_dataset,
               device=device,
               n_jobs_dataloader=n_jobs_dataloader,
               print_batch_progress=print_batch_progress)

    # save results
    drocc.save_results(OUTPUT_PATH +
                       f'results/{Experiment_Name}_results_{seed_i+1}.json')
    logger.info('Test results saved at ' + OUTPUT_PATH +
                f'results/{Experiment_Name}_results_{seed_i+1}.json' + '\n')

    # save model
    drocc.save_model(OUTPUT_PATH +
                     f'model/{Experiment_Name}_model_{seed_i+1}.pt')
    logger.info('Model saved at ' + OUTPUT_PATH +
                f'model/{Experiment_Name}_model_{seed_i+1}.pt')
def main(seed_i):
    """
    Extension of the deep multi-sphere SVDD to semi-supervised settings inpired
    from the DSAD of Ruff et al. (2020).

    MSAD loss changed to use the sqrt(dist) for normal samples and 1/(dist^2)
    for abnormal samples. The network is pretrained for longer (30 epochs) to get
    a better KMeans initialization. Anomaly score is dist - R
    """
    # initialize logger
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    try:
        logger.handlers[1].stream.close()
        logger.removeHandler(logger.handlers[1])
    except IndexError:
        pass
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
    log_file = OUTPUT_PATH + 'logs/' + f'log_{seed_i+1}.txt'
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # print path and main docstring with experiment summary
    logger.info('Brief summary of experiment : \n' + main.__doc__)
    logger.info(f'Log file : {log_file}')
    logger.info(f'Data path : {DATA_PATH}')
    logger.info(f'Outputs path : {OUTPUT_PATH}' + '\n')

    ############################## Make datasets ###############################
    # load data_info
    df_info = pd.read_csv(DATA_INFO_PATH)
    df_info = df_info.drop(df_info.columns[0], axis=1)
    # remove low contrast images (all black)
    df_info = df_info[df_info.low_contrast == 0]

    # Train Validation Test Split
    spliter = MURA_TrainValidTestSplitter(
        df_info,
        train_frac=train_frac,
        ratio_known_normal=ratio_known_normal,
        ratio_known_abnormal=ratio_known_abnormal,
        random_state=42)
    spliter.split_data(verbose=False)
    train_df = spliter.get_subset('train')
    valid_df = spliter.get_subset('valid')
    test_df = spliter.get_subset('test')
    # make datasets
    train_dataset = MURA_Dataset(train_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size)
    valid_dataset = MURA_Dataset(valid_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size)
    test_dataset = MURA_Dataset(test_df,
                                data_path=DATA_PATH,
                                load_mask=True,
                                load_semilabels=True,
                                output_size=img_size)
    # print info to logger
    logger.info(f'Train fraction : {train_frac:.0%}')
    logger.info(f'Fraction knonw normal : {ratio_known_normal:.0%}')
    logger.info(f'Fraction known abnormal : {ratio_known_abnormal:.0%}')
    logger.info('Split Summary \n' + str(spliter.print_stat(returnTable=True)))
    logger.info('Online preprocessing pipeline : \n' +
                str(train_dataset.transform) + '\n')

    ################################ Set Up ####################################
    # Set seed
    seed = seeds[seed_i]
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        logger.info(f'Set seed {seed_i+1:02}/{n_seeds:02} to {seed}')

    # set number of thread
    if n_thread > 0:
        torch.set_num_threads(n_thread)

    # print info in logger
    logger.info(f'Device : {device}')
    logger.info(f'Number of thread : {n_thread}')
    logger.info(
        f'Number of dataloader worker for {Experiment_Name} : {n_jobs_dataloader}'
        + '\n')

    ######################### Networks Initialization ##########################
    net = AE_SVDD_Hybrid(pretrain_ResNetEnc=ae_pretrain,
                         output_channels=ae_out_size[0],
                         return_svdd_embed=True)
    net = net.to(device)

    # add info to logger
    logger.info(f'Network : {net.__class__.__name__}')
    logger.info(f'Autoencoder pretrained on ImageNet : {ae_pretrain}')
    logger.info('Network architecture: \n' +
                summary_string(net, (1, img_size, img_size),
                               device=str(device),
                               batch_size=batch_size) + '\n')

    # initialization of the Model
    jointDMSAD = joint_DMSAD(net, eta=eta, gamma=gamma)

    if model_path_to_load:
        jointDMSAD.load_model(model_path_to_load, map_location=device)
        logger.info(f'Model Loaded from {model_path_to_load}' + '\n')

    ################################ Training ##################################
    # add parameter info
    logger.info(f'{Experiment_Name} eta : {eta}')
    logger.info(f'{Experiment_Name} gamma : {gamma}')
    logger.info(f'{Experiment_Name} number of epoch : {n_epoch}')
    logger.info(
        f'{Experiment_Name} number of pretraining epoch: {n_epoch_pretrain}')
    logger.info(
        f'{Experiment_Name} number of initial hypersphere: {n_sphere_init}')
    logger.info(f'{Experiment_Name} learning rate : {lr}')
    logger.info(f'{Experiment_Name} learning rate milestones : {lr_milestone}')
    logger.info(f'{Experiment_Name} weight_decay : {weight_decay}')
    logger.info(f'{Experiment_Name} optimizer : Adam')
    logger.info(f'{Experiment_Name} batch_size {batch_size}')
    logger.info(
        f'{Experiment_Name} number of dataloader worker : {n_jobs_dataloader}')
    logger.info(
        f'{Experiment_Name} criterion weighting : {criterion_weight[0]} Reconstruction loss + {criterion_weight[1]} MSAD embdedding loss'
    )
    logger.info(
        f'{Experiment_Name} reset scaling epoch : {reset_scaling_epoch}')

    # train DMSAD
    jointDMSAD.train(train_dataset,
                     valid_dataset=valid_dataset,
                     n_sphere_init=n_sphere_init,
                     n_epoch=n_epoch,
                     n_epoch_pretrain=n_epoch_pretrain,
                     lr=lr,
                     weight_decay=weight_decay,
                     lr_milestone=lr_milestone,
                     criterion_weight=criterion_weight,
                     reset_scaling_epoch=reset_scaling_epoch,
                     batch_size=batch_size,
                     n_jobs_dataloader=n_jobs_dataloader,
                     device=device,
                     print_batch_progress=print_batch_progress)

    # validate DMSAD
    jointDMSAD.validate(valid_dataset,
                        batch_size=batch_size,
                        n_jobs_dataloader=n_jobs_dataloader,
                        criterion_weight=criterion_weight,
                        device=device,
                        print_batch_progress=print_batch_progress)

    # test DMSAD
    jointDMSAD.test(test_dataset,
                    batch_size=batch_size,
                    n_jobs_dataloader=n_jobs_dataloader,
                    criterion_weight=criterion_weight,
                    device=device,
                    print_batch_progress=print_batch_progress)

    # save results
    jointDMSAD.save_results(
        OUTPUT_PATH + f'results/{Experiment_Name}_results_{seed_i+1}.json')
    logger.info('Test results saved at ' + OUTPUT_PATH +
                f'results/{Experiment_Name}_results_{seed_i+1}.json' + '\n')

    # save model
    jointDMSAD.save_model(OUTPUT_PATH +
                          f'model/{Experiment_Name}_model_{seed_i+1}.pt')
    logger.info('Model saved at ' + OUTPUT_PATH +
                f'model/{Experiment_Name}_model_{seed_i+1}.pt')
def main(seed_i):
    """
    Train a DeepSAD model following Lukas Ruff et al. (2019) work and code structure
    adapted to the MURA dataset (preprocessing inspired from the work of Davletshina
    et al. (2020)). The DeepSAD network structure is a ResNet18 Encoder. The Encoder
    is pretrained via Autoencoder training. The Autoencoder itself is not initialized
    with weights trained on ImageNet. The best threshold on the scores is defined
    using the validation set as the one maximizing the F1-score. The ROC AUC is
    reported on the test and validation set.
    """
    # initialize logger
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    try:
        logger.handlers[1].stream.close()
        logger.removeHandler(logger.handlers[1])
    except IndexError:
        pass
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
    log_file = OUTPUT_PATH + 'logs/' + f'log_{seed_i+1}.txt'
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # print path
    logger.info('Brief summary of experiment : \n' + main.__doc__)
    if Note is not None: logger.info(Note + '\n')
    logger.info(f'Log file : {log_file}')
    logger.info(f'Data path : {DATA_PATH}')
    logger.info(f'Outputs path : {OUTPUT_PATH}' + '\n')

    ############################## Make datasets ###############################
    # load data_info
    df_info = pd.read_csv(DATA_INFO_PATH)
    df_info = df_info.drop(df_info.columns[0], axis=1)
    # remove low contrast images (all black)
    df_info = df_info[df_info.low_contrast == 0]

    # Train Validation Test Split
    spliter = MURA_TrainValidTestSplitter(
        df_info,
        train_frac=train_frac,
        ratio_known_normal=ratio_known_normal,
        ratio_known_abnormal=ratio_known_abnormal,
        random_state=42)
    spliter.split_data(verbose=False)
    train_df = spliter.get_subset('train')
    valid_df = spliter.get_subset('valid')
    test_df = spliter.get_subset('test')
    # make datasets
    train_dataset = MURA_Dataset(train_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size)
    valid_dataset = MURA_Dataset(valid_df,
                                 data_path=DATA_PATH,
                                 load_mask=True,
                                 load_semilabels=True,
                                 output_size=img_size)
    test_dataset = MURA_Dataset(test_df,
                                data_path=DATA_PATH,
                                load_mask=True,
                                load_semilabels=True,
                                output_size=img_size)
    # print info to logger
    logger.info(f'Train fraction : {train_frac:.0%}')
    logger.info(f'Fraction knonw normal : {ratio_known_normal:.0%}')
    logger.info(f'Fraction known abnormal : {ratio_known_abnormal:.0%}')
    logger.info('Split Summary \n' + str(spliter.print_stat(returnTable=True)))
    logger.info('Online preprocessing pipeline : \n' +
                str(train_dataset.transform) + '\n')

    ################################ Set Up ####################################
    # Set seed
    seed = seeds[seed_i]
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        logger.info(f'Set seed {seed_i+1:02}/{n_seeds:02} to {seed}')

    # set number of thread
    if n_thread > 0:
        torch.set_num_threads(n_thread)

    # print info in logger
    logger.info(f'Device : {device}')
    logger.info(f'Number of thread : {n_thread}')
    logger.info(
        f'Number of dataloader worker for DeepSAD : {n_jobs_dataloader}')
    logger.info(
        f'Autoencoder number of dataloader worker : {ae_n_jobs_dataloader}' +
        '\n')

    ######################### Networks Initialization ##########################
    ae_net = AE_ResNet18(embed_dim=embed_dim,
                         pretrain_ResNetEnc=ae_pretrain,
                         output_size=ae_out_size)
    ae_net = ae_net.to(device)
    net = ResNet18_Encoder(embed_dim=embed_dim, pretrained=False)
    net = net.to(device)

    # initialization of the Model
    deepSAD = DeepSAD(net, ae_net=ae_net, eta=eta)
    # add info to logger
    logger.info(f'Autoencoder : {ae_net.__class__.__name__}')
    logger.info(f'Encoder : {net.__class__.__name__}')
    logger.info(f'Embedding dimension : {embed_dim}')
    logger.info(f'Autoencoder pretrained on ImageNet : {ae_pretrain}')
    logger.info(f'DeepSAD eta : {eta}')
    logger.info(
        'Autoencoder architecture: \n' +
        summary_string(ae_net, (1, img_size, img_size), device=str(device)) +
        '\n')

    if model_path_to_load:
        deepSAD.load_model(model_path_to_load,
                           load_ae=True,
                           map_location=device)
        logger.info(f'Model Loaded from {model_path_to_load}' + '\n')

    ############################## Pretraining #################################
    logger.info(f'Pretraining DeepSAD via Autoencoder : {pretrain}')
    if pretrain:
        # add parameter info
        logger.info(f'Autoencoder number of epoch : {ae_n_epoch}')
        logger.info(f'Autoencoder learning rate : {ae_lr}')
        logger.info(f'Autoencoder learning rate milestone : {ae_lr_milestone}')
        logger.info(f'Autoencoder weight_decay : {ae_weight_decay}')
        logger.info(f'Autoencoder optimizer : Adam')
        logger.info(f'Autoencoder batch_size {ae_batch_size}' + '\n')
        # train AE
        deepSAD.pretrain(train_dataset,
                         valid_dataset,
                         test_dataset,
                         lr=ae_lr,
                         n_epoch=ae_n_epoch,
                         lr_milestone=ae_lr_milestone,
                         batch_size=ae_batch_size,
                         weight_decay=ae_weight_decay,
                         device=device,
                         n_jobs_dataloader=ae_n_jobs_dataloader,
                         print_batch_progress=print_batch_progress)

    ################################ Training ##################################
    # add parameter info
    logger.info(f'DeepSAD number of epoch : {n_epoch}')
    logger.info(f'DeepSAD learning rate : {lr}')
    logger.info(f'DeepSAD learning rate milestone : {lr_milestone}')
    logger.info(f'DeepSAD weight_decay : {weight_decay}')
    logger.info(f'DeepSAD optimizer : Adam')
    logger.info(f'DeepSAD batch_size {batch_size}')
    logger.info(f'DeepSAD number of dataloader worker : {n_jobs_dataloader}' +
                '\n')

    # train DeepSAD
    deepSAD.train(train_dataset,
                  lr=lr,
                  n_epoch=n_epoch,
                  lr_milestone=lr_milestone,
                  batch_size=batch_size,
                  weight_decay=weight_decay,
                  device=device,
                  n_jobs_dataloader=n_jobs_dataloader,
                  print_batch_progress=print_batch_progress)

    # validate DeepSAD
    deepSAD.validate(valid_dataset,
                     device=device,
                     n_jobs_dataloader=n_jobs_dataloader,
                     print_batch_progress=print_batch_progress)

    # test DeepSAD
    deepSAD.test(test_dataset,
                 device=device,
                 n_jobs_dataloader=n_jobs_dataloader,
                 print_batch_progress=print_batch_progress)

    # save results
    deepSAD.save_results(OUTPUT_PATH +
                         f'results/DeepSAD_results_{seed_i+1}.json')
    logger.info('Test results saved at ' + OUTPUT_PATH +
                f'results/DeepSAD_results_{seed_i+1}.json' + '\n')
    # save model
    deepSAD.save_model(OUTPUT_PATH + f'model/DeepSAD_model_{seed_i+1}.pt')
    logger.info('Model saved at ' + OUTPUT_PATH +
                f'model/DeepSAD_model_{seed_i+1}.pt')
示例#7
0
def main(seed_i):
    """
    Implementation of the unsupervised ARAE model proposed by Salehi et al (2020).
    This unsupervised method apply a projected gradient descent algorithm to find
    a more meaningful lattent space for the autoencoder. The encoder composed of
    a ResNet18 encoder. The decoder is composed of a mirrored ResNet18. The latent
    space has dimension (16,16,512).
    """
    # initialize logger
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    try:
        logger.handlers[1].stream.close()
        logger.removeHandler(logger.handlers[1])
    except IndexError:
        pass
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
    log_file = OUTPUT_PATH + 'logs/' + f'log_{seed_i+1}.txt'
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # print path and main docstring with experiment summary
    logger.info('Brief summary of experiment : \n' + main.__doc__)
    logger.info(f'Log file : {log_file}')
    logger.info(f'Data path : {DATA_PATH}')
    logger.info(f'Outputs path : {OUTPUT_PATH}' + '\n')

    ############################## Make datasets ###############################
    # load data_info
    df_info = pd.read_csv(DATA_INFO_PATH)
    df_info = df_info.drop(df_info.columns[0], axis=1)
    # remove low contrast images (all black)
    df_info = df_info[df_info.low_contrast == 0]
    # keep only hands
    df_info = df_info[df_info.body_part == 'HAND']

    # Train Validation Test Split
    spliter = MURA_TrainValidTestSplitter(df_info, train_frac=train_frac,
                                          ratio_known_normal=ratio_known_normal,
                                          ratio_known_abnormal=ratio_known_abnormal, random_state=42)
    spliter.split_data(verbose=False)
    train_df = spliter.get_subset('train')
    valid_df = spliter.get_subset('valid')
    test_df = spliter.get_subset('test')
    # make datasets
    train_dataset = MURA_Dataset(train_df, data_path=DATA_PATH, load_mask=True,
                                 load_semilabels=True, output_size=img_size)
    valid_dataset = MURA_Dataset(valid_df, data_path=DATA_PATH, load_mask=True,
                                 load_semilabels=True, output_size=img_size)
    test_dataset = MURA_Dataset(test_df, data_path=DATA_PATH, load_mask=True,
                                 load_semilabels=True, output_size=img_size)
    # print info to logger
    logger.info(f'Train fraction : {train_frac:.0%}')
    logger.info(f'Fraction knonw normal : {ratio_known_normal:.0%}')
    logger.info(f'Fraction known abnormal : {ratio_known_abnormal:.0%}')
    logger.info('Split Summary \n' + str(spliter.print_stat(returnTable=True)))
    logger.info('Online preprocessing pipeline : \n' + str(train_dataset.transform) + '\n')

    ################################ Set Up ####################################
    # Set seed
    seed = seeds[seed_i]
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        logger.info(f'Set seed {seed_i+1:02}/{n_seeds:02} to {seed}')

    # set number of thread
    if n_thread > 0:
        torch.set_num_threads(n_thread)

    # print info in logger
    logger.info(f'Device : {device}')
    logger.info(f'Number of thread : {n_thread}')
    logger.info(f'Number of dataloader worker for {Experiment_Name} : {n_jobs_dataloader}' + '\n')

    ######################### Networks Initialization ##########################
    net = AE_ResNet18(pretrain_ResNetEnc=pretrain, output_channels=ae_output_size[0])
    net = net.to(device)

    # add info to logger
    logger.info(f'Network : {net.__class__.__name__}')
    logger.info(f'ResNet18 pretrained on ImageNet : {pretrain}')
    logger.info('Network architecture: \n' + summary_string(net, (1, img_size, img_size), device=str(device), batch_size=batch_size) + '\n')

    # initialization of the Model
    arae = ARAE(net, gamma, epsilon)

    if model_path_to_load:
        arae.load_model(model_path_to_load, map_location=device)
        logger.info(f'Model Loaded from {model_path_to_load}' + '\n')

    ################################ Training ##################################
    # add parameter info
    logger.info(f'{Experiment_Name} epsilon : {epsilon}')
    logger.info(f'{Experiment_Name} adversarial importance gamma : {gamma}')
    logger.info(f'{Experiment_Name} number of epoch : {n_epoch}')
    logger.info(f'{Experiment_Name} number of adversarial search epoch: {n_epoch_adv}')
    logger.info(f'{Experiment_Name} learning rate : {lr}')
    logger.info(f'{Experiment_Name} adversarial search learning rate : {lr_adv}')
    method_adv = 'PGD' if use_PGD else 'FGSM'
    logger.info(f'{Experiment_Name} adversarial search method : {method_adv}')
    logger.info(f'{Experiment_Name} learning rate milestone : {lr_milestone}')
    logger.info(f'{Experiment_Name} weight_decay : {weight_decay}')
    logger.info(f'{Experiment_Name} optimizer : Adam')
    logger.info(f'{Experiment_Name} batch_size {batch_size}')
    logger.info(f'{Experiment_Name} number of dataloader worker : {n_jobs_dataloader}')

    # train DROCC
    arae.train(train_dataset, lr=lr, lr_adv=lr_adv, lr_milestone=lr_milestone,
                weight_decay=weight_decay, n_epoch=n_epoch, n_epoch_adv=n_epoch_adv, use_PGD=use_PGD,
                batch_size=batch_size, device=device, n_jobs_dataloader=n_jobs_dataloader,
                print_batch_progress=print_batch_progress, valid_dataset=valid_dataset)

    # validate DROCC
    arae.validate(valid_dataset, device=device,
                   n_jobs_dataloader=n_jobs_dataloader,
                   print_batch_progress=print_batch_progress)

    # test DROCC
    arae.test(test_dataset, device=device,
               n_jobs_dataloader=n_jobs_dataloader,
               print_batch_progress=print_batch_progress)

    # save results
    arae.save_results(OUTPUT_PATH + f'results/{Experiment_Name}_results_{seed_i+1}.json')
    logger.info('Test results saved at ' + OUTPUT_PATH + f'results/{Experiment_Name}_results_{seed_i+1}.json' + '\n')

    # save model
    arae.save_model(OUTPUT_PATH + f'model/{Experiment_Name}_model_{seed_i+1}.pt')
    logger.info('Model saved at ' + OUTPUT_PATH + f'model/{Experiment_Name}_model_{seed_i+1}.pt')