Exemplo n.º 1
0
def infer_ma(hparams):
    global patients, patients_path, n_processes, model_dir, preprocessed_output_dir, temp_output_dir
    model_dir = hparams.model_dir
    training_start_time = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(training_start_time))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()
    # Parsing the number of CPU's used
    n_processes = int(hparams.threads)
    print("Number of CPU's used : ", n_processes)

    # PRINT PARSED HPARAMS
    print("\n\n")
    print("Model Dir               :", hparams.model_dir)
    print("Test CSV                :", hparams.test_csv)
    print("Number of channels      :", hparams.num_channels)
    print("Model Name              :", hparams.model)
    print("Modalities              :", hparams.modalities)
    print("Number of classes       :", hparams.num_classes)
    print("Base Filters            :", hparams.base_filters)
    print("Load Weights            :", hparams.weights)

    print("Generating Test csv")
    if not os.path.exists(hparams["results_dir"]):
        os.mkdir(hparams.results_dir)
    if not hparams.csv_provided == "True":
        print("Since CSV were not provided, we are gonna create for you")
        csv_creator_adv.generate_csv(
            hparams.test_dir,
            to_save=hparams.results_dir,
            mode=hparams.mode,
            ftype="test",
            modalities=hparams.modalities,
        )
        test_csv = os.path.join(hparams.results_dir, "test.csv")
    else:
        test_csv = hparams.test_csv

    n_processes = int(hparams.threads)
    model = fetch_model(
        hparams.model,
        int(hparams.num_modalities),
        int(hparams.num_classes),
        int(hparams.base_filters),
    )
    checkpoint = torch.load(str(hparams.weights))
    model.load_state_dict(checkpoint.model_state_dict)

    if hparams.device != "cpu":
        model.cuda()
    model.eval()

    test_df = pd.read_csv(test_csv)
    preprocessed_output_dir = os.path.join(hparams.model_dir, "preprocessed")
    os.makedirs(preprocessed_output_dir, exist_ok=True)
    patients = test_df.iloc[:, 0].astype(str)
    patients_path = test_df.iloc[:, 1]
    n_processes = int(hparams.threads)
    if len(patients) < n_processes:
        print("\n*********** WARNING ***********")
        print(
            "You are processing less number of patients as compared to the\n" +
            "threads provided, which means you are asking for more resources than \n"
            +
            "necessary which is not a great practice. Anyway, we have accounted for that \n"
            +
            "and reduced the number of threads to the maximum number of patients for \n"
            + "better resource management!\n")
        n_processes = len(patients)
    print("*" * 80)
    print("Intializing preprocessing")
    print("*" * 80)
    print("Initiating the CPU workload on %d threads.\n\n" % n_processes)
    print("Currently processing the following patients : ")
    START = time.time()
    pool = Pool(processes=n_processes)
    pool.map(preprocess_batch_works, range(n_processes))
    END = time.time()
    print("\n\n Preprocessing time taken : {} seconds".format(END - START))

    # Load the preprocessed patients to the dataloader
    print("*" * 80)
    print("Intializing Deep Neural Network")
    print("*" * 80)
    START = time.time()
    print("Initiating the GPU workload on CUDA threads.\n\n")
    print("Currently processing the following patients : ")
    preprocessed_data_dir = os.path.join(hparams.model_dir, "preprocessed")
    temp_output_dir = os.path.join(hparams.model_dir, "temp_output")
    os.makedirs(temp_output_dir, exist_ok=True)
    dataset_infer = VolSegDatasetInfer(preprocessed_data_dir)
    infer_loader = DataLoader(
        dataset_infer,
        batch_size=int(hparams.batch_size),
        shuffle=False,
        num_workers=int(hparams.threads),
        pin_memory=False,
    )
