コード例 #1
0
])
Censored = Data['Censored'].reshape([
    N,
])
fnames = Data[dtype + '_Symbs']
fnames = [j.split(' ')[0] for j in fnames]
Data = None

#%%
# Get result files
#==============================================================================

# Getting at-risk groups
t_batch, o_batch, at_risk_batch, x_batch = \
    sUtils.calc_at_risk(Survival,
                        1-Censored,
                        Features)

sys.exit()

#%%
# Get mask (to be multiplied by Pij) ******************

n_batch = t_batch.shape[0]
Pij_mask = np.zeros((n_batch, n_batch))

# Get difference in outcomes between all cases
if mask_type == 'observed':
    outcome_diff = np.abs(t_batch[None, :] - t_batch[:, None])

for idx in range(n_batch):
コード例 #2
0
import sys
sys.path.append('/home/mohamed/Desktop/CooperLab_Research/KNN_Survival/Codes')
import SurvivalUtils as sUtils

import tensorflow as tf
import numpy as np

#%% 
#
# Generate simulated data
#
n = 30; d = 140
X_input = np.random.rand(n, d)
T = np.random.randint(0, 300, [n,])
C = np.random.randint(0, 2, [n,])
T, O, at_risk, X_input = sUtils.calc_at_risk(T, 1-C, X_input)

#%%

# -----------------------------
# Add to graph (for demo)
tf.reset_default_graph()
X_input = tf.Variable(X_input)
T = tf.Variable(T, dtype='float32')
O = tf.Variable(O, dtype='float32')
at_risk = tf.Variable(at_risk)

# for now, let's assume we already NCA_transformed X
X_transformed = X_input

# no of feats and split size
コード例 #3
0
    def predict(self,
                neighbor_idxs,
                Survival_train,
                Censored_train,
                Survival_test=None,
                Censored_test=None,
                K=15,
                Method='non-cumulative'):
        """
        Predict testing set using 'prototype' (i.e. training) set using KNN
        
        neighbor_idxs - indices of nearest neighbors; (N_test, N_train)
        Survival_train - training sample time-to-event; (N,) np array
        Censored_train - training sample censorship status; (N,) np array
        K           - number of nearest-neighbours to use, int
        Method      - cumulative vs non-cumulative probability
        """

        # Keep only desired K
        neighbor_idxs = neighbor_idxs[:, 0:K]

        # Initialize
        N_test = neighbor_idxs.shape[0]
        T_test = np.zeros([N_test])

        if Method == 'non-cumulative':

            # Convert outcomes to "alive status" at each time point
            alive_train = sUtils.getAliveStatus(Survival_train, Censored_train)

            # Get survival prediction for each patient
            for idx in range(N_test):

                status = alive_train[neighbor_idxs[idx, :], :]
                totalKnown = np.sum(status >= 0, axis=0)
                status[status < 0] = 0

                # remove timepoints where there are no known statuses
                status = status[:, totalKnown != 0]
                totalKnown = totalKnown[totalKnown != 0]

                # get "average" predicted survival time
                status = np.sum(status, axis=0) / totalKnown

                # now get overall time prediction
                T_test[idx] = np.sum(status)

        elif Method == 'cumulative':

            for idx in range(N_test):

                # Get at-risk groups for each time point for nearest neighbors
                T = Survival_train[neighbor_idxs[idx, :]]
                O = 1 - Censored_train[neighbor_idxs[idx, :]]
                T, O, at_risk, _ = sUtils.calc_at_risk(T, O)

                N_at_risk = K - at_risk

                # Calcuate cumulative probability of survival
                P = np.cumprod((N_at_risk - O) / N_at_risk)

                # now get overall time prediction
                T_test[idx] = np.sum(P)

        else:
            raise ValueError(
                "Method is either 'cumulative' or 'non-cumulative'.")

        # Get c-index
        #======================================================================
        CI = 0
        if Survival_test is not None:
            assert (Censored_test is not None)
            CI = sUtils.c_index(T_test,
                                Survival_test,
                                Censored_test,
                                prediction_type='survival_time')

        return T_test, CI
