def config_filenames(self, net_type, use_core, keep_spatial_dim=False):
        # init file managers
        Weights = File.Weights(net_type, output_dir=input_dir)
        if use_core:
            if keep_spatial_dim:
                Embed = File.Embed('SP_' + net_type, output_dir=output_dir)
            else:
                Embed = File.Embed(net_type, output_dir=output_dir)
        else:
            if net_type == 'dir':
                Embed = File.Pred(type='malig', pre='dir', output_dir=output_dir)
            elif net_type == 'dirR':
                Embed = File.Pred(type='rating', pre='dirR', output_dir=output_dir)
            elif net_type == 'dirS':
                Embed = File.Pred(type='size', pre='dirS', output_dir=output_dir)
            elif net_type == 'dirRS':
                # assert False # save rating and size in seperate files
                Embed = {}
                Embed['R'] = File.Pred(type='rating', pre='dirRS', output_dir=output_dir)
                Embed['S'] = File.Pred(type='size', pre='dirRS', output_dir=output_dir)
            else:
                print('{} not recognized'.format(net_type))
                assert False

        return Weights, Embed
Exemple #2
0
    def __init__(self, network = 'dir', pooling='max', categorize=False):
        self.network = network
        self.Weights = FileManager.Weights(network)
        self.Embed = FileManager.Embed(network)

        self.data_size = 144
        self.data_res = '0.5I'  # 'Legacy'
        self.data_sample = 'Normal'

        self.net_in_size = 128
        self.net_input_shape = (self.net_in_size, self.net_in_size, 1)
        self.net_out_size = 128
        self.net_normalize = True
        self.net_pool = pooling
        self.categorize = categorize

        self.model = None
