Exemplo n.º 1
0
if not os.path.isdir(dir_name):
    os.makedirs(dir_name)
    os.makedirs(os.path.join(dir_name, 'pos'))
    os.makedirs(os.path.join(dir_name, 'neg'))

# Setup the NN Model
model = get_model(json_opts.model)
if hasattr(model.net, 'classification_mode'):
    model.net.classification_mode = 'attention'
if hasattr(model.net, 'deep_supervised'):
    model.net.deep_supervised = False

# Setup Dataset and Augmentation
dataset_class = get_dataset(train_opts.arch_type)
dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
dataset_transform = get_dataset_transformation(train_opts.arch_type,
                                               opts=json_opts.augmentation)

# Setup Data Loader
dataset = dataset_class(dataset_path,
                        split='train',
                        transform=dataset_transform['valid'])
data_loader = DataLoader(dataset=dataset,
                         num_workers=1,
                         batch_size=1,
                         shuffle=True)

# test
for iteration, data in enumerate(data_loader, 1):
    model.set_input(data[0], data[1])

    cls = dataset.label_names[int(data[1])]
Exemplo n.º 2
0
def train(arguments):

    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset(arch_type)
    ds_path = get_dataset_path(arch_type, json_opts.data_path)
    ds_transform = get_dataset_transformation(arch_type,
                                              opts=json_opts.augmentation)

    # Setup the NN Model
    model = get_model(json_opts.model)
    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Setup Data Loader
    num_workers = train_opts.num_workers if hasattr(train_opts,
                                                    'num_workers') else 16
    train_dataset = ds_class(ds_path,
                             split='train',
                             transform=ds_transform['train'],
                             preload_data=train_opts.preloadData)
    valid_dataset = ds_class(ds_path,
                             split='val',
                             transform=ds_transform['valid'],
                             preload_data=train_opts.preloadData)
    test_dataset = ds_class(ds_path,
                            split='test',
                            transform=ds_transform['valid'],
                            preload_data=train_opts.preloadData)

    # create sampler
    if train_opts.sampler == 'stratified':
        print('stratified sampler')
        train_sampler = StratifiedSampler(train_dataset.labels,
                                          train_opts.batchSize)
        batch_size = 52
    elif train_opts.sampler == 'weighted2':
        print('weighted sampler with background weight={}x'.format(
            train_opts.bgd_weight_multiplier))
        # modify and increase background weight
        weight = train_dataset.weight
        bgd_weight = np.min(weight)
        weight[abs(weight - bgd_weight) <
               1e-8] = bgd_weight * train_opts.bgd_weight_multiplier
        train_sampler = sampler.WeightedRandomSampler(
            weight, len(train_dataset.weight))
        batch_size = train_opts.batchSize
    else:
        print('weighted sampler')
        train_sampler = sampler.WeightedRandomSampler(
            train_dataset.weight, len(train_dataset.weight))
        batch_size = train_opts.batchSize

    # loader
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=num_workers,
                              batch_size=batch_size,
                              sampler=train_sampler)
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=num_workers,
                              batch_size=train_opts.batchSize,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=num_workers,
                             batch_size=train_opts.batchSize,
                             shuffle=True)

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Training Function
    track_labels = np.arange(len(train_dataset.label_names))
    model.set_labels(track_labels)
    model.set_scheduler(train_opts)

    if hasattr(model, 'update_state'):
        model.update_state(0)

    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))

        # # # --- Start ---
        # import matplotlib.pyplot as plt
        # plt.ion()
        # plt.figure()
        # target_arr = np.zeros(14)
        # # # --- End ---

        # Training Iterations
        for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1),
                                                 total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()

            if epoch == (train_opts.n_epochs - 1):
                import time
                time.sleep(36000)

            if train_opts.max_it == epoch_iter:
                break

            # # # --- visualise distribution ---
            # for lab in labels.numpy():
            #     target_arr[lab] += 1
            # plt.clf(); plt.bar(train_dataset.label_names, target_arr); plt.pause(0.01)
            # # # --- End ---

            # Visualise predictions
            if epoch_iter <= 100:
                visuals = model.get_current_visuals()
                visualizer.display_current_results(visuals,
                                                   epoch=epoch,
                                                   save_result=False)

            # Error visualisation
            errors = model.get_current_errors()
            error_logger.update(errors, split='train')

        # Validation and Testing Iterations
        pr_lbls = []
        gt_lbls = []
        for loader, split in zip([valid_loader, test_loader],
                                 ['validation', 'test']):
            model.reset_results()

            for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1),
                                                     total=len(loader)):

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                # Visualise predictions
                visuals = model.get_current_visuals()
                visualizer.display_current_results(visuals,
                                                   epoch=epoch,
                                                   save_result=False)

                if train_opts.max_it == epoch_iter:
                    break

            # Error visualisation
            errors = model.get_accumulated_errors()
            stats = model.get_classification_stats()
            error_logger.update({**errors, **stats}, split=split)

            # HACK save validation error
            if split == 'validation':
                valid_err = errors['CE']

        # Update the plots
        for split in ['train', 'validation', 'test']:
            # exclude bckground
            #track_labels = np.delete(track_labels, 3)
            #show_labels = train_dataset.label_names[:3] + train_dataset.label_names[4:]
            show_labels = train_dataset.label_names
            visualizer.plot_current_errors(epoch,
                                           error_logger.get_errors(split),
                                           split_name=split,
                                           labels=show_labels)
            visualizer.print_current_errors(epoch,
                                            error_logger.get_errors(split),
                                            split_name=split)
        error_logger.reset()

        # Save the model parameters
        if epoch % train_opts.save_epoch_freq == 0:
            model.save(epoch)

        if hasattr(model, 'update_state'):
            model.update_state(epoch)

        # Update the model learning rate
        model.update_learning_rate(metric=valid_err, epoch=epoch)