コード例 #4
0
    Censored = np.int32(Data['Censored']).reshape([
        N,
    ])
    fnames = Data['Integ_Symbs']
    #fnames = Data['Gene_Symbs']

    # remove zero-variance features
    fvars = np.std(Features, 0)
    keep = fvars > 0
    Features = Features[:, keep]
    fnames = fnames[keep]
    N, D = Features.shape  # after feature removal

    # Getting at-risk groups (trainign set)
    Features, Survival, Observed, at_risk = \
      sUtils.calc_at_risk(Features, Survival, 1-Censored)

    ## Limit N (for prototyping) ----
    #n = 100
    #Features = Features[0:n, :]
    #Survival = Survival[0:n]
    #Observed = Observed[0:n]
    #at_risk = at_risk[0:n]
    #--------------------------------

    # *************************************************************
    # Z-scoring survival to prevent numerical errors
    Survival = (Survival - np.mean(Survival)) / np.std(Survival)
    # *************************************************************

    #%%============================================================================
コード例 #5
0
    def train(self,
              features,
              survival,
              censored,
              features_valid=None,
              survival_valid=None,
              censored_valid=None,
              COMPUT_GRAPH_PARAMS={},
              BATCH_SIZE=20,
              PLOT_STEP=10,
              MODEL_SAVE_STEP=10,
              MAX_ITIR=100):
        """
        train a survivalNCA model
        features - (N,D) np array
        survival and censored - (N,) np array
        """

        #pUtils.Log_and_print("Training survival NCA model.")

        # Initial preprocessing and sanity checks
        #======================================================================

        #pUtils.Log_and_print("Initial preprocessing.")

        assert len(features.shape) == 2
        assert len(survival.shape) == 1
        assert len(censored.shape) == 1

        USE_VALID = False
        if features_valid is not None:
            USE_VALID = True
            assert (features_valid.shape[1] == features.shape[1])
            assert (survival_valid is not None)
            assert (censored_valid is not None)

        # normalize (for numeric stability)
        epsilon = 1e-10
        survival = (survival / self.T_MAX) + epsilon
        if USE_VALID:
            survival_valid = (survival_valid / self.T_MAX) + epsilon

        # Define computational graph
        #======================================================================

        COMPUT_GRAPH_PARAMS['dim_input'] = features.shape[1]
        graph = self._build_computational_graph(COMPUT_GRAPH_PARAMS)

        # Begin session
        #======================================================================

        #print("Running TF session.")
        #pUtils.Log_and_print("Running TF session.")

        with tf.Session() as sess:

            # Initial ground work
            #==================================================================

            # op to save/restore all the variables
            saver = tf.train.Saver()

            if "checkpoint" in os.listdir(self.WEIGHTPATH):
                # load existing weights
                #pUtils.Log_and_print("Restoring saved model ...")
                saver.restore(sess,
                              self.WEIGHTPATH + self.description + ".ckpt")
                #pUtils.Log_and_print("Model restored.")

            else:
                # start a new model
                sess.run(tf.global_variables_initializer())

            # for tensorboard visualization
            #train_writer = tf.summary.FileWriter(self.RESULTPATH + 'model/tensorboard',
            #                                     sess.graph)

            # Define some methods
            #==================================================================

            # periodically save model
            def _saveTFmodel():
                """Saves model weights using tensorflow saver"""

                # save weights
                #pUtils.Log_and_print("\nSaving TF model weights...")
                #save_path = saver.save(sess, \
                #                self.WEIGHTPATH + self.description + ".ckpt")
                #pUtils.Log_and_print("Model saved in file: %s" % save_path)

                # save attributes
                self.save()

            # monitor
            def _monitorProgress():
                """Monitor cost"""

                cs = np.array(self.Costs_epochLevel_train)
                epoch_no = np.arange(len(cs))
                cs = np.concatenate((epoch_no[:, None], cs), axis=1)

                cs_valid = None
                if USE_VALID:
                    cs_valid = np.array(self.Costs_epochLevel_valid)

                #timestamp = str(datetime.datetime.today()).replace(' ','_')
                #timestamp.replace(":", '_')
                #self._plotMonitor(arr= cs, arr2= cs_valid,
                #             title= "cost vs. epoch",
                #             xlab= "epoch", ylab= "cost",
                #             savename= self.RESULTPATH + "plots/" +
                #              self.description + "cost_" + timestamp + ".svg")

            # Begin epochs
            #==================================================================

            try:
                itir = 0

                #print("\n\tepoch\tbatch\tcost")
                #print("\t-----------------------")

                while itir < MAX_ITIR:

                    #pUtils.Log_and_print("\n\tTraining epoch {}\n".format(self.EPOCHS_RUN))

                    itir += 1
                    cost_tot = 0
                    cost_tot_valid = 0

                    # Shuffle so that training batches differ every epoch
                    #==========================================================

                    idxs = np.arange(features.shape[0])
                    np.random.shuffle(idxs)
                    features = features[idxs, :]
                    survival = survival[idxs]
                    censored = censored[idxs]

                    # Divide into balanced batches
                    #==========================================================

                    # Get balanced batches (if relevant)
                    if BATCH_SIZE < censored.shape[0]:
                        batchIdxs = dm.get_balanced_batches(
                            censored, BATCH_SIZE=BATCH_SIZE)
                    else:
                        batchIdxs = [np.arange(censored.shape[0])]

                    if USE_VALID:
                        batchIdxs_valid = \
                            dm.get_balanced_batches(censored_valid, BATCH_SIZE = BATCH_SIZE)

                    # Run over training set
                    #==========================================================

                    for batchidx, batch in enumerate(batchIdxs):

                        # Getting at-risk groups
                        t_batch, o_batch, at_risk_batch, x_batch = \
                            sUtils.calc_at_risk(survival[batch],
                                                1-censored[batch],
                                                features[batch, :])

                        # run optimizer and fetch cost

                        feed_dict = {
                            graph.X_input: x_batch,
                            graph.T: t_batch,
                            graph.O: o_batch,
                            graph.At_Risk: at_risk_batch,
                        }

                        _, cost = sess.run([graph.optimizer, graph.cost], \
                                            feed_dict = feed_dict)

                        # normalize cost for sample size
                        cost = cost / len(batch)

                        # record/append cost
                        #self.Costs_batchLevel_train.append(cost)
                        cost_tot += cost

                        #print("\t{}\t{}\t{}".format(self.EPOCHS_RUN, batchidx, round(cost[0], 3)))
                        #pUtils.Log_and_print("\t\tTraining: Batch {} of {}, cost = {}".\
                        #     format(batchidx, len(batchIdxs)-1, round(cost[0], 3)))

                    # Run over validation set
                    #==========================================================
                    if USE_VALID:
                        for batchidx, batch in enumerate(batchIdxs_valid):

                            # Getting at-risk groups
                            t_batch, o_batch, at_risk_batch, x_batch = \
                                sUtils.calc_at_risk(survival[batch],
                                                    1-censored[batch],
                                                    features[batch, :])

                            # fetch cost

                            feed_dict = {
                                graph.X_input: x_batch,
                                graph.T: t_batch,
                                graph.O: o_batch,
                                graph.At_Risk: at_risk_batch,
                            }

                            cost = sess.run(graph.cost, feed_dict=feed_dict)

                            # normalize cost for sample size
                            cost = cost / len(batch)

                            # record/append cost
                            #self.Costs_batchLevel_valid.append(cost)
                            cost_tot_valid += cost

                            #pUtils.Log_and_print("\t\tValidation: Batch {} of {}, cost = {}".\
                            #     format(batchidx, len(batchIdxs_valid)-1, round(cost[0], 3)))

                    # Update and save
                    #==========================================================

                    # update epochs and append costs
                    self.EPOCHS_RUN += 1
                    self.Costs_epochLevel_train.append(cost_tot)
                    if USE_VALID:
                        self.Costs_epochLevel_valid.append(cost_tot_valid)

                    # periodically save model
                    #if (self.EPOCHS_RUN % MODEL_SAVE_STEP) == 0:
                    #    _saveTFmodel()

                    # periodically monitor progress
                    if (self.EPOCHS_RUN % PLOT_STEP == 0) and \
                        (self.EPOCHS_RUN > 0):
                        _monitorProgress()

            except KeyboardInterrupt:
                pass

            # save final model and plot costs
            #_saveTFmodel()
            _monitorProgress()

            #pUtils.Log_and_print("Finished training model.")
            #pUtils.Log_and_print("Obtaining final results.")

            # save learned weights
            W = sess.run(graph.W, feed_dict=feed_dict)
            np.save(self.RESULTPATH + 'model/' + self.description + \
                    'featWeights.npy', W)

        return W