def run(choose_model="DIR",
        epochs=200,
        config=0,
        skip_validation=False,
        no_training=False,
        config_name='LEGACY',
        load_data_from_predications=False):

    np.random.seed(1337)
    random.seed(1337)
    tf.set_random_seed(1234)
    K.set_session(tf.Session(graph=tf.get_default_graph()))

    ## --------------------------------------- ##
    ## ------- General Setup ----------------- ##
    ## --------------------------------------- ##

    #data
    dataset_type = 'Primary'
    data_size = 160
    if no_training:
        data_size = 160
    res = 0.5  # 'Legacy' #0.7 #0.5 #'0.5I'
    sample = 'Normal'  # 'UniformNC' #'Normal' #'Uniform'
    data_run = '813'
    data_epoch = 70
    return_predicted_ratings = not no_training
    use_gen = True
    #model
    model_size = 128
    input_shape = (model_size, model_size, 1)
    normalize = True
    out_size = 128
    do_augment = True
    if no_training:
        do_augment = False
    preload_weight = None

    print("-" * 30)
    print("Running {} for --** {} **-- model, with #{} configuration".format(
        "training" if not no_training else "validation", choose_model, config))
    if load_data_from_predications:
        print(
            "\tdata_run = {}, \n\tdata_epoch = {}, return_predicted_ratings = {}"
            .format(data_run, data_epoch, return_predicted_ratings))
    else:
        print(
            "\tdata_size = {},\n\tmodel_size = {},\n\tres = {},\n\tdo_augment = {}"
            .format(data_size, model_size, res, do_augment))
        print("\tdataset_type = {}".format(dataset_type))
    print("-" * 30)

    model = None

    data_augment_params = {
        'max_angle': 30,
        'flip_ratio': 0.5,
        'crop_stdev': 0.15,
        'epoch': 0
    }

    data_loader = build_loader(
        size=data_size,
        res=res,
        sample=sample,
        dataset_type=dataset_type,
        config_name=config_name,
        configuration=config,
        run=data_run,
        epoch=data_epoch,
        load_data_from_predictions=load_data_from_predications,
        return_predicted_ratings=return_predicted_ratings)

    ## --------------------------------------- ##
    ## ------- Prepare Direct Architecture ------- ##
    ## --------------------------------------- ##

    if choose_model is "DIR":
        # run = '300'  # SPIE avg-pool (data-aug, balanced=False,class_weight=True)
        # run = '301'  # SPIE max-pool (data-aug, balanced=False,class_weight=True)
        # run = '302'  # SPIE rmac-pool (data-aug, balanced=False,class_weight=True)

        # run = 'zzz'

        model = DirectArch(miniXception_loader,
                           input_shape,
                           output_size=out_size,
                           normalize=normalize,
                           pooling='msrmac')
        model.model.summary()
        model.compile(learning_rate=1e-3, decay=0)
        if use_gen:
            generator = DataGeneratorDir(
                data_loader,
                val_factor=0 if skip_validation else 1,
                balanced=False,
                data_size=data_size,
                model_size=model_size,
                batch_size=32,
                do_augment=do_augment,
                augment=data_augment_params,
                use_class_weight=True,
                use_confidence=False)
            model.load_generator(generator)
        else:
            dataset = load_nodule_dataset(size=data_size,
                                          res=res,
                                          sample=sample)
            images_train, labels_train, class_train, masks_train, _ = prepare_data_direct(
                dataset[2], num_of_classes=2)
            images_valid, labels_valid, class_valid, masks_valid, _ = prepare_data_direct(
                dataset[1], num_of_classes=2)
            images_train = np.array([
                crop_center(im, msk, size=model_size)[0]
                for im, msk in zip(images_train, masks_train)
            ])
            images_valid = np.array([
                crop_center(im, msk, size=model_size)[0]
                for im, msk in zip(images_valid, masks_valid)
            ])
            model.load_data(images_train,
                            labels_train,
                            images_valid,
                            labels_valid,
                            batch_size=32)

    if choose_model is "DIR_RATING":

        ### CLEAN SET
        # run = '800'  # rmac conf:size
        # run = '801'  # rmac conf:none
        # run = '802'  # rmac conf:rating-std
        # run = '803'  # max conf:none

        ### PRIMARY SET
        # run = '810'  # rmac conf:size
        # run = '811'  # rmac conf:none
        # run = '812'  # rmac conf:rating-std
        # run = '813'  # max conf:none
        # run = '814'  # max separated_prediction

        # run = '820'  # dirD, max, logcoh-loss
        # run = '821'  # dirD, max, pearson-loss
        # run = '822'  # dirD, max, KL-rank-loss
        # run = '823'  # dirD, max, poisson-rank-loss
        # run = '824'  # dirD, max, categorical-cross-entropy-loss
        # run = '825'  # dirD, max, ranked-pearson-loss
        # run = '826'  # dirD, max, KL-normalized-rank-loss
        # run = '827'  # dirD, max, KL-normalized-rank-loss (local-scaled) softmax
        # run = '828'  # dirD, max, KL-normalized-rank-loss (local-scaled) l2
        # run = '829'  # dirD, max, ranked-pearson-loss (local-scaled)

        # run = '830'  # dirD, rmac, logcoh-loss
        # run = '831'  # dirD, rmac, pearson-loss
        # run = '832'  # dirD, rmac, KL-rank-loss
        # run = '833'  # dirD, rmac, poisson-rank-loss
        # run = '834'  # dirD, rmac, categorical-cross-entropy-loss
        # run = '835'  # dirD, rmac, ranked-pearson-loss
        # run = '836'  # dirD, rmac, KL-normalized-rank-loss

        # run = '841'  # dirD, max, pearson-loss    pre:dirR813-50
        # run = '842b'  # dirD, max, KL-rank-loss    pre:dirR813-50  (b:lr-4)
        # run = '846'  # dirD, max, KL-norm-loss    pre:dirR813-50

        # run = '851'  # dirD, rmac, pearson-loss   pre:dirR813-50
        # run = '852'  # dirD, rmac, KL-rank-loss   pre:dirR813-50
        # run = '856'  # dirD, rmac, KL-norm-loss   pre:dirR813-50

        # run = '860'  # dirD, max, KL-loss    pre:dirR813-50  (b:lr-4, freeze:7)
        # run = '861'  # dirD, max, KL-loss    pre:dirR813-50  (b:lr-4, freeze:17)
        # run = '862'  # dirD, max, KL-loss    pre:dirR813-50  (b:lr-4, freeze:28)
        # run = '863'  # dirD, max, KL-loss    pre:dirR813-50  (b:lr-4, freeze:39)

        # run = '870'  # dirRD, max, KL-loss    schd: 00
        # run = '871'  # dirRD, max, KL-loss    schd: 01
        # run = '872'  # dirRD, max, KL-loss    schd: 02
        # run = '873'  # dirRD, max, KL-loss    schd: 03
        # run = '874'  # dirRD, max, KL-loss    schd: 04
        # run = '875'  # dirRD, max, KL-loss    schd: 05
        # run = '876'  # dirRD, max, KL-loss    schd: 06
        # run = '877b'  # dirRD, max, KL-loss    schd: 07b
        # run = '878'  # dirRD, max, KL-loss    schd: 08
        # run = '879'  # dirRD, max, KL-loss    schd: 09

        # run = '888'  # dirRD, max, KL-loss    schd: 08, on partial data SUP
        # run = '882'  # dirRD, max, KL-loss    schd:

        run = '898b'  # dirRD, max, KL-loss    schd: 08, on partial data UNSUP
        # run = '890b'  # dirR
        # run = '892b'  # dirRD, max, KL-loss

        # run = 'ccc'

        obj = 'rating_distance-matrix'  # 'distance-matrix' 'rating' 'rating-size'

        rating_scale = 'none'
        reg_loss = None  # {'SampleCorrelation': 0.0}  # 'Dispersion', 'Std', 'FeatureCorrelation', 'SampleCorrelation'
        batch_size = 32

        epoch_pre = 50
        preload_weight = None
        # FileManager.Weights('dirR', output_dir=input_dir).name(run='813c{}'.format(config), epoch=epoch_pre)
        # FileManager.Weights('dirR', output_dir=input_dir).name(run='251c{}'.format(config), epoch=epoch_pre)

        model = DirectArch(miniXception_loader,
                           input_shape,
                           output_size=out_size,
                           objective=obj,
                           separated_prediction=False,
                           normalize=normalize,
                           pooling='max',
                           l1_regularization=None,
                           regularization_loss=reg_loss,
                           batch_size=batch_size)

        if preload_weight is not None:
            model.load_core_weights(preload_weight, 39)
            # 7:    freeze 1 blocks
            # 17:   freeze 2 blocks
            # 28:   freeze 3 blocks
            # 39:   freeze 4 blocks

        model.model.summary()

        should_use_scheduale = (reg_loss is not None) or (obj in [
            'rating_size', 'rating_distance-matrix'
        ])

        # scheduale 00:     870
        # sched = [{'epoch': 00, 'weights': [0.9, 0.1]},
        #         {'epoch': 40, 'weights': [0.5, 0.5]},
        #         {'epoch': 80, 'weights': [0.1, 0.9]}] \
        #    if should_use_scheduale else []

        # scheduale 01:     871
        # sched = [{'epoch': 00, 'weights': [1.0, 0.0]},
        #         {'epoch': 50, 'weights': [0.0, 1.0]}] \
        #    if should_use_scheduale else []

        # scheduale 02:     872
        # sched = [{'epoch': 00, 'weights': [0.9, 0.1]},
        #       {'epoch': 50, 'weights': [0.1, 0.9]}] \
        #   if should_use_scheduale else []

        # scheduale 03:     873
        # sched = [{'epoch': 00, 'weights': [0.9, 0.1]},
        #        {'epoch': 50, 'weights': [0.5, 0.5]},
        #         {'epoch': 100, 'weights': [0.1, 0.9]}] \
        #    if should_use_scheduale else []

        # scheduale 04:     874
        # sched = [{'epoch': 00, 'weights': [1.0, 0.0]},
        #        {'epoch': 50, 'weights': [0.0, 0.1]}] \
        #   if should_use_scheduale else []

        # scheduale 05:     875
        # sched = [{'epoch': 00, 'weights': [1.0, 0.0]},
        #        {'epoch': 50, 'weights': [0.0, 1.0]},
        #         {'epoch': 100, 'weights': [0.0, 0.1]}] \
        #    if should_use_scheduale else []

        # scheduale 06:     876
        # sched = [{'epoch': 00, 'weights': [0.9, 0.1]},
        #         {'epoch': 40, 'weights': [0.5, 0.5]},
        #         {'epoch': 60, 'weights': [0.1, 0.1]},
        #         {'epoch': 80, 'weights': [0.0, 0.1]},
        #         {'epoch': 100, 'weights': [0.0, 0.05]}] \
        #    if should_use_scheduale else []

        # scheduale 07b:     877b
        # sched = [{'epoch': 00,  'weights': [1.0, 0.0]},
        #         {'epoch': 50,  'weights': [0.0, 1.0]},
        #         {'epoch': 80,  'weights': [0.0, 0.1]},
        #         {'epoch': 100, 'weights': [0.0, 0.05]}] \
        #    if should_use_scheduale else []

        # scheduale 08b:     878
        # sched = [{'epoch': 00, 'weights': [0.9, 0.1]},
        #         {'epoch': 40, 'weights': [0.5, 0.5]},
        #         {'epoch': 80, 'weights': [0.0, 0.1]}] \
        #    if should_use_scheduale else []

        # scheduale 09:     879
        # sched = [{'epoch': 00, 'weights': [0.9, 0.1]},
        #         {'epoch': 20, 'weights': [0.7, 0.3]},
        #         {'epoch': 40, 'weights': [0.5, 0.5]},
        #         {'epoch': 60, 'weights': [0.3, 0.3]},
        #         {'epoch': 80, 'weights': [0.0, 0.1]}] \
        #    if should_use_scheduale else []

        # scheduale      892/882
        sched = [{'epoch': 00, 'weights': [0.9, 0.1]},
                 {'epoch': 80, 'weights': [0.5, 0.5]},
                 {'epoch': 120, 'weights': [0.0, 0.1]}] \
            if should_use_scheduale else []

        loss = dict()
        loss['predictions'] = 'logcosh'
        loss['predictions_size'] = 'logcosh'
        loss['distance_matrix'] = distance_matrix_rank_loss_adapter(
            K_losses.kullback_leibler_divergence, 'KL')
        # distance_matrix_logcosh
        # pearson_correlation
        # distance_matrix_rank_loss_adapter(K_losses.kullback_leibler_divergence, 'KL')
        # distance_matrix_rank_loss_adapter(K_losses.poisson, 'poisson')
        # distance_matrix_rank_loss_adapter(K_losses.categorical_crossentropy, 'entropy')
        model.compile(
            learning_rate=1e-3 if (preload_weight is None) else 1e-4,
            loss=loss,
            scheduale=sched
        )  # mean_squared_logarithmic_error, binary_crossentropy, logcosh

        if use_gen:
            generator = DataGeneratorDir(
                data_loader,
                val_factor=0 if skip_validation else 1,
                data_size=data_size,
                model_size=model_size,
                batch_size=batch_size,
                objective=obj,
                rating_scale=rating_scale,
                weighted_rating=('distance-matrix' in obj),
                balanced=False,
                do_augment=do_augment,
                augment=data_augment_params,
                use_class_weight=False,
                use_confidence=False)
            model.load_generator(generator)
        else:
            dataset = load_nodule_dataset(size=data_size,
                                          res=res,
                                          sample=sample,
                                          dataset_type=dataset_type)
            images_train, labels_train, masks_train = prepare_data_direct(
                dataset[2], objective='rating', rating_scale=rating_scale)
            images_valid, labels_valid, masks_valid = prepare_data_direct(
                dataset[1], objective='rating', rating_scale=rating_scale)
            images_train = np.array([
                crop_center(im, msk, size=model_size)[0]
                for im, msk in zip(images_train, masks_train)
            ])
            images_valid = np.array([
                crop_center(im, msk, size=model_size)[0]
                for im, msk in zip(images_valid, masks_valid)
            ])
            model.load_data(images_train,
                            labels_train,
                            images_valid,
                            labels_valid,
                            batch_size=batch_size)

    ## --------------------------------------- ##
    ## ------- Prepare Siamese Architecture ------ ##
    ## --------------------------------------- ##

    if choose_model is "SIAM":
        # run = '300'  # l1, avg-pool (data-aug, balanced=True, class_weight=False)
        # run = '301'  # l1, max-pool (data-aug, balanced=True, class_weight=False)
        # run = '302'  # l1, rmac-pool (data-aug, balanced=True, class_weight=False)
        # run = '310'  # l2, avg-pool (data-aug, balanced=True, class_weight=False)
        # run = '311'  # l2, max-pool (data-aug, balanced=True, class_weight=False)
        # run = '312'  # l2, rmac-pool (data-aug, balanced=True, class_weight=False)
        # run = '320'  # cos, avg-pool (data-aug, balanced=True, class_weight=False)
        # run = '321'  # cos, max-pool (data-aug, balanced=True, class_weight=False)
        # run = '322b'  # cos, rmac-pool (data-aug, balanced=True, class_weight=False)

        # b/c - changed margin-loss params
        # run = '313c'  # l2, max-pool MARGINAL-LOSS (data-aug, balanced=True, class_weight=False)
        # run = '314c'  # l2, rmac-pool MARGINAL-LOSS (data-aug, balanced=True, class_weight=False)
        # run = '323c'  # cos, max-pool MARGINAL-LOSS (data-aug, balanced=True, class_weight=False)
        # run = '324c'  # cos, rmac-pool MARGINAL-LOSS (data-aug, balanced=True, class_weight=False)

        # run = 'zzz'

        batch_size = 64 if local else 128

        # model
        generator = DataGeneratorSiam(data_loader,
                                      data_size=data_size,
                                      model_size=model_size,
                                      batch_size=batch_size,
                                      val_factor=0 if skip_validation else 3,
                                      balanced=True,
                                      objective="malignancy",
                                      do_augment=do_augment,
                                      augment=data_augment_params,
                                      use_class_weight=False)

        model = SiamArch(miniXception_loader,
                         input_shape,
                         output_size=out_size,
                         batch_size=batch_size,
                         distance='l2',
                         normalize=normalize,
                         pooling='msrmac')
        model.model.summary()
        model.compile(learning_rate=1e-3, decay=0)
        if use_gen:
            model.load_generator(generator)
        else:
            imgs_trn, lbl_trn = generator.next_train().__next__()
            imgs_val, lbl_val = generator.next_val().__next__()
            model.load_data(imgs_trn, lbl_trn, imgs_val, lbl_val)

    if choose_model is "SIAM_RATING":
        ### clean set
        # run = '400'  # l2-rmac no-conf
        # run = '401'  # cosine-rmac no-conf
        # run = '402'  # l2-rmac conf
        # run = '403'  # cosine-rmac conf
        # run = '404'  # l2-max no-conf
        # run = '405'  # cosine-max no-conf

        ### primary set
        # run = '410'  # l2-rmac no-conf
        # run = '411'  # cosine-rmac no-conf
        # run = '412'  # l2-rmac conf
        # run = '413'  # cosine-rmac conf
        # run = '414'  # l2-max no-conf
        # run = '415'  # cosine-max no-conf

        run = 'zzz'

        obj = 'rating'  # rating / size / rating_size
        batch_size = 16 if local else 64
        reg_loss = None  # {'SampleCorrating_clusters_distance_and_stdrelation': 0.1}  # 'Dispersion', 'Std', 'FeatureCorrelation', 'SampleCorrelation'

        epoch_pre = 60
        preload_weight = None  # FileManager.Weights('dirR', output_dir=input_dir).name(run='251c{}'.format(config), epoch=70)

        should_use_scheduale = (reg_loss is not None) or (obj == 'rating_size')
        '''
        sched = [{'epoch': 00, 'weights': [0.1, 0.9]},
                 {'epoch': 30, 'weights': [0.4, 0.6]},
                 {'epoch': 60, 'weights': [0.6, 0.4]},
                 {'epoch': 80, 'weights': [0.9, 0.1]},
                 {'epoch': 100, 'weights': [1.0, 0.0]}] \
            if should_use_scheduale else []
        '''
        sched = [{'epoch': 00, 'weights': [0.1, 0.9]},
                 {'epoch': 20, 'weights': [0.4, 0.6]},
                 {'epoch': 30, 'weights': [0.6, 0.4]},
                 {'epoch': 50, 'weights': [0.9, 0.1]},
                 {'epoch': 80, 'weights': [1.0, 0.0]}] \
            if should_use_scheduale else []
        # model
        generator = DataGeneratorSiam(data_loader,
                                      data_size=data_size,
                                      model_size=model_size,
                                      batch_size=batch_size,
                                      train_facotr=2,
                                      val_factor=0 if skip_validation else 3,
                                      balanced=False,
                                      objective=obj,
                                      weighted_rating=True,
                                      do_augment=do_augment,
                                      augment=data_augment_params,
                                      use_class_weight=False,
                                      use_confidence=False)

        model = SiamArch(miniXception_loader,
                         input_shape,
                         output_size=out_size,
                         objective=obj,
                         batch_size=batch_size,
                         distance='cosine',
                         normalize=normalize,
                         pooling='rmac',
                         regularization_loss=reg_loss,
                         l1_regularization=False)

        if preload_weight is not None:
            model.load_core_weights(preload_weight)
        model.model.summary()
        model.compile(learning_rate=1e-3,
                      decay=0,
                      loss='logcosh',
                      scheduale=sched)  # mean_squared_error, logcosh
        model.load_generator(generator)

    ## --------------------------------------- ##
    ## ------- Prepare Triplet Architecture ------ ##
    ## --------------------------------------- ##

    if choose_model is "TRIPLET":

        # run = '000'  # rmac softplus, b16
        # run = '001'  # rmac hinge, b16, pre:dirR813-50
        # run = '002'  # rmac hinge, b32, pre:dirR813-50
        # run = '003'  # rmac hinge, b64, pre:dirR813-50
        # run = '004'  # rmac hinge, b128, pre:dirR813-50
        # run = '005'  # rmac hinge, b64, pre:dirR813-50
        run = '006'  # rmac rank, b64, pre:dirR813-50

        # run = 'zzz'

        objective = 'rating'
        use_rank_loss = True

        batch_size = 16 if local else 64

        gen = True
        epoch_pre = 50
        preload_weight = FileManager.Weights(
            'dirR', output_dir=input_dir).name(run='813c{}'.format(config),
                                               epoch=epoch_pre)

        # model
        model = TripArch(miniXception_loader,
                         input_shape,
                         objective=objective,
                         output_size=out_size,
                         distance='l2',
                         normalize=True,
                         pooling='msrmac',
                         categorize=use_rank_loss)

        if preload_weight is not None:
            model.load_core_weights(preload_weight)
        model.model.summary()
        model.compile(learning_rate=1e-3, decay=0)

        generator = DataGeneratorTrip(data_loader,
                                      data_size=data_size,
                                      model_size=model_size,
                                      batch_size=batch_size,
                                      objective=objective,
                                      balanced=(objective == 'malignancy'),
                                      categorize=use_rank_loss,
                                      val_factor=0 if skip_validation else 1,
                                      train_factor=2,
                                      do_augment=do_augment,
                                      augment=data_augment_params,
                                      use_class_weight=False,
                                      use_confidence=False)
        if gen:
            model.load_generator(generator)
        else:
            imgs_trn, lbl_trn = generator.next_train().__next__()
            imgs_val, lbl_val = generator.next_val().__next__()
            model.load_data(imgs_trn, lbl_trn, imgs_val, lbl_val)

    ## --------------------------------------- ##
    ## -------      RUN             ------ ##
    ## --------------------------------------- ##

    cnf_id = config if config_name == 'LEGACY' else CrossValidationManager(
        config_name).get_run_id(config)
    run_name = '{}{}c{}'.format('', run, cnf_id)
    print('Current Run: {}'.format(run_name))
    if no_training:
        model.last_epoch = epochs
        model.run = run_name
    else:
        model.train(run=run_name,
                    epoch=(0 if preload_weight is None else epoch_pre),
                    n_epoch=epochs,
                    gen=use_gen,
                    do_graph=False)

    return model
    for conf in [1]:  # range(n_groups):
        run = run_id + 'c{}'.format(conf)
        if True:
            print("Predicting Rating for " + run)
            PredRating = PredictRating(pooling=pooling)
            PredRating.load_dataset(data_subset_id=DataSubSet,
                                    full=full,
                                    include_unknown=include_unknown,
                                    size=128,
                                    rating_scale=rating_scale,
                                    configuration=conf)
            preds = []
            valid_epochs = []
            for e in epochs:
                WeightsFile = FileManager.Weights('dirR').name(run, epoch=e)
                PredFile = FileManager.Pred(type='rating', pre='dirR')
                out_file = PredFile(run=run, dset=post)

                data, out_filename = PredRating.predict_rating(
                    WeightsFile, out_file)
                images_test, pred, meta_test, classes_test, labels_test, masks_test = data
                preds.append(np.expand_dims(pred, axis=0))
            preds = np.concatenate(preds, axis=0)
            pickle.dump((preds, np.array(epochs), meta_test, images_test,
                         classes_test, labels_test, masks_test),
                        open(out_filename, 'bw'))
        else:
            print("Predicting Malignancy for " + run)
            PredMal = PredictMal(pooling=pooling)
            for e in epochs:
in_size = 128
out_size = 128
normalize = True

load = False
evaluate = False
force = False

# 0     Test
# 1     Validation
# 2     Training
DataSubSet = 2

run = '000'
epoch = 5
WeightsFile = FileManager.Weights('siamR').name(run, epoch=epoch)

pred_file_format = '.\output\embed\pred_siam{}_E{}_{}.p'


def pred_filename(run, epoch, post):
    return pred_file_format.format(run, epoch, post)


## ========================= ##
## ======= Load Data ======= ##
## ========================= ##

if DataSubSet == 0:
    post = "Test"
elif DataSubSet == 1:
# ===================

inp_size = 144
net_size = 128
out_size = 128
input_shape = (net_size, net_size, 1)
res     = 'Legacy'
sample  = 'Normal' #'UniformNC'

# 0     Test
# 1     Validation
# 2     Training
DataSubSet = 1
dsets = ['Test', 'Valid', 'Train']

Weights = FileManager.Weights('siam')

wRuns = ['078X']
wEpchs= [24]

run = wRuns[0]
epoch = wEpchs[0]

# Load Data
# =================

images, labels, masks, meta = \
                    prepare_data(load_nodule_dataset(size=inp_size, res=res, sample=sample)[DataSubSet],
                                 categorize=False,
                                 reshuffle=False,
                                 return_meta=True,