Exemplo n.º 3
0
def validation(json_name):
    # Load options
    json_opts = json_file_to_pyobj(json_name)
    train_opts = json_opts.training

    # Setup the NN Model
    model = get_model(json_opts.model)
    save_directory = os.path.join(model.save_dir, train_opts.arch_type)
    mkdirfun(save_directory)

    # Setup Dataset and Augmentation
    dataset_class = get_dataset(train_opts.arch_type)
    dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
    dataset_transform = get_dataset_transformation(train_opts.arch_type,
                                                   opts=json_opts.augmentation)

    # Setup Data Loader
    dataset = dataset_class(dataset_path,
                            split='validation',
                            transform=dataset_transform['valid'])
    data_loader = DataLoader(dataset=dataset,
                             num_workers=8,
                             batch_size=1,
                             shuffle=False)

    # Visualisation Parameters
    #visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)

    # Setup stats logger
    stat_logger = StatLogger()

    # test
    for iteration, data in enumerate(data_loader, 1):
        model.set_input(data[0], data[1])
        model.test()

        input_arr = np.squeeze(data[0].cpu().numpy()).astype(np.float32)
        label_arr = np.squeeze(data[1].cpu().numpy()).astype(np.int16)
        output_arr = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(
            np.int16)

        # If there is a label image - compute statistics
        dice_vals = dice_score(label_arr, output_arr, n_class=int(4))
        md, hd = distance_metric(label_arr, output_arr, dx=2.00, k=2)
        precision, recall = precision_and_recall(label_arr,
                                                 output_arr,
                                                 n_class=int(4))
        stat_logger.update(split='test',
                           input_dict={
                               'img_name': '',
                               'dice_LV': dice_vals[1],
                               'dice_MY': dice_vals[2],
                               'dice_RV': dice_vals[3],
                               'prec_MYO': precision[2],
                               'reca_MYO': recall[2],
                               'md_MYO': md,
                               'hd_MYO': hd
                           })

        # Write a nifti image
        import SimpleITK as sitk
        input_img = sitk.GetImageFromArray(np.transpose(input_arr, (2, 1, 0)))
        input_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
        label_img = sitk.GetImageFromArray(np.transpose(label_arr, (2, 1, 0)))
        label_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
        predi_img = sitk.GetImageFromArray(np.transpose(output_arr, (2, 1, 0)))
        predi_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])

        sitk.WriteImage(
            input_img,
            os.path.join(save_directory, '{}_img.nii.gz'.format(iteration)))
        sitk.WriteImage(
            label_img,
            os.path.join(save_directory, '{}_lbl.nii.gz'.format(iteration)))
        sitk.WriteImage(
            predi_img,
            os.path.join(save_directory, '{}_pred.nii.gz'.format(iteration)))

    stat_logger.statlogger2csv(split='test',
                               out_csv_name=os.path.join(
                                   save_directory, 'stats.csv'))
    for key, (mean_val,
              std_val) in stat_logger.get_errors(split='test').items():
        print('-', key, ': \t{0:.3f}+-{1:.3f}'.format(mean_val, std_val), '-')
Exemplo n.º 4
0
def train(arguments):

    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset(arch_type)
    ds_path = get_dataset_path(arch_type, json_opts.data_path)
    ds_transform = get_dataset_transformation(arch_type,
                                              opts=json_opts.augmentation)

    # Setup the NN Model
    model = get_model(json_opts.model)
    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Setup Data Loader
    print("\n\n\n\n\nOK FOR DATASET\n\n\n\n\n\n")
    train_dataset = ds_class(ds_path,
                             split='train',
                             transform=ds_transform['train'],
                             preload_data=train_opts.preloadData)
    valid_dataset = ds_class(ds_path,
                             split='validation',
                             transform=ds_transform['valid'],
                             preload_data=train_opts.preloadData)
    test_dataset = ds_class(ds_path,
                            split='test',
                            transform=ds_transform['valid'],
                            preload_data=train_opts.preloadData)
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=1,
                              batch_size=train_opts.batchSize,
                              shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=1,
                              batch_size=train_opts.batchSize,
                              shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=1,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    print("\n\n\n\n\nOK FOR DATASET\n\n\n\n\n\n")

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Training Function
    model.set_scheduler(train_opts)
    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))

        # Training Iterations
        for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1),
                                                 total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()
            #model.optimize_parameters_accumulate_grd(epoch_iter)

            # Error visualisation
            errors = model.get_current_errors()
            error_logger.update(errors, split='train')

        # Validation and Testing Iterations
        for loader, split in zip([valid_loader, test_loader],
                                 ['validation', 'test']):
            for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1),
                                                     total=len(loader)):

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                # Error visualisation
                errors = model.get_current_errors()
                stats = model.get_segmentation_stats()
                error_logger.update({**errors, **stats}, split=split)

                # Visualise predictions
                visuals = model.get_current_visuals()
                visualizer.display_current_results(visuals,
                                                   epoch=epoch,
                                                   save_result=False)

        # Update the plots
        for split in ['train', 'validation', 'test']:
            visualizer.plot_current_errors(epoch,
                                           error_logger.get_errors(split),
                                           split_name=split)
            visualizer.print_current_errors(epoch,
                                            error_logger.get_errors(split),
                                            split_name=split)
        error_logger.reset()
        print("Memory Usage :",
              convert_bytes(torch.cuda.max_memory_allocated()))
        print("Number of parameters :", model.get_number_parameters())

        # Save the model parameters
        if epoch % train_opts.save_epoch_freq == 0:
            model.save(epoch)

        # Update the model learning rate
        model.update_learning_rate()