Exemplo n.º 2
0
def infer_multi_4(cfg, device, save_brain, weights):
    """
    Inference using multi modality network

    Parameters
    ----------
    cfg : string
        Location of the config file
    device : int/str
        device to be run on
    save_brain : int
        whether to save brain or not

    Returns
    -------
    None.

    """
    cfg = os.path.abspath(cfg)

    if os.path.isfile(cfg):
        params_df = pd.read_csv(
            cfg,
            sep=" = ",
            names=["param_name", "param_value"],
            comment="#",
            skip_blank_lines=True,
            engine="python",
        ).fillna(" ")
    else:
        print("Missing test_params.cfg file? Please give one!")
        sys.exit(0)
    params = {}
    params["weights"] = weights
    for i in range(params_df.shape[0]):
        params[params_df.iloc[i, 0]] = params_df.iloc[i, 1]
    start = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(start))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()

    print("Generating Test csv")
    if not os.path.exists(os.path.join(params["results_dir"])):
        os.mkdir(params["results_dir"])
    if not params["csv_provided"] == "True":
        print("Since CSV were not provided, we are gonna create for you")
        csv_creator_adv.generate_csv(
            params["test_dir"],
            to_save=params["results_dir"],
            mode=params["mode"],
            ftype="test",
            modalities=params["modalities"],
        )
        test_csv = os.path.join(params["results_dir"], "test.csv")
    else:
        test_csv = params["test_csv"]

    test_df = pd.read_csv(test_csv)

    model = fetch_model(
        params["model"],
        int(params["num_modalities"]),
        int(params["num_classes"]),
        int(params["base_filters"]),
    )
    if device != "cpu":
        model.cuda()

    checkpoint = torch.load(str(params["weights"]))
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    for patient in tqdm.tqdm(test_df.values):
        os.makedirs(os.path.join(params["results_dir"], patient[0]),
                    exist_ok=True)
        nmods = params["num_modalities"]
        stack = np.zeros([int(nmods), 128, 128, 128], dtype=np.float32)
        for i in range(int(nmods)):
            image_path = patient[i + 1]
            patient_nib = nib.load(image_path)
            image = patient_nib.get_fdata()
            image = preprocess_image(patient_nib)
            stack[i] = image
        stack = stack[np.newaxis, ...]
        stack = torch.FloatTensor(stack)
        if device != "cpu":
            image = stack.cuda()
        with torch.no_grad():
            output = model(image)
            output = output.cpu().numpy()[0][0]
            to_save = interpolate_image(output, (240, 240, 160))
            to_save = unpad_image(to_save)
            to_save[to_save >= 0.9] = 1
            to_save[to_save < 0.9] = 0
            for i in range(to_save.shape[2]):
                if np.any(to_save[:, :, i]):
                    to_save[:, :, i] = binary_fill_holes(to_save[:, :, i])
            to_save = postprocess_prediction(to_save).astype(np.uint8)
            to_save_mask = nib.Nifti1Image(to_save, patient_nib.affine)
            nib.save(
                to_save_mask,
                os.path.join(params["results_dir"], patient[0],
                             patient[0] + "_mask.nii.gz"),
            )
    print("Done with running the model.")
    if save_brain:
        print(
            "You chose to save the brain. We are now saving it with the masks."
        )
        for patient in tqdm.tqdm(test_df.values):
            nmods = params["num_modalities"]
            mask_nib = nib.load(
                os.path.join(params["results_dir"], patient[0],
                             patient[0] + "_mask.nii.gz"))
            mask_data = mask_nib.get_fdata().astype(np.int8)
            for i in range(int(nmods)):
                image_name = os.path.basename(patient[i + 1]).strip(".nii.gz")
                image_path = patient[i + 1]
                patient_nib = nib.load(image_path)
                image_data = patient_nib.get_fdata()
                image_data[mask_data == 0] = 0
                to_save_image = nib.Nifti1Image(image_data, patient_nib.affine)
                nib.save(
                    to_save_image,
                    os.path.join(params["results_dir"], patient[0],
                                 image_name + "_brain.nii.gz"),
                )

    print("Final output stored in : %s" % (params["results_dir"]))
    print("Thank you for using BrainMaGe")
    print("*" * 60)