コード例 #6
0
    def train(self,
              features,
              survival,
              censored,
              features_valid=None,
              survival_valid=None,
              censored_valid=None,
              graph_hyperparams={},
              BATCH_SIZE=20,
              PLOT_STEP=10,
              MODEL_SAVE_STEP=10,
              MAX_ITIR=100,
              MODEL_BUFFER=4,
              EARLY_STOPPING=False,
              MONITOR=True,
              PLOT=True,
              K=35,
              Method='cumulative-time',
              norm=2):
        """
        train a survivalNCA model
        features - (N,D) np array
        survival and censored - (N,) np array
        """

        #pUtils.Log_and_print("Training survival NCA model.")

        # Initial preprocessing and sanity checks
        #======================================================================

        #pUtils.Log_and_print("Initial preprocessing.")

        D = features.shape[1]

        assert len(features.shape) == 2
        assert len(survival.shape) == 1
        assert len(censored.shape) == 1

        USE_VALID = False
        if features_valid is not None:
            USE_VALID = True
            assert (features_valid.shape[1] == D)
            assert (survival_valid is not None)
            assert (censored_valid is not None)

        if EARLY_STOPPING:
            assert USE_VALID

        # Define computational graph
        #======================================================================

        graph_hyperparams = \
            pUtils.Merge_dict_with_default(\
                    dict_given = graph_hyperparams,
                    dict_default = self.default_graph_hyperparams,
                    keys_Needed = self.userspecified_graph_hyperparams)

        # Begin session
        #======================================================================

        #print("Running TF session.")
        #pUtils.Log_and_print("Running TF session.")

        with tf.Session() as sess:

            # Initial ground work
            #==================================================================

            # op to save/restore all the variables
            saver = tf.train.Saver()

            if "checkpoint" in os.listdir(self.WEIGHTPATH):
                # load existing weights
                #pUtils.Log_and_print("Restoring saved model ...")
                saver.restore(sess,
                              self.WEIGHTPATH + self.description + ".ckpt")
                #pUtils.Log_and_print("Model restored.")

            else:
                # start a new model
                sess.run(tf.global_variables_initializer())

            # for tensorboard visualization
            #train_writer = tf.summary.FileWriter(self.RESULTPATH + 'model/tensorboard',
            #                                     sess.self.graph)

            # Define some methods
            #==================================================================

            # periodically save model
            def _saveTFmodel():
                """Saves model weights using tensorflow saver"""

                # save weights
                #pUtils.Log_and_print("\nSaving TF model weights...")
                #save_path = saver.save(sess, \
                #                self.WEIGHTPATH + self.description + ".ckpt")
                #pUtils.Log_and_print("Model saved in file: %s" % save_path)

                # save attributes
                self.save()

            # monitor
            def _monitorProgress(snapshot_idx=None):
                """
                Monitor cost - save txt and plots cost
                """
                # find min epochs to display in case of keyboard interrupt
                max_epoch = np.min([
                    len(self.Costs_epochLevel_train),
                    len(self.CIs_train),
                    len(self.CIs_valid)
                ])

                # concatenate costs
                costs = np.array(self.Costs_epochLevel_train[0:max_epoch])
                cis_train = np.array(self.CIs_train[0:max_epoch])
                if USE_VALID:
                    cis_valid = np.array(self.CIs_valid[0:max_epoch])
                else:
                    cis_valid = None

                epoch_no = np.arange(max_epoch)
                costs = np.concatenate((epoch_no[:, None], costs[:, None]),
                                       axis=1)
                cis_train = np.concatenate(
                    (epoch_no[:, None], cis_train[:, None]), axis=1)

                # Saving raw numbers for later reference
                savename = self.RESULTPATH + "plots/" + self.description + self.timestamp

                with open(savename + '_costs.txt', 'wb') as f:
                    np.savetxt(f, costs, fmt='%s', delimiter='\t')

                with open(savename + '_cis_train.txt', 'wb') as f:
                    np.savetxt(f, cis_train, fmt='%s', delimiter='\t')

                if USE_VALID:
                    with open(savename + '_cis_valid.txt', 'wb') as f:
                        np.savetxt(f, cis_valid, fmt='%s', delimiter='\t')

                #
                # Note, plotting would not work when running
                # this using screen (Xdisplay is not supported)
                #
                if PLOT:
                    self._plotMonitor(arr=costs,
                                      title="Cost vs. epoch",
                                      xlab="epoch",
                                      ylab="Cost",
                                      savename=savename + "_costs.svg")
                    self._plotMonitor(arr=cis_train,
                                      arr2=cis_valid,
                                      title="C-index vs. epoch",
                                      xlab="epoch",
                                      ylab="C-index",
                                      savename=savename + "_Ci.svg",
                                      snapshot_idx=snapshot_idx)

            # Begin epochs
            #==================================================================

            try:
                itir = 0

                if MONITOR:
                    print("\n\tepoch\tcost\tCi_train\tCi_valid")
                    print("\t----------------------------------------------")

                knnmodel = knn.SurvivalKNN(self.RESULTPATH,
                                           description=self.description)

                # Initialize weights buffer
                # (keep a snapshot of model for early stopping)
                # each "channel" in 3rd dim is one snapshot of the model
                if USE_VALID:
                    Ws = np.zeros((D, D, MODEL_BUFFER))
                    Cis = []

                while itir < MAX_ITIR:

                    #pUtils.Log_and_print("\n\tTraining epoch {}\n".format(self.EPOCHS_RUN))

                    itir += 1
                    cost_tot = 0
                    self._update_timestamp()

                    # Divide into balanced batches
                    #==========================================================

                    n = censored.shape[0]

                    # Get balanced batches (if relevant)
                    if BATCH_SIZE < n:
                        # Shuffle so that training batches differ every epoch
                        idxs = np.arange(features.shape[0])
                        np.random.shuffle(idxs)
                        features = features[idxs, :]
                        survival = survival[idxs]
                        censored = censored[idxs]
                        # stochastic mini-batch GD
                        batchIdxs = dm.get_balanced_batches(
                            censored, BATCH_SIZE=BATCH_SIZE)
                    else:
                        # Global GD
                        batchIdxs = [np.arange(n)]

                    # Run over training set
                    #==========================================================

                    for batchidx, batch in enumerate(batchIdxs):

                        # Getting at-risk groups
                        t_batch, o_batch, at_risk_batch, x_batch = \
                            sUtils.calc_at_risk(survival[batch],
                                                1-censored[batch],
                                                features[batch, :])

                        # Get at-risk mask (to be multiplied by Pij)
                        n_batch = t_batch.shape[0]

                        # print("\tbatch {} of {}".format(batchidx, n_batch-1))

                        Pij_mask = np.zeros((n_batch, n_batch))
                        for idx in range(n_batch):
                            # only observed cases
                            if o_batch[idx] == 1:
                                # only at-risk cases
                                Pij_mask[idx, at_risk_batch[idx]:] = 1

                        # run optimizer and fetch cost
                        feed_dict = {
                            self.graph.X_input:
                            x_batch,
                            self.graph.Pij_mask:
                            Pij_mask,
                            self.graph.ALPHA:
                            graph_hyperparams['ALPHA'],
                            self.graph.LAMBDA:
                            graph_hyperparams['LAMBDA'],
                            self.graph.SIGMA:
                            graph_hyperparams['SIGMA'],
                            self.graph.DROPOUT_FRACTION:
                            graph_hyperparams['DROPOUT_FRACTION'],
                        }
                        _, cost = sess.run(
                            [self.graph.optimizer, self.graph.cost],
                            feed_dict=feed_dict)

                        # normalize cost for sample size
                        cost = cost / len(batch)

                        # record/append cost
                        #self.Costs_batchLevel_train.append(cost)
                        cost_tot += cost

                        #pUtils.Log_and_print("\t\tTraining: Batch {} of {}, cost = {}".\
                        #     format(batchidx, len(batchIdxs)-1, round(cost[0], 3)))

                    # Now get final NCA matrix (without dropput)
                    #==========================================================

                    feed_dict[self.graph.DROPOUT_FRACTION] = 0
                    W_grabbed = self.graph.W.eval(feed_dict=feed_dict)

                    # Get Ci for training/validation set
                    #==========================================================

                    # transform
                    x_train_transformed = np.dot(features, W_grabbed)
                    if USE_VALID:
                        x_valid_transformed = np.dot(features_valid, W_grabbed)

                    # get neighbor indices
                    neighbor_idxs_train = \
                        knnmodel._get_neighbor_idxs(x_train_transformed,
                                                    x_train_transformed,
                                                    norm=norm)
                    if USE_VALID:
                        neighbor_idxs_valid = \
                            knnmodel._get_neighbor_idxs(x_valid_transformed,
                                                        x_train_transformed,
                                                        norm=norm)

                    # Predict training/validation set
                    _, Ci_train = knnmodel.predict(neighbor_idxs_train,
                                                   Survival_train=survival,
                                                   Censored_train=censored,
                                                   Survival_test=survival,
                                                   Censored_test=censored,
                                                   K=K,
                                                   Method=Method)
                    if USE_VALID:
                        _, Ci_valid = knnmodel.predict(
                            neighbor_idxs_valid,
                            Survival_train=survival,
                            Censored_train=censored,
                            Survival_test=survival_valid,
                            Censored_test=censored_valid,
                            K=K,
                            Method=Method)
                    if not USE_VALID:
                        Ci_valid = 0

                    if MONITOR:
                        print("\t{}\t{}\t{}\t{}".format(\
                                self.EPOCHS_RUN,
                                round(cost_tot, 3),
                                round(Ci_train, 3),
                                round(Ci_valid, 3)))

                    # Update and save
                    #==========================================================

                    # update epochs and append costs
                    self.EPOCHS_RUN += 1
                    self.Costs_epochLevel_train.append(cost_tot)
                    self.CIs_train.append(Ci_train)
                    self.CIs_valid.append(Ci_valid)

                    # periodically save model
                    #if (self.EPOCHS_RUN % MODEL_SAVE_STEP) == 0:
                    #    _saveTFmodel()

                    # periodically monitor progress
                    if MONITOR:
                        if (self.EPOCHS_RUN % PLOT_STEP == 0) and \
                            (self.EPOCHS_RUN > 0):
                            _monitorProgress()

                    # Early stopping
                    #==========================================================

                    if EARLY_STOPPING:
                        # Save snapshot
                        Ws[:, :, itir % MODEL_BUFFER] = W_grabbed
                        Cis.append(Ci_valid)

                        # Stop when overfitting starts to occur
                        if len(Cis) > (2 * MODEL_BUFFER):
                            ci_new = np.mean(Cis[-MODEL_BUFFER:])
                            ci_old = np.mean(Cis[-2 *
                                                 MODEL_BUFFER:-MODEL_BUFFER])

                            if ci_new < ci_old:
                                snapshot_idx = (itir - MODEL_BUFFER +
                                                1) % MODEL_BUFFER
                                W_grabbed = Ws[:, :, snapshot_idx]
                                break

            except KeyboardInterrupt:
                pass

            #pUtils.Log_and_print("Finished training model.")
            #pUtils.Log_and_print("Obtaining final results.")

            if MONITOR:
                # save final model
                #_saveTFmodel()

                # plot costs
                if EARLY_STOPPING:
                    snapshot = itir - MODEL_BUFFER
                else:
                    snapshot = None
                _monitorProgress(snapshot_idx=snapshot)

                # save learned weights
                np.save(self.RESULTPATH + 'model/' + self.description + \
                        self.timestamp + 'NCA_matrix.npy', W_grabbed)

        return W_grabbed