Exemplo n.º 5
0
for epoch in epochs:

    # Load options and replace the epoch attribute
    json_opts = json_file_to_pyobj(
        '/vol/biomedic2/oo2113/projects/syntAI/ukbb_pytorch/configs_final/debug_ct.json'
    )
    json_opts = json_opts._replace(model=json_opts.model._replace(
        which_epoch=epoch))

    # Setup the NN Model
    model = get_model(json_opts.model)

    # Setup Dataset and Augmentation
    dataset_class = get_dataset('test_sax')
    dataset_path = get_dataset_path('test_sax', json_opts.data_path)
    dataset_transform = get_dataset_transformation('test_sax',
                                                   json_opts.augmentation)

    # Setup Data Loader
    dataset = dataset_class(dataset_path, transform=dataset_transform['test'])
    data_loader = DataLoader(dataset=dataset,
                             num_workers=1,
                             batch_size=1,
                             shuffle=False)

    # test
    for iteration, (input_arr, input_meta, _) in enumerate(data_loader, 1):
        # look for the subject_id
        if iteration == subject_id:
            # load the input image into the model
            model.set_input(input_arr)
            inp_fmap, out_fmap = model.get_feature_maps(layer_name=layer_name,
def train(arguments):

    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset(arch_type)
    ds_path = get_dataset_path(arch_type, json_opts.data_path)
    ds_transform = get_dataset_transformation(
        arch_type,
        opts=json_opts.augmentation,
        max_output_channels=json_opts.model.output_nc,
        verbose=json_opts.training.verbose)

    # Setup channels
    channels = json_opts.data_opts.channels
    if len(channels) != json_opts.model.input_nc \
            or len(channels) != getattr(json_opts.augmentation, arch_type).scale_size[-1]:
        raise Exception(
            'Number of data channels must match number of model channels, and patch and scale size dimensions'
        )

    # Setup the NN Model
    model = get_model(json_opts.model)
    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Setup Data Loader
    split_opts = json_opts.data_split
    train_dataset = ds_class(ds_path,
                             split='train',
                             transform=ds_transform['train'],
                             preload_data=train_opts.preloadData,
                             train_size=split_opts.train_size,
                             test_size=split_opts.test_size,
                             valid_size=split_opts.validation_size,
                             split_seed=split_opts.seed,
                             channels=channels)
    valid_dataset = ds_class(ds_path,
                             split='validation',
                             transform=ds_transform['valid'],
                             preload_data=train_opts.preloadData,
                             train_size=split_opts.train_size,
                             test_size=split_opts.test_size,
                             valid_size=split_opts.validation_size,
                             split_seed=split_opts.seed,
                             channels=channels)
    test_dataset = ds_class(ds_path,
                            split='test',
                            transform=ds_transform['valid'],
                            preload_data=train_opts.preloadData,
                            train_size=split_opts.train_size,
                            test_size=split_opts.test_size,
                            valid_size=split_opts.validation_size,
                            split_seed=split_opts.seed,
                            channels=channels)
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=16,
                              batch_size=train_opts.batchSize,
                              shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=16,
                              batch_size=train_opts.batchSize,
                              shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=16,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Training Function
    model.set_scheduler(train_opts)
    # Setup Early Stopping
    early_stopper = EarlyStopper(json_opts.training.early_stopping,
                                 verbose=json_opts.training.verbose)
    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))
        train_volumes = []
        validation_volumes = []

        # Training Iterations
        for epoch_iter, (images, labels,
                         indices) in tqdm(enumerate(train_loader, 1),
                                          total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()
            #model.optimize_parameters_accumulate_grd(epoch_iter)

            # Error visualisation
            errors = model.get_current_errors()
            error_logger.update(errors, split='train')

            ids = train_dataset.get_ids(indices)
            volumes = model.get_current_volumes()
            visualizer.display_current_volumes(volumes, ids, 'train', epoch)
            train_volumes.append(volumes)

        # Validation and Testing Iterations
        for loader, split, dataset in zip([valid_loader, test_loader],
                                          ['validation', 'test'],
                                          [valid_dataset, test_dataset]):
            for epoch_iter, (images, labels,
                             indices) in tqdm(enumerate(loader, 1),
                                              total=len(loader)):
                ids = dataset.get_ids(indices)

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                # Error visualisation
                errors = model.get_current_errors()
                stats = model.get_segmentation_stats()
                error_logger.update({**errors, **stats}, split=split)

                if split == 'validation':  # do not look at testing
                    # Visualise predictions
                    volumes = model.get_current_volumes()
                    visualizer.display_current_volumes(volumes, ids, split,
                                                       epoch)
                    validation_volumes.append(volumes)

                    # Track validation loss values
                    early_stopper.update({**errors, **stats})

        # Update the plots
        for split in ['train', 'validation', 'test']:
            visualizer.plot_current_errors(epoch,
                                           error_logger.get_errors(split),
                                           split_name=split)
            visualizer.print_current_errors(epoch,
                                            error_logger.get_errors(split),
                                            split_name=split)
        visualizer.save_plots(epoch, save_frequency=5)
        error_logger.reset()

        # Save the model parameters
        if not early_stopper.is_improving is False:
            model.save(json_opts.model.model_type, epoch)
            save_config(json_opts, json_filename, model, epoch)

        # Update the model learning rate
        model.update_learning_rate(
            metric=early_stopper.get_current_validation_loss())

        if early_stopper.interrogate(epoch):
            break
Exemplo n.º 7
0
def validation(json_name):
    # Load options
    json_opts = json_file_to_pyobj(json_name)
    train_opts = json_opts.training

    # Should be consistent with what was used as input to train the network
    model_types = train_opts.modalities

    # Setup Dataset and Augmentation
    dataset_class = get_dataset(train_opts.arch_type)
    dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
    dataset_transform = get_dataset_transformation(train_opts.arch_type,
                                                   opts=json_opts.augmentation)

    # Setup Data Loader
    dataset = dataset_class(
        dataset_path, split='test_val',
        transform=dataset_transform['valid'])  # 'validation'
    data_loader = DataLoader(dataset=dataset,
                             num_workers=8,
                             batch_size=1,
                             shuffle=False)

    # Setup the NN Model
    dataDims = [
        1, 1, dataset.image_dims[0], dataset.image_dims[1],
        dataset.image_dims[2]
    ]
    model = get_model(json_opts.model, dataDims)
    save_directory = os.path.join(model.save_dir, train_opts.arch_type)
    mkdirfun(save_directory)

    # Visualisation Parameters
    #visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)

    # Setup stats logger
    stat_logger = StatLogger()

    # test
    for iteration, data in enumerate(data_loader, 1):
        identifier = data[2].cpu().numpy()[0]
        #pname = identifier2id(identifier, json_opts.data_path[0])
        #print("Iteration {}, patient {}".format(iteration, pname))
        print("Iteration {}".format(iteration))

        model.set_input(data[0], data[1])
        model.test()

        #print("data shape {}".format(data[0].size))

        try:
            input_arr = np.squeeze(data[0].cpu().numpy()).astype(np.float32)
        except:
            input_arr = np.squeeze(data[0][0].cpu().numpy()).astype(np.float32)
        label_arr = np.squeeze(data[1].cpu().numpy()).astype(np.int16)
        output_arr = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(
            np.int16)
        logit_arr = np.squeeze(model.logits.data.cpu().numpy()).astype(
            np.float32)

        #print("output_arr shape: {}".format(output_arr.shape))

        # Clean the prediction
        output_arr = remove_islands_gp(output_arr)

        # If there is a label image - compute statistics
        #dice_vals = dice_score(label_arr, output_arr, n_class=int(json_opts.model.output_nc)) # DICE - left and right together
        dice_vals, dice_vals_left_right = dice_score_average_left_right(
            label_arr, output_arr, n_class=int(
                json_opts.model.output_nc))  # DICE - av. of left + right
        print("DICE scores: {}".format(dice_vals_left_right))

        # hd, msd = distance_metric_new(label_arr, output_arr, vox_size=0.390625) # Hausdorff distance and mean surface distance
        hd, msd = distance_metric_new_left_right(
            label_arr, output_arr,
            vox_size=0.390625)  # Hausdorff distance and mean surface distance

        volumes = calculate_volume(output_arr,
                                   pixdim=[0.390625])  # Segmentation volume
        volumes_label = calculate_volume(label_arr,
                                         pixdim=[0.390625
                                                 ])  # Segmentation volume
        cm_dist, cm_diff = measure_cm_dist_wrapper(
            output_arr, label_arr,
            pixdim=[0.390625])  # Center of mass difference
        precision, recall = precision_and_recall(
            label_arr, output_arr,
            n_class=int(json_opts.model.output_nc))  # Precision and recall

        # Accumulate stats
        stat_logger.update(split='test',
                           input_dict={
                               'img_name': iteration,
                               'Background': dice_vals[0],
                               'GPe_dice left': dice_vals_left_right[0, 0],
                               'GPe_dice right': dice_vals_left_right[0, 1],
                               'GPi_dice left': dice_vals_left_right[1, 0],
                               'GPi_dice right': dice_vals_left_right[1, 1],
                               'GPe_hd left [mm]': hd[0, 0],
                               'GPe_hd right [mm]': hd[0, 1],
                               'GPi_hd left [mm]': hd[1, 0],
                               'GPi_hd right [mm]': hd[1, 1],
                               'GPe_msd left [mm]': msd[0, 0],
                               'GPe_msd right [mm]': msd[0, 1],
                               'GPi_msd left [mm]': msd[1, 0],
                               'GPi_msd right [mm]': msd[1, 1],
                               'GPe right vol [cm^3]': volumes[0] / 1000,
                               'GPe left vol [cm^3]': volumes[1] / 1000,
                               'GPi right vol [cm^3]': volumes[2] / 1000,
                               'GPi left vol [cm^3]': volumes[3] / 1000,
                               'GPe right label vol [cm^3]':
                               volumes_label[0] / 1000,
                               'GPe left label vol [cm^3]':
                               volumes_label[1] / 1000,
                               'GPi right label vol [cm^3]':
                               volumes_label[2] / 1000,
                               'GPi left label vol [cm^3]':
                               volumes_label[3] / 1000,
                               'GPe right cm dist [mm]': cm_dist[0, 0],
                               'GPe left cm dist [mm]': cm_dist[0, 1],
                               'GPe right cm X [mm]': cm_diff[0][0][0],
                               'GPe right cm Y [mm]': cm_diff[0][0][1],
                               'GPe right cm Z [mm]': cm_diff[0][0][2],
                               'GPe left cm X [mm]': cm_diff[0][1][0],
                               'GPe left cm Y [mm]': cm_diff[0][1][1],
                               'GPe left cm Z [mm]': cm_diff[0][1][2],
                               'GPi right cm dist [mm]': cm_dist[1, 0],
                               'GPi left cm dist [mm]': cm_dist[1, 1],
                               'GPi right cm X [mm]': cm_diff[1][0][0],
                               'GPi right cm Y [mm]': cm_diff[1][0][1],
                               'GPi right cm Z [mm]': cm_diff[1][0][2],
                               'GPi left cm X [mm]': cm_diff[1][1][0],
                               'GPi left cm Y [mm]': cm_diff[1][1][1],
                               'GPi left cm Z [mm]': cm_diff[1][1][2],
                               'GPe_prec': precision[1],
                               'GPe_reca': recall[1],
                               'GPi_prec': precision[2],
                               'GPi_reca': recall[2],
                           })

        # Write a nifti image
        import SimpleITK as sitk
        if input_arr.ndim <= 3:
            input_img = sitk.GetImageFromArray(
                np.transpose(input_arr, (2, 1, 0)))
            input_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])  # Original
        else:
            input_arr = np.squeeze(
                input_arr)  #; input_arr = np.squeeze(input_arr[0, :, :, :])

        # Save labels and predictions
        label_img = sitk.GetImageFromArray(np.transpose(label_arr, (2, 1, 0)))
        label_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
        predi_img = sitk.GetImageFromArray(np.transpose(output_arr, (2, 1, 0)))
        predi_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
        sitk.WriteImage(
            label_img,
            os.path.join(save_directory, '{}_lbl.nii.gz'.format(identifier)))
        sitk.WriteImage(
            predi_img,
            os.path.join(save_directory, '{}_pred.nii.gz'.format(identifier)))

        # Save the logits - probability maps for each class
        # for qq in range(int(json_opts.model.output_nc)):
        #     logit_img = sitk.GetImageFromArray(np.transpose(np.squeeze(logit_arr[qq, ...]), (2, 1, 0))) #
        #     logit_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
        #     sitk.WriteImage(logit_img, os.path.join(save_directory, '{}_logit_class_{}.nii.gz'.format(iteration, qq)))

        #print("iteration: {}".format(iteration))
        for ii in range(len(model_types)):
            #print("{}".format(input_arr.shape))
            try:
                input_img = sitk.GetImageFromArray(
                    np.transpose(np.squeeze(input_arr[ii, :, :, :]),
                                 (2, 1, 0)))
            except:
                input_img = sitk.GetImageFromArray(
                    np.transpose(np.squeeze(input_arr), (2, 1, 0)))
            input_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0,
                                    1])  # For multimodal
            sitk.WriteImage(
                input_img,
                os.path.join(
                    save_directory,
                    '{}_img_{}.nii.gz'.format(identifier, model_types[ii])))

    stat_logger.statlogger2csv(split='test',
                               out_csv_name=os.path.join(
                                   save_directory, 'stats.csv'))
    for key, (mean_val,
              std_val) in stat_logger.get_errors(split='test').items():
        print('-', key, ': \t{0:.3f}+-{1:.3f}'.format(mean_val, std_val), '-')