Exemplo n.º 3
0
def train_network(cfg, device, weights):
    """
    Receiving a configuration file and a device, the training is pushed through this file

    Parameters
    ----------
    cfg : TYPE
        DESCRIPTION.
    device : TYPE
        DESCRIPTION.

    Returns
    -------
    None.

    """
    training_start_time = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(training_start_time))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()
    print("Checking for this cfg file : ", cfg)
    # READING FROM A CFG FILE and check if file exists or not
    if os.path.isfile(cfg):
        params_df = pd.read_csv(cfg,
                                sep=' = ',
                                names=['param_name', 'param_value'],
                                comment='#',
                                skip_blank_lines=True,
                                engine='python').fillna(' ')
    else:
        print('Missing train_params.cfg file?')
        sys.exit(0)

    # Reading in all the parameters
    params = {}
    for i in range(params_df.shape[0]):
        params[params_df.iloc[i, 0]] = params_df.iloc[i, 1]
    print(type(device), device)
    if type(device) != str:
        params['device'] = str(device)
    params['weights'] = weights
    # Although uneccessary, we still do this
    if not os.path.isdir(str(params['model_dir'])):
        os.mkdir(params['model_dir'])

    # PRINT PARSED ARGS
    print("\n\n")
    print("Training Folder Dir     :", params['train_dir'])
    print("Validation Dir          :", params['validation_dir'])
    print("Model Directory         :", params['model_dir'])
    print("Mode                    :", params['mode'])
    print("Number of modalities    :", params['num_modalities'])
    print("Modalities              :", params['modalities'])
    print("Number of classes       :", params['num_classes'])
    print("Max Number of epochs    :", params['max_epochs'])
    print("Batch size              :", params['batch_size'])
    print("Optimizer               :", params['optimizer'])
    print("Learning Rate           :", params['learning_rate'])
    print("Learning Rate Milestones:", params['lr_milestones'])
    print("Patience to decay       :", params['decay_milestones'])
    print("Early Stopping Patience :", params['early_stop_patience'])
    print("Depth Layers            :", params['layers'])
    print("Model used              :", params['model'])
    print("Weights used            :", params['weights'])
    sys.stdout.flush()
    print("Device Given :", device)
    sys.stdout.flush()
    # Although uneccessary, we still do this
    os.makedirs(params['model_dir'], exist_ok=True)
    print("Current Device : ", torch.cuda.current_device())
    print("Device Count on Machine : ", torch.cuda.device_count())
    print("Device Name : ", torch.cuda.get_device_name())
    print("Cuda Availibility : ", torch.cuda.is_available())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)
    if device.type == 'cuda':
        print('Memory Usage:')
        print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024**3, 1),
              'GB')
        print('Cached: ', round(torch.cuda.memory_cached(0) / 1024**3, 1),
              'GB')
    sys.stdout.flush()

    # We generate CSV for training if not provided
    print("Generating CSV Files")
    # Generating training csv files
    if not params['csv_provided'] == 'True':
        print('Since CSV were not provided, we are gonna create for you')
        generate_csv(params['train_dir'],
                     to_save=params['model_dir'],
                     mode=params['mode'],
                     ftype='train',
                     modalities=params['modalities'])
        generate_csv(params['validation_dir'],
                     to_save=params['model_dir'],
                     mode=params['mode'],
                     ftype='validation',
                     modalities=params['modalities'])
        params['train_csv'] = os.path.join(params['model_dir'], 'train.csv')
        params['validation_csv'] = os.path.join(params['model_dir'],
                                                'validation.csv')
    else:
        # Taken directly from params
        pass
    os.makedirs(params['model_dir'], exist_ok=True)

    log_dir = os.path.join(params['model_dir'])
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(log_dir, 'checkpoints'),
        monitor='val_loss',
        verbose=True,
        save_top_k=1,
        mode='auto',
        save_weights_only=False,
        prefix=str('deep_resunet_' + params['base_filters']))
    stop_callback = EarlyStopping(monitor='val_loss',
                                  mode='auto',
                                  patience=int(params['early_stop_patience']),
                                  verbose=True)
    model = SkullStripper(params)

    res_ckpt = weights
    trainer = Trainer(
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=stop_callback,
        default_root_dir=params['model_dir'],
        gpus=params['device'],
        fast_dev_run=False,
        max_epochs=int(params['max_epochs']),
        min_epochs=int(params['min_epochs']),
        distributed_backend='ddp',
        weights_summary='full',
        weights_save_path=params['model_dir'],
        amp_level='O1',
        num_sanity_val_steps=5,
        resume_from_checkpoint=res_ckpt,
    )
    trainer.fit(model)