コード例 #7
0
#==============================================================================

if GETLOGS == True:
    # Separate out validation set
    N_tot = np.size(Features, 0)
    Features_valid = Features[int(PERC_TRAIN * N_tot):N_tot, :]
    Survival_valid = Survival[int(PERC_TRAIN * N_tot):N_tot]
    Censored_valid = Censored[int(PERC_TRAIN * N_tot):N_tot]

    Features = Features[0:int(PERC_TRAIN * N_tot), :]
    Survival = Survival[0:int(PERC_TRAIN * N_tot)]
    Censored = Censored[0:int(PERC_TRAIN * N_tot)]

    # Getting at-risk groups (validation set)
    Features_valid, Survival_valid, Observed_valid, at_risk_valid = \
      sUtils.calc_at_risk(Features_valid, Survival_valid, 1-Censored_valid)

# Getting at-risk groups (trainign set)
Features, Survival, Observed, at_risk = \
  sUtils.calc_at_risk(Features, Survival, 1-Censored)

#%%============================================================================
# Setting params and other stuff
#==============================================================================

# Convert to integer/bool (important for BayesOpt to work properly since it
# tries float values)
EPOCHS = int(EPOCHS)
DEPTH = int(DEPTH)
MAXWIDTH = int(MAXWIDTH)