def visualization(json_name):
    layer_name = 'attentionblock2'
    json_opts = json_file_to_pyobj(json_name)
    train_opts = json_opts.training

    # Setup the NN Model
    model = get_model(json_opts.model)
    save_directory = os.path.join(model.save_dir, train_opts.arch_type,
                                  layer_name)
    mkdirfun(save_directory)
    #epochs = range(485, 490, 3)
    att_maps = list()
    int_imgs = list()
    subject_id = int(1)

    #json_opts = json_opts._replace(model=json_opts.model._replace(which_epoch=epoch))
    model = get_model(json_opts.model)
    # Setup Dataset and Augmentation
    dataset_class = get_dataset(train_opts.arch_type)
    dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
    dataset_transform = get_dataset_transformation(train_opts.arch_type,
                                                   opts=json_opts.augmentation)

    # Setup Data Loader
    dataset = dataset_class(dataset_path,
                            split='test',
                            transform=dataset_transform['valid'])
    data_loader = DataLoader(dataset=dataset,
                             num_workers=8,
                             batch_size=1,
                             shuffle=False)

    for iteration, (input_arr, input_meta, _) in enumerate(data_loader, 1):
        if iteration == subject_id:
            # load the input image into the model
            model.set_input(input_arr)
            inp_fmap, out_fmap = model.get_feature_maps(layer_name=layer_name,
                                                        upscale=False)

            # Display the input image and Down_sample the input image
            orig_input_img = model.input.permute(2, 3, 4, 1, 0).cpu().numpy()
            upsampled_attention = F.upsample(
                out_fmap[1], size=input_arr.size()[2:],
                mode='trilinear').data.squeeze().permute(1, 2, 3,
                                                         0).cpu().numpy()

            # Append it to the list
            int_imgs.append(orig_input_img[:, :, :, 0, 0])
            att_maps.append(upsampled_attention[:, :, :, 1])

            # return the model
            model.destructor()

    # Write the attentions to a nifti image
    input_meta['name'][0] = str(subject_id) + '_img_2.nii.gz'
    int_imgs = numpy.array(int_imgs).transpose([1, 2, 3, 0])
    write_nifti_img(int_imgs, input_meta, savedir=save_directory)

    input_meta['name'][0] = str(subject_id) + '_att_2.nii.gz'
    att_maps = numpy.array(att_maps).transpose([1, 2, 3, 0])
    write_nifti_img(att_maps, input_meta, savedir=save_directory)