Exemplo n.º 4
0
def infer_ma(cfg, device, save_brain, weights):

    cfg = os.path.abspath(cfg)

    if os.path.isfile(cfg):
        params_df = pd.read_csv(cfg,
                                sep=' = ',
                                names=['param_name', 'param_value'],
                                comment='#',
                                skip_blank_lines=True,
                                engine='python').fillna(' ')
    else:
        print('Missing test_params.cfg file? Please give one!')
        sys.exit(0)
    params = {}
    for i in range(params_df.shape[0]):
        params[params_df.iloc[i, 0]] = params_df.iloc[i, 1]
    params['weights'] = weights
    start = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(start))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()

    print("Generating Test csv")
    if not os.path.exists(os.path.join(params['results_dir'])):
        os.mkdir(params['results_dir'])
    if not params['csv_provided'] == 'True':
        print('Since CSV were not provided, we are gonna create for you')
        csv_creator_adv.generate_csv(params['test_dir'],
                                     to_save=params['results_dir'],
                                     mode=params['mode'],
                                     ftype='test',
                                     modalities=params['modalities'])
        test_csv = os.path.join(params['results_dir'], 'test.csv')
    else:
        test_csv = params['test_csv']

    test_df = pd.read_csv(test_csv)
    test_df.ID = test_df.ID.astype(str)
    temp_dir = os.path.join(params['results_dir'], 'Temp')
    os.makedirs(temp_dir, exist_ok=True)

    patients_dict = {}

    print("Resampling the images to isotropic resolution of 1mm x 1mm x 1mm")
    print("Also Converting the images to RAI and brats for smarter use.")
    for patient in tqdm.tqdm(test_df.values):
        os.makedirs(os.path.join(temp_dir, patient[0]), exist_ok=True)
        patient_path = patient[1]
        image = nib.load(patient_path)
        old_spacing = image.header.get_zooms()
        old_affine = image.affine
        old_shape = image.header.get_data_shape()
        new_spacing = (1, 1, 1)
        new_shape = (int(
            np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))),
                     int(
                         np.round(old_spacing[1] / new_spacing[1] *
                                  float(image.shape[1]))),
                     int(
                         np.round(old_spacing[2] / new_spacing[2] *
                                  float(image.shape[2]))))
        image_data = image.get_fdata()
        new_image = resize(image_data,
                           new_shape,
                           order=3,
                           mode='edge',
                           cval=0,
                           anti_aliasing=False)
        new_affine = np.eye(4)
        new_affine = np.array(old_affine)
        for i in range(3):
            for j in range(3):
                if old_affine[i, j] != 0:
                    new_affine[i,
                               j] = old_affine[i, j] * (1 / old_affine[i, j])
                    if old_affine[i, j] <= 0:
                        new_affine[i, j] = -1 * (old_affine[i, j] *
                                                 (1 / old_affine[i, j]))
        temp_image = nib.Nifti1Image(new_image, new_affine)
        nib.save(
            temp_image,
            os.path.join(temp_dir, patient[0],
                         patient[0] + '_resamp111.nii.gz'))

        temp_dict = {}
        temp_dict['name'] = patient[0]
        temp_dict['old_spacing'] = old_spacing
        temp_dict['old_affine'] = old_affine
        temp_dict['old_shape'] = old_shape
        temp_dict['new_spacing'] = new_spacing
        temp_dict['new_affine'] = new_affine
        temp_dict['new_shape'] = new_shape

        patient_path = os.path.join(temp_dir, patient[0],
                                    patient[0] + '_resamp111.nii.gz')
        patient_nib = nib.load(patient_path)
        patient_data = patient_nib.get_fdata()
        patient_data, pad_info = pad_image(patient_data)
        patient_affine = patient_nib.affine
        temp_image = nib.Nifti1Image(patient_data, patient_affine)
        nib.save(
            temp_image,
            os.path.join(temp_dir, patient[0],
                         patient[0] + '_bratsized.nii.gz'))
        temp_dict['pad_info'] = pad_info
        patients_dict[patient[0]] = temp_dict

    model = fetch_model(params['model'], int(params['num_modalities']),
                        int(params['num_classes']),
                        int(params['base_filters']))
    checkpoint = torch.load(str(params['weights']))
    model.load_state_dict(checkpoint['model_state_dict'])

    if device != 'cpu':
        model.cuda()
    model.eval()

    print("Done Resampling the Data.\n")
    print("--" * 30)
    print("Running the model on the subjects")
    for patient in tqdm.tqdm(test_df.values):
        patient_path = os.path.join(temp_dir, patient[0],
                                    patient[0] + '_bratsized.nii.gz')
        patient_nib = nib.load(patient_path)
        image = patient_nib.get_fdata()
        image = process_image(image)
        image = resize(image, (128, 128, 128),
                       order=3,
                       mode='edge',
                       cval=0,
                       anti_aliasing=False)
        image = image[np.newaxis, np.newaxis, ...]
        image = torch.FloatTensor(image)
        if device != 'cpu':
            image = image.cuda()
        with torch.no_grad():
            output = model(image)
            output = output.cpu().numpy()[0][0]
            to_save = interpolate_image(output, patient_nib.shape)
            to_save[to_save >= 0.9] = 1
            to_save[to_save < 0.9] = 0
            to_save_nib = nib.Nifti1Image(to_save, patient_nib.affine)
            nib.save(
                to_save_nib,
                os.path.join(temp_dir, patient[0],
                             patient[0] + '_bratsized_mask.nii.gz'))
            current_patient_dict = patients_dict[patient[0]]
            new_image = padder_and_cropper(to_save,
                                           current_patient_dict['pad_info'])
            to_save_new_nib = nib.Nifti1Image(new_image, patient_nib.affine)
            nib.save(
                to_save_new_nib,
                os.path.join(temp_dir, patient[0],
                             patient[0] + '_resample111_mask.nii.gz'))
            to_save_final = resize(new_image,
                                   current_patient_dict['old_shape'],
                                   order=3,
                                   mode='edge',
                                   cval=0)
            to_save_final[to_save_final > 0.9] = 1
            to_save_final[to_save_final < 0.9] = 0
            for i in range(to_save_final.shape[2]):
                if np.any(to_save_final[:, :, i]):
                    to_save_final[:, :,
                                  i] = binary_fill_holes(to_save_final[:, :,
                                                                       i])
            to_save_final = postprocess_prediction(to_save_final).astype(
                np.uint8)
            to_save_final_nib = nib.Nifti1Image(
                to_save_final, current_patient_dict['old_affine'])

            os.makedirs(os.path.join(params['results_dir'], patient[0]),
                        exist_ok=True)

            nib.save(
                to_save_final_nib,
                os.path.join(params['results_dir'], patient[0],
                             patient[0] + '_mask.nii.gz'))

    print("Done with running the model.")
    if save_brain:
        print(
            "You chose to save the brain. We are now saving it with the masks."
        )
        for patient in tqdm.tqdm(test_df.values):
            image = nib.load(patient[1])
            image_data = image.get_fdata()
            mask = nib.load(
                os.path.join(params['results_dir'], patient[0],
                             patient[0] + '_mask.nii.gz'))
            mask_data = mask.get_fdata().astype(np.int8)
            image_data[mask_data == 0] = 0
            to_save_brain = nib.Nifti1Image(image_data, image.affine)
            nib.save(
                to_save_brain,
                os.path.join(params['results_dir'], patient[0],
                             patient[0] + '_brain.nii.gz'))

    print("Please check the %s folder for the intermediate outputs if you\"+\
          would like to see some intermediate steps." %
          (os.path.join(params['results_dir'], 'Temp')))
    print("Final output stored in : %s" % (params['results_dir']))
    print("Thank you for using BrainMaGe")
    print('*' * 60)
