Exemplo n.º 1
0
            # Training ======================================================================
            batch_size = args.batch_size
            lr_finder = model.lr_finder(x_train,
                                        y_train,
                                        batch_size,
                                        tolerance=3)
            best = lr_finder.get_best_lr()

            model.optimizer.set_lr(best)

            epochs = args.epochs
            callbacks = [tt.callbacks.EarlyStopping()]
            verbose = True
            log = model.fit(x_train,
                            y_train,
                            batch_size,
                            epochs,
                            callbacks,
                            val_data=val)
            # Evaluation ===================================================================
            surv = model.interpolate(10).predict_surv_df(x_test)
            ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

            # ctd = ev.concordance_td()
            ctd = ev.concordance_td('antolini')
            time_grid = np.linspace(durations_test.min(), durations_test.max(),
                                    100)

            ibs = ev.integrated_brier_score(time_grid)
            nbll = ev.integrated_nbll(time_grid)
            val_loss = min(log.monitors['val_'].scores['loss']['score'])
            wandb.log({
Exemplo n.º 2
0
                                     '%s_%s_exp%d_bs%d_nep%d_nla%d_nno%d_'
                                     % (init_survival_estimator_name, dataset,
                                        experiment_idx, init_batch_size,
                                        init_n_epochs, n_layers, n_nodes)
                                     +
                                     'lr%f_a%f_s%f_nd%d_test.pt'
                                     % (init_lr, init_alpha, init_sigma,
                                        init_num_durations))
                    time_elapsed_filename = model_filename[:-3] + '_time.txt'
                    print('*** Pre-training...')
                    # assert os.path.isfile(model_filename)
                    if not os.path.isfile(model_filename):
                        # print('*** Fitting with hyperparam:', hyperparam, flush=True)
                        surv_model.fit(X_train_std,
                                       y_train_discrete,
                                       init_batch_size,
                                       init_n_epochs,
                                       verbose=False)
                        elapsed = time.time() - tic
                        # print('Time elapsed: %f second(s)' % elapsed)
                        np.savetxt(time_elapsed_filename,
                                   np.array(elapsed).reshape(1, -1))
                        surv_model.save_net(model_filename)
                    else:
                        # print('*** Loading ***', flush=True)
                        surv_model.load_net(model_filename)
                        elapsed = float(np.loadtxt(time_elapsed_filename))
                        # print('Time elapsed (from previous fitting): %f second(s)'
                        #       % elapsed)
                    surv_model.net.train()
Exemplo n.º 3
0
                        os.path.join(output_dir, 'models',
                                     '%s_%s_exp%d_bs%d_nep%d_nla%d_nno%d_'
                                     % (survival_estimator_name, dataset,
                                        experiment_idx, batch_size, n_epochs,
                                        n_layers, n_nodes)
                                     +
                                     'lr%f_a%f_s%f_nd%d_cv%d.pt'
                                     % (lr, alpha, sigma,
                                        num_durations, cross_val_idx))
                    time_elapsed_filename = model_filename[:-3] + '_time.txt'
                    if not os.path.isfile(model_filename):
                        # print('*** Fitting with hyperparam:', hyperparam,
                        #       '-- cross val index:', cross_val_idx, flush=True)
                        surv_model.fit(fold_X_train_std,
                                       fold_y_train_discrete,
                                       batch_size,
                                       n_epochs,
                                       verbose=False)
                        elapsed = time.time() - tic
                        print('Time elapsed: %f second(s)' % elapsed)
                        np.savetxt(time_elapsed_filename,
                                   np.array(elapsed).reshape(1, -1))
                        surv_model.save_net(model_filename)
                    else:
                        # print('*** Loading ***', flush=True)
                        surv_model.load_net(model_filename)
                        elapsed = float(np.loadtxt(time_elapsed_filename))
                        print('Time elapsed (from previous fitting): ' +
                              '%f second(s)' % elapsed)

                    surv_df = \
Exemplo n.º 4
0
     os.path.join(output_dir, 'models',
                  '%s_%s_exp%d_%s_bs%d_nep%d_nla%d_nno%d_'
                  % (survival_estimator_name, dataset,
                     experiment_idx, val_string, batch_size,
                     n_epochs, n_layers, n_nodes)
                  +
                  'lr%f_a%f_s%f_nd%d_fold%d.pt'
                  % (lr, alpha, sigma,
                     num_durations, fold_idx))
 time_elapsed_filename = model_filename[:-3] + '_time.txt'
 if not os.path.isfile(model_filename):
     if use_early_stopping and not use_cross_val:
         surv_model.fit(fold_X_train_std,
                        fold_y_train_discrete,
                        batch_size,
                        n_epochs,
                        [tt.callbacks.EarlyStopping()],
                        val_data=(fold_X_val_std,
                                  fold_y_val_discrete),
                        verbose=False)
     else:
         surv_model.fit(fold_X_train_std,
                        fold_y_train_discrete,
                        batch_size,
                        n_epochs,
                        verbose=False)
     elapsed = time.time() - tic
     print('Time elapsed: %f second(s)' % elapsed)
     np.savetxt(time_elapsed_filename,
                np.array(elapsed).reshape(1, -1))
     surv_model.save_net(model_filename)
 else:
Exemplo n.º 5
0
class DeepHit_pycox():
    def __init__(self, nodes_per_layer, layers, dropout, weight_decay, batch_size, \
                 num_durations, alpha, sigma, lr=0.0001, seed=47):
        # set seed
        np.random.seed(seed)
        _ = torch.manual_seed(seed)

        self.standardalizer = None
        self.standardize_data = True

        self._duration_col = "duration"
        self._event_col = "event"

        self.in_features = None
        self.batch_norm = True
        self.output_bias = False
        self.activation = torch.nn.ReLU
        self.epochs = 512
        self.num_workers = 2
        self.callbacks = [tt.callbacks.EarlyStopping()]

        # parameters tuned
        self.alpha = alpha
        self.sigma = 10**sigma
        self.num_nodes = [int(nodes_per_layer) for _ in range(int(layers))]
        self.dropout = dropout
        self.weight_decay = weight_decay
        self.lr = lr
        self.batch_size = int(batch_size)

        self.num_durations = int(num_durations)
        self.labtrans = DeepHitSingle.label_transform(self.num_durations)

    def set_standardize(self, standardize_bool):
        self.standardize_data = standardize_bool

    def _format_to_pycox(self, X, Y, F):
        # from numpy to pandas df
        df = pd.DataFrame(data=X, columns=F)
        if Y is not None:
            df[self._duration_col] = Y[:, 0]
            df[self._event_col] = Y[:, 1]
        return df

    def _standardize_df(self, df, flag):
        # if flag = test, the df passed in does not contain Y labels
        if self.standardize_data:
            # standardize x
            df_x = df if flag == 'test' else df.drop(
                columns=[self._duration_col, self._event_col])
            if flag == "train":
                cols_leave = []
                cols_standardize = []
                for column in df_x.columns:
                    if set(pd.unique(df[column])) == set([0, 1]):
                        cols_leave.append(column)
                    else:
                        cols_standardize.append(column)
                standardize = [([col], StandardScaler())
                               for col in cols_standardize]
                leave = [(col, None) for col in cols_leave]
                self.standardalizer = DataFrameMapper(standardize + leave)

                x = self.standardalizer.fit_transform(df_x).astype('float32')
                # standardize y
                y = self.labtrans.fit_transform(df[self._duration_col].values.astype('float32'), \
                                                df[self._event_col].values.astype('float32'))
            elif flag == "val":
                x = self.standardalizer.transform(df_x).astype('float32')
                y = self.labtrans.transform(df[self._duration_col].values.astype('float32'), \
                                            df[self._event_col].values.astype('float32'))

            elif flag == "test":
                x = self.standardalizer.transform(df_x).astype('float32')
                y = None

            else:
                raise NotImplementedError

            return x, y
        else:
            raise NotImplementedError

    def fit(self, X, y, column_names):
        # format data
        self.column_names = column_names
        full_df = self._format_to_pycox(X, y, self.column_names)
        val_df = full_df.sample(frac=0.2)
        train_df = full_df.drop(val_df.index)
        train_x, train_y = self._standardize_df(train_df, "train")
        val_x, val_y = self._standardize_df(val_df, "val")
        # configure model
        self.in_features = train_x.shape[1]
        self.out_features = self.labtrans.out_features
        net = tt.practical.MLPVanilla(
            in_features=self.in_features,
            num_nodes=self.num_nodes,
            out_features=self.out_features,
            batch_norm=self.batch_norm,
            dropout=self.dropout,
            activation=self.activation,
            output_bias=self.output_bias,
            w_init_=lambda w: torch.nn.init.xavier_normal_(w))
        self.model = DeepHitSingle(net, tt.optim.Adam(lr=self.lr, weight_decay=self.weight_decay), \
                                  alpha=self.alpha, sigma=self.sigma, duration_index=self.labtrans.cuts)
        n_train = train_x.shape[0]
        while n_train % self.batch_size == 1:  # this will cause issues in batch norm
            self.batch_size += 1
        self.model.fit(train_x,
                       train_y,
                       self.batch_size,
                       self.epochs,
                       self.callbacks,
                       verbose=True,
                       val_data=(val_x, val_y),
                       val_batch_size=self.batch_size,
                       num_workers=self.num_workers)

    def predict(self, test_x, time_list):
        # format data
        test_df = self._format_to_pycox(test_x, None, self.column_names)
        test_x, _ = self._standardize_df(test_df, "test")

        proba_matrix_ = self.model.interpolate().predict_surv_df(test_x)
        time_list_ = list(proba_matrix_.index)
        proba_matrix = np.transpose(proba_matrix_.values)
        pred_medians = []
        median_time = max(time_list_)
        # if the predicted proba never goes below 0.5, predict the largest seen value
        for test_idx, survival_proba in enumerate(proba_matrix):
            # the survival_proba is in descending order
            for col, proba in enumerate(survival_proba):
                if proba > 0.5:
                    continue
                if proba == 0.5 or col == 0:
                    median_time = time_list_[col]
                else:
                    median_time = (time_list_[col - 1] + time_list_[col]) / 2
                break
            pred_medians.append(median_time)

        return np.array(pred_medians), proba_matrix_
Exemplo n.º 6
0
def _train_dht(x, t, e, folds, params):
    """Helper Function to train a deep-hit model (van der schaar et. al).

  Args:
    x:
      a numpy array of input features (Training Data).
    t:
      a numpy vector of event times (Training Data).
    e:
      a numpy vector of event indicators (1 if event occured, 0 otherwise)
      (Training Data).
    folds:
       vector of the training cv folds.

  Returns:
    Trained pycox.DeepHitSingle model.

  """
    if params is None:
        num_nodes = [100, 100]
        lr = 1e-3
        bs = 128
    else:
        num_nodes = params['num_nodes']
        lr = params['lr']
        bs = params['bs']

    x = x.astype('float32')
    t = t.astype('float32')
    e = e.astype('int32')

    #   num_durations = int(0.5*max(t))
    #   print ("num_durations:", num_durations)

    num_durations = int(max(t))
    #num_durations = int(30)

    print("num_durations:", num_durations)

    labtrans = DeepHitSingle.label_transform(num_durations, scheme='quantiles')
    #labtrans = DeepHitSingle.label_transform(num_durations,)

    #print (labtrans)

    in_features = x.shape[1]
    batch_norm = False
    dropout = 0.0
    output_bias = False

    fold_model = {}

    for f in set(folds):

        xf = x[folds != f]
        tf = t[folds != f]
        ef = e[folds != f]

        validx = sorted(
            np.random.choice(len(xf),
                             size=(int(0.15 * len(xf))),
                             replace=False))

        vidx = np.array([False] * len(xf))
        vidx[validx] = True

        y_train = labtrans.fit_transform(tf[~vidx], ef[~vidx])
        y_val = labtrans.transform(tf[vidx], ef[vidx])
        out_features = labtrans.out_features

        net = ttup.practical.MLPVanilla(in_features, num_nodes, out_features,
                                        batch_norm, dropout)

        model = DeepHitSingle(net,
                              ttup.optim.Adam,
                              alpha=0.5,
                              sigma=1,
                              duration_index=labtrans.cuts)

        y_train = y_train[0].astype('int64'), y_train[1].astype('float32')
        y_val = y_val[0].astype('int64'), y_val[1].astype('float32')

        val = xf[vidx], y_val
        train = xf[~vidx], y_train

        batch_size = bs
        model.optimizer.set_lr(lr)
        epochs = 10
        callbacks = [ttup.callbacks.EarlyStopping()]

        model.fit(
            xf[~vidx],
            y_train,
            batch_size,
            epochs,
            callbacks,
            True,
            val_data=val,
        )

        fold_model[f] = model

    return fold_model