def train(arguments):
    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset(train_opts.arch_type)
    ds_path = get_dataset_path(arch_type, json_opts.data_path)
    ds_transform = get_dataset_transformation(arch_type,
                                              opts=json_opts.augmentation)

    # Setup Data Loader - to RAM
    train_dataset = ds_class(ds_path,
                             split='train',
                             transform=ds_transform['train'],
                             preload_data=train_opts.preloadData)
    valid_dataset = ds_class(ds_path,
                             split='validation',
                             transform=ds_transform['valid'],
                             preload_data=train_opts.preloadData)
    test_dataset = ds_class(ds_path,
                            split='test',
                            transform=ds_transform['valid'],
                            preload_data=train_opts.preloadData)
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=0,
                              batch_size=train_opts.batchSize,
                              shuffle=False)
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=0,
                              batch_size=train_opts.batchSize,
                              shuffle=False)  # num_workers=16
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=0,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # Setup the NN Model
    dataDims = [
        json_opts.training.batchSize, 1, train_dataset.image_dims[0],
        train_dataset.image_dims[1], train_dataset.image_dims[2]
    ]  # This is required only for the STN based network
    model = get_model(json_opts.model, dataDims)
    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Save the json configuration file to the checkpoints directory
    Ind = json_filename.rfind('/')
    copyfile(
        json_filename,
        os.path.join(json_opts.model.checkpoints_dir,
                     json_opts.model.experiment_name, json_filename[Ind + 1:]))

    # Training Function
    model.set_scheduler(train_opts)
    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))

        # Training Iterations
        for epoch_iter, (images, labels, _) in tqdm(enumerate(train_loader, 1),
                                                    total=len(train_loader)):
            #for epoch_iter, (images, labels) in enumerate(train_loader, 1):
            # Make a training update
            model.set_input(images, labels)  # Load data to GPU memory
            model.optimize_parameters()
            #model.optimize_parameters_accumulate_grd(epoch_iter)

            # So we won't increase lambda inside the epoch except for the first time
            model.haus_flag = False

            # Error visualisation
            errors = model.get_current_errors()
            error_logger.update(errors, split='train')

        # Validation and Testing Iterations
        for loader, split in zip([valid_loader, test_loader],
                                 ['validation', 'test']):
            for epoch_iter, (images, labels, _) in tqdm(enumerate(loader, 1),
                                                        total=len(loader)):

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                # Error visualisation
                errors = model.get_current_errors()
                stats = model.get_segmentation_stats()
                error_logger.update({**errors, **stats}, split=split)

                # Visualise predictions
                visuals = model.get_current_visuals()
                #visualizer.display_current_results(visuals, epoch=epoch, save_result=False)

        # Update the plots
        for split in ['train', 'validation', 'test']:
            visualizer.plot_current_errors(epoch,
                                           error_logger.get_errors(split),
                                           split_name=split)
            visualizer.print_current_errors(epoch,
                                            error_logger.get_errors(split),
                                            split_name=split)
        error_logger.reset()

        # Save the model parameters
        if epoch % train_opts.save_epoch_freq == 0:
            model.save(epoch)

        # Update the model learning rate
        model.update_learning_rate()

        # Update the Hausdorff distance lambda
        if (epoch + 1) % json_opts.model.haus_update_rate == 0:
            model.haus_flag = True
            print("Hausdorff distance lambda has been updated.")