Exemplo n.º 5
0
def train_network(cfg, device, weights):
    """
    Receiving a configuration file and a device, the training is pushed through this file

    Parameters
    ----------
    cfg : TYPE
        DESCRIPTION.
    device : TYPE
        DESCRIPTION.

    Returns
    -------
    None.

    """
    training_start_time = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(training_start_time))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()
    print("Checking for this cfg file : ", cfg)
    # READING FROM A CFG FILE and check if file exists or not
    if os.path.isfile(cfg):
        params_df = pd.read_csv(
            cfg,
            sep=" = ",
            names=["param_name", "param_value"],
            comment="#",
            skip_blank_lines=True,
            engine="python",
        ).fillna(" ")
    else:
        print("Missing train_params.cfg file?")
        sys.exit(0)

    # Reading in all the parameters
    params = {}
    for i in range(params_df.shape[0]):
        params[params_df.iloc[i, 0]] = params_df.iloc[i, 1]
    print(type(device), device)
    if type(device) != str:
        params["device"] = str(device)
    params["weights"] = weights
    # Although uneccessary, we still do this
    if not os.path.isdir(str(params["model_dir"])):
        os.mkdir(params["model_dir"])

    # PRINT PARSED ARGS
    print("\n\n")
    print("Training Folder Dir     :", params["train_dir"])
    print("Validation Dir          :", params["validation_dir"])
    print("Model Directory         :", params["model_dir"])
    print("Mode                    :", params["mode"])
    print("Number of modalities    :", params["num_modalities"])
    print("Modalities              :", params["modalities"])
    print("Number of classes       :", params["num_classes"])
    print("Max Number of epochs    :", params["max_epochs"])
    print("Batch size              :", params["batch_size"])
    print("Optimizer               :", params["optimizer"])
    print("Learning Rate           :", params["learning_rate"])
    print("Learning Rate Milestones:", params["lr_milestones"])
    print("Patience to decay       :", params["decay_milestones"])
    print("Early Stopping Patience :", params["early_stop_patience"])
    print("Depth Layers            :", params["layers"])
    print("Model used              :", params["model"])
    print("Weights used            :", params["weights"])
    sys.stdout.flush()
    print("Device Given :", device)
    sys.stdout.flush()
    # Although uneccessary, we still do this
    os.makedirs(params["model_dir"], exist_ok=True)
    print("Current Device : ", torch.cuda.current_device())
    print("Device Count on Machine : ", torch.cuda.device_count())
    print("Device Name : ", torch.cuda.get_device_name())
    print("Cuda Availibility : ", torch.cuda.is_available())
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    if device.type == "cuda":
        print("Memory Usage:")
        print("Allocated:", round(torch.cuda.memory_allocated(0) / 1024**3, 1),
              "GB")
        print("Cached: ", round(torch.cuda.memory_cached(0) / 1024**3, 1),
              "GB")
    sys.stdout.flush()

    # We generate CSV for training if not provided
    print("Generating CSV Files")
    # Generating training csv files
    if not params["csv_provided"] == "True":
        print("Since CSV were not provided, we are gonna create for you")
        generate_csv(
            params["train_dir"],
            to_save=params["model_dir"],
            mode=params["mode"],
            ftype="train",
            modalities=params["modalities"],
        )
        generate_csv(
            params["validation_dir"],
            to_save=params["model_dir"],
            mode=params["mode"],
            ftype="validation",
            modalities=params["modalities"],
        )
        params["train_csv"] = os.path.join(params["model_dir"], "train.csv")
        params["validation_csv"] = os.path.join(params["model_dir"],
                                                "validation.csv")
    else:
        # Taken directly from params
        pass
    os.makedirs(params["model_dir"], exist_ok=True)

    log_dir = os.path.join(params["model_dir"])
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(log_dir, "checkpoints"),
        monitor="val_loss",
        verbose=True,
        save_top_k=1,
        mode="auto",
        save_weights_only=False,
        prefix=str("deep_resunet_" + params["base_filters"]),
    )
    stop_callback = EarlyStopping(
        monitor="val_loss",
        mode="auto",
        patience=int(params["early_stop_patience"]),
        verbose=True,
    )
    model = SkullStripper(params)

    res_ckpt = weights
    trainer = Trainer(
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=stop_callback,
        default_root_dir=params["model_dir"],
        gpus=params["device"],
        fast_dev_run=False,
        max_epochs=int(params["max_epochs"]),
        min_epochs=int(params["min_epochs"]),
        distributed_backend="ddp",
        weights_summary="full",
        weights_save_path=params["model_dir"],
        amp_level="O1",
        num_sanity_val_steps=5,
        resume_from_checkpoint=res_ckpt,
    )
    trainer.fit(model)