Exemplo n.º 10
0
def train(arguments):

    # Parse input arguments
    json_filename = arguments.config

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Create train-test-validation splits
    np.random.seed(28)
    root_dir = json_opts.data_path.ct_82
    num_files = len(get_dicom_dirs(os.path.join(root_dir, "image")))
    train_idx, test_idx, val_idx = get_train_test_val_indices(num_files,
                                                              test_frac=0.25,
                                                              val_frac=0.0)

    ds_transform = get_dataset_transformation(arch_type,
                                              opts=json_opts.augmentation)
    train_dataset = CT82Dataset(root_dir,
                                "train",
                                train_idx,
                                transform=ds_transform['train'],
                                resample=True,
                                preload_data=train_opts.preloadData)
    test_dataset = CT82Dataset(root_dir,
                               "test",
                               test_idx,
                               transform=ds_transform['valid'],
                               resample=True,
                               preload_data=train_opts.preloadData)
    # val_dataset = CT82Dataset(root_dir, "validation", val_idx, transform=ds_transform['valid'], resample=True)

    # Setup the NN Model
    model = get_model(json_opts.model)

    # Setup Data Loaders
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=train_opts.num_workers,
                              batch_size=train_opts.batchSize,
                              shuffle=True)
    # val_loader = DataLoader(dataset=val_dataset, num_workers=train_opts.num_workers, batch_size=train_opts.batchSize, shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=train_opts.num_workers,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # Define tensorboard writer
    writer = SummaryWriter(os.path.join(model.save_dir, "runs"))

    ## Add model architecture to TensorBoard
    # images, labels = iter(train_loader).next()
    # images = images.to('cuda') if model.use_cuda else images
    # writer.add_graph(model.net, images)
    # writer.flush()

    # Training
    model.set_scheduler(train_opts)
    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))

        # Training Iterations
        stats_dict = {}
        for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1),
                                                 total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()
            #model.optimize_parameters_accumulate_grd(epoch_iter)

            # Error visualisation
            errors = model.get_current_errors()
            for error_name, error in errors.items():
                if 'Train ' + error_name in stats_dict:
                    stats_dict['Train ' + error_name].append(error)
                else:
                    stats_dict['Train ' + error_name] = [error]

        # Validation and Testing Iterations
        with torch.no_grad():
            for epoch_iter, (images, labels) in tqdm(enumerate(test_loader, 1),
                                                     total=len(test_loader)):

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                errors = model.get_current_errors()
                for error_name, error in errors.items():
                    if 'Train ' + error_name in stats_dict:
                        stats_dict['Train ' + error_name].append(error)
                    else:
                        stats_dict['Train ' + error_name] = [error]
                stats = model.get_segmentation_stats()
                for stat_name, stat in stats.items():
                    if 'Test ' + stat_name in stats_dict:
                        stats_dict['Test ' + stat_name].append(stat)
                    else:
                        stats_dict['Test ' + stat_name] = [stat]

        # Save the model parameters
        if epoch % train_opts.save_epoch_freq == 0:
            model.save(epoch)

        # Update the model learning rate
        model.update_learning_rate()

        for k, v in stats_dict.items():
            writer.add_scalar(k, sum(v) / len(v), epoch)
        writer.flush()

    writer.close()
def test(arguments):

    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset(arch_type)
    ds_path = get_dataset_path(arch_type, json_opts.data_path)
    ds_transform = get_dataset_transformation(arch_type,
                                              opts=json_opts.augmentation)

    # Setup the NN Model
    with HiddenPrints():
        model = get_model(json_opts.model)

    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.8f} sec\tbp time: {1:.8f} sec per sample'.format(
            *model.get_fp_bp_time2((1, 1, 224, 288))))
        exit()

    # Setup Data Loader
    num_workers = train_opts.num_workers if hasattr(train_opts,
                                                    'num_workers') else 16

    valid_dataset = ds_class(ds_path,
                             split='val',
                             transform=ds_transform['valid'],
                             preload_data=train_opts.preloadData)
    test_dataset = ds_class(ds_path,
                            split='test',
                            transform=ds_transform['valid'],
                            preload_data=train_opts.preloadData)
    # loader
    batch_size = train_opts.batchSize
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=num_workers,
                              batch_size=train_opts.batchSize,
                              shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=0,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # Visualisation Parameters
    filename = 'test_loss_log.txt'
    visualizer = Visualiser(json_opts.visualisation,
                            save_dir=model.save_dir,
                            filename=filename)
    error_logger = ErrorLogger()

    # Training Function
    track_labels = np.arange(len(valid_dataset.label_names))
    model.set_labels(track_labels)
    model.set_scheduler(train_opts)

    if hasattr(model.net, 'deep_supervised'):
        model.net.deep_supervised = False

    # Validation and Testing Iterations
    pr_lbls = []
    gt_lbls = []
    for loader, split in zip([test_loader], ['test']):
        #for loader, split in zip([valid_loader, test_loader], ['validation', 'test']):
        model.reset_results()

        for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1),
                                                 total=len(loader)):

            # Make a forward pass with the model
            model.set_input(images, labels)
            model.validate()

        # Error visualisation
        errors = model.get_accumulated_errors()
        stats = model.get_classification_stats()
        error_logger.update({**errors, **stats}, split=split)

    # Update the plots
    # for split in ['train', 'validation', 'test']:
    for split in ['test']:
        # exclude bckground
        #track_labels = np.delete(track_labels, 3)
        #show_labels = train_dataset.label_names[:3] + train_dataset.label_names[4:]
        show_labels = valid_dataset.label_names
        visualizer.plot_current_errors(300,
                                       error_logger.get_errors(split),
                                       split_name=split,
                                       labels=show_labels)
        visualizer.print_current_errors(300,
                                        error_logger.get_errors(split),
                                        split_name=split)

        import pickle as pkl
        dst_file = os.path.join(model.save_dir, 'test_result.pkl')
        with open(dst_file, 'wb') as f:
            d = error_logger.get_errors(split)
            d['labels'] = valid_dataset.label_names
            d['pr_lbls'] = np.hstack(model.pr_lbls)
            d['gt_lbls'] = np.hstack(model.gt_lbls)
            pkl.dump(d, f)

    error_logger.reset()

    if arguments.time:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.8f} sec\tbp time: {1:.8f} sec per sample'.format(
            *model.get_fp_bp_time2((1, 1, 224, 288))))
Exemplo n.º 12
0
def train(arguments):

    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Create train-test-validation splits
    np.random.seed(41)
    root_dir = json_opts.data_path.ct_82
    num_files = len(get_dicom_dirs(os.path.join(root_dir, "image")))
    train_idx, test_idx, val_idx = get_train_test_val_indices(num_files,
                                                              test_frac=0.25,
                                                              val_frac=0.0)

    ds_transform = get_dataset_transformation(arch_type,
                                              opts=json_opts.augmentation)
    train_dataset = CT82Dataset(root_dir,
                                "train",
                                train_idx,
                                transform=ds_transform['train'],
                                resample=True,
                                preload_data=train_opts.preloadData)
    test_dataset = CT82Dataset(root_dir,
                               "test",
                               test_idx,
                               transform=ds_transform['valid'],
                               resample=True,
                               preload_data=train_opts.preloadData)
    # val_dataset = CT82Dataset(root_dir, "validation", val_idx, transform=ds_transform['valid'], resample=True)

    # Setup the NN Model
    model = get_model(json_opts.model)
    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Setup Data Loaders
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=train_opts.num_workers,
                              batch_size=train_opts.batchSize,
                              shuffle=True)
    # val_loader = DataLoader(dataset=val_dataset, num_workers=train_opts.num_workers, batch_size=train_opts.batchSize, shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=train_opts.num_workers,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Training Function
    model.set_scheduler(train_opts)
    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))

        # Training Iterations
        for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1),
                                                 total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()
            #model.optimize_parameters_accumulate_grd(epoch_iter)

            # Error visualisation
            errors = model.get_current_errors()
            error_logger.update(errors, split='train')

        # Validation and Testing Iterations
        with torch.no_grad():
            for epoch_iter, (images, labels) in tqdm(enumerate(test_loader, 1),
                                                     total=len(test_loader)):

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                # Error visualisation
                errors = model.get_current_errors()
                stats = model.get_segmentation_stats()
                error_logger.update({**errors, **stats}, split='test')

                # Visualise predictions
                visuals = model.get_current_visuals()
                visualizer.display_current_results(visuals,
                                                   epoch=epoch,
                                                   save_result=False)

            # Update the plots
            for split in ['train', 'test']:
                visualizer.plot_current_errors(epoch,
                                               error_logger.get_errors(split),
                                               split_name=split)
                visualizer.print_current_errors(epoch,
                                                error_logger.get_errors(split),
                                                split_name=split)
            error_logger.reset()

        # Save the model parameters
        if epoch % train_opts.save_epoch_freq == 0:
            model.save(epoch)

        # Update the model learning rate
        model.update_learning_rate()
Exemplo n.º 13
0
def train(arguments, data_splits, n_split = 0):

    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug
    predict_path = arguments.predict_path

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset(arch_type)
    ds_path  = get_dataset_path(arch_type, json_opts.data_path)
    ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation)

    # Setup the NN Model
    model = get_model(json_opts.model, im_dim = train_opts.im_dim, split=n_split)
    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(*model.get_fp_bp_time()))
        exit()

    # Setup Data Loader
    test_dataset  = ds_class(ds_path, split='test',  data_splits = data_splits['test'],  im_dim=train_opts.im_dim, transform=ds_transform['valid'], preload_data=train_opts.preloadData)
    test_loader  = DataLoader(dataset=test_dataset,  num_workers=2, batch_size=train_opts.batchSize, shuffle=False)

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Training Function
    model.set_scheduler(train_opts)
    
        
        

    # Validation and Testing Iterations
    loader, split = [test_loader, 'test']
    for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1), total=len(loader)):

        # Make a forward pass with the model
        model.set_input(images, labels)
        model.predict(predict_path)

        # Error visualisation
        errors = model.get_current_errors()
        stats = model.get_segmentation_stats()
        error_logger.update({**errors, **stats}, split=split)

        # Visualise predictions
        visuals = model.get_current_visuals()
        visualizer.display_current_results(visuals, epoch=1, save_result=False)

        del images, labels

    # Update the plots
    for split in ['test']:
        visualizer.plot_current_errors(1, error_logger.get_errors(split), split_name=split)
        visualizer.print_current_errors(1, error_logger.get_errors(split), split_name=split)
    error_logger.reset()
    # print("Memory Usage :", convert_bytes(torch.cuda.max_memory_allocated()))
    print("Number of parameters :", model.get_number_parameters())
Exemplo n.º 14
0
def train(arguments):

    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset("HDF5")
    ds_path = get_dataset_path("HDF5", json_opts.data_path)
    ds_transform = get_dataset_transformation(arch_type,
                                              opts=json_opts.augmentation)

    # Setup the NN Model
    model = get_model(json_opts.model)
    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Setup Data Loader
    train_dataset = ds_class(ds_path,
                             split='train',
                             transform=ds_transform['train'],
                             preload_data=train_opts.preloadData)
    valid_dataset = ds_class(ds_path,
                             split='validation',
                             transform=ds_transform['valid'],
                             preload_data=train_opts.preloadData)
    #test_dataset  = ds_class(ds_path, split='test',       transform=ds_transform['valid'], preload_data=train_opts.preloadData)
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=1,
                              batch_size=train_opts.batchSizeTrain,
                              shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=1,
                              batch_size=train_opts.batchSizeVal,
                              shuffle=False)
    #test_loader  = DataLoader(dataset=test_dataset,  num_workers=16, batch_size=train_opts.batchSize, shuffle=False)

    # Visualisation Parameters
    #visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    #error_logger = ErrorLogger()
    writer = SummaryWriter(log_dir=model.save_dir)

    # Training Function
    model.set_scheduler(train_opts)
    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))

        # Training Iterations
        train_loss_total = 0.0
        num_steps = 0
        for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1),
                                                 total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()
            #model.optimize_parameters_accumulate_grd(epoch_iter)

            # # Error visualisation
            # errors = model.get_current_errors()
            # error_logger.update(errors, split='train')

            #tensorboard loss
            train_loss_total += model.get_loss()
            num_steps += 1

        # tensorboard train loss
        train_loss_total_avg = train_loss_total / num_steps

        # Validation and Testing Iterations
        val_loss_total = 0.0
        num_steps = 0
        #for loader, split in zip([valid_loader, test_loader], ['validation', 'test']):
        for epoch_iter, (images, labels) in tqdm(enumerate(valid_loader, 1),
                                                 total=len(valid_loader)):
            split = 'validation'

            # Make a forward pass with the model
            model.set_input(images, labels)
            model.validate()

            # # Error visualisation
            # errors = model.get_current_errors()
            # stats = model.get_segmentation_stats()
            # error_logger.update({**errors, **stats}, split=split)

            # # Visualise predictions
            # visuals = model.get_current_visuals()
            # visualizer.display_current_results(visuals, epoch=epoch, save_result=False)

            #tensorboard loss
            val_loss_total += model.get_loss()
            num_steps += 1
        # tensorboard val loss
        val_loss_total_avg = val_loss_total / num_steps

        # # Update the plots
        # for split in ['train', 'validation']:#, 'test']:
        # #visualizer.plot_current_errors(epoch, error_logger.get_errors(split), split_name=split)
        # visualizer.print_current_errors(epoch, error_logger.get_errors(split), split_name=split)
        # error_logger.reset()

        # Visualize progress in tensorboard
        writer.add_scalars('losses', {
            'val_loss': val_loss_total_avg,
            'train_loss': train_loss_total_avg
        }, epoch)
        lr = model.optimizers[0].param_groups[0]['lr']
        writer.add_scalar('learning_rate', lr, epoch)

        # Save the model parameters
        if epoch % train_opts.save_epoch_freq == 0:
            model.save(epoch)

        # Update the model learning rate
        model.update_learning_rate()