Beispiel #1
0
    def _model_factory(self,
                       n_trees=None,
                       n_input_features=None,
                       n_neurons=None):
        if self.algorithm == 'CPH':
            return CoxPHFitter()
        elif self.algorithm == 'RSF':
            return RandomSurvivalForestModel(num_trees=n_trees)
        elif self.algorithm in self._pycox_methods:
            net_args = {
                'in_features': n_input_features,
                'num_nodes': n_neurons,
                'batch_norm': True,
                'dropout': 0.1,
            }

            if self.algorithm == 'DeepSurv':
                net = tt.practical.MLPVanilla(out_features=1,
                                              output_bias=False,
                                              **net_args)
                model = CoxPH(net, tt.optim.Adam)

                return model
            if self.algorithm == 'CoxTime':
                net = MLPVanillaCoxTime(**net_args)
                model = CoxTime(net, tt.optim.Adam)

                return model
            if self.algorithm in self._discrete_time_methods:
                num_durations = 30
                print(f'   {num_durations} equidistant intervals')
            if self.algorithm == 'DeepHit':
                labtrans = DeepHitSingle.label_transform(num_durations)
                net = self._get_discrete_time_net(labtrans, net_args)
                model = DeepHitSingle(net,
                                      tt.optim.Adam,
                                      alpha=0.2,
                                      sigma=0.1,
                                      duration_index=labtrans.cuts)

                return model
            if self.algorithm == 'MTLR':
                labtrans = MTLR.label_transform(num_durations)
                net = self._get_discrete_time_net(labtrans, net_args)
                model = MTLR(net, tt.optim.Adam, duration_index=labtrans.cuts)

                return model
            if self.algorithm == 'Nnet-survival':
                labtrans = LogisticHazard.label_transform(num_durations)
                net = self._get_discrete_time_net(labtrans, net_args)
                model = LogisticHazard(net,
                                       tt.optim.Adam(0.01),
                                       duration_index=labtrans.cuts)

                return model
        else:
            raise Exception('Unrecognized model.')
Beispiel #2
0
def test_cox_time_runs(numpy):
    input, target = make_dataset(False).apply(lambda x: x.float()).to_numpy()
    labtrans = CoxTime.label_transform()
    target = labtrans.fit_transform(*target)
    data = tt.tuplefy(input, target)
    if not numpy:
        data = data.to_tensor()
    net = MLPVanillaCoxTime(data[0].shape[1], [4], False)
    model = CoxTime(net)
    fit_model(data, model)
    model.compute_baseline_hazards()
    assert_survs(data[0], model)
Beispiel #3
0
X_train_std = X_train_std.astype('float32')
X_test_std = X_test_std.astype('float32')
y_train = y_train.astype('float32')
y_test = y_test.astype('float32')

torch.manual_seed(method_random_seed)
np.random.seed(method_random_seed)

labtrans = CoxTime.label_transform()
y_train_std = labtrans.fit_transform(*y_train.T)

batch_norm = True
dropout = 0.0

net = MLPVanillaCoxTime(X_train_std.shape[1],
                        [n_nodes for layer_idx in range(n_layers)], batch_norm,
                        dropout)

optimizer = tt.optim.Adam(lr=lr)

surv_model = CoxTime(net, optimizer, labtrans=labtrans)

model_filename = \
    os.path.join(output_dir, 'models',
                 '%s_%s_exp%d_bs%d_nep%d_nla%d_nno%d_lr%f_test.pt'
                 % (survival_estimator_name, dataset, experiment_idx,
                    batch_size, n_epochs, n_layers, n_nodes, lr))
assert os.path.isfile(model_filename)
if not os.path.isfile(model_filename):
    print('*** Fitting with hyperparam:', hyperparam, flush=True)
    surv_model.fit(X_train_std,
Beispiel #4
0
def objective(x_train,
              neurons,
              drop,
              activation,
              lr_opt,
              optimizer,
              n_layers,
              name,
              labtrans=""):
    """ Define the structure of the neural network for a Cox-MLP (CC), CoxTime and  DeepHit
    # Arguments
        x_train: input data as formated by the function "prepare_data"
        neurons: number of neurons per hidden layer in the neural network
        drop: dropout rate applied after each hidden layer
        activation: activation function applied after each hidden layer
        lr_opt: learning rate chosen for optimization
        optimizer: optimization algorithm 
        n_layers: number of hidden layers 
        name: name of the model
        labtrans: transformed input variables, including the time variable
    # Returns
        model: pycox model (based on pytorch) with the architecture defined previously
        callbacks: callbacks function
    
    """
    in_features = x_train.shape[1]
    if labtrans != "":
        out_features = labtrans.out_features
    else:
        out_features = 1
    nb_neurons = [neurons] * n_layers

    if optimizer == "rmsprop":
        optim = tt.optim.RMSprop()
        callbacks = [tt.callbacks.Callback()]

    elif optimizer == "adam":
        optim = tt.optim.Adam()
        callbacks = [tt.callbacks.Callback()]

    elif optimizer == "adam_amsgrad":
        optim = tt.optim.Adam(amsgrad=True)
        callbacks = [tt.callbacks.Callback()]

    elif optimizer == "sgdwr":
        optim = tt.optim.SGD(momentum=0.9)
        callbacks = [tt.callbacks.LRCosineAnnealing()]

    if activation == 'relu':
        act = torch.nn.ReLU
    elif activation == 'elu':
        act = torch.nn.ELU
    elif activation == 'tanh':
        act = torch.nn.Tanh

    if name == "Cox-CC":
        net = tt.practical.MLPVanilla(in_features,
                                      nb_neurons,
                                      out_features,
                                      batch_norm=True,
                                      dropout=drop,
                                      activation=act,
                                      output_bias=False)
        model = CoxCC(net, optim)

    elif name == "CoxTime":
        net = MLPVanillaCoxTime(in_features,
                                nb_neurons,
                                batch_norm=True,
                                dropout=drop,
                                activation=act)
        model = CoxTime(net, optim, labtrans=labtrans)

    elif name == "DeepHit":
        net = tt.practical.MLPVanilla(in_features,
                                      nb_neurons,
                                      out_features,
                                      batch_norm=True,
                                      dropout=drop,
                                      activation=act,
                                      output_bias=False)
        model = DeepHitSingle(net,
                              optim,
                              alpha=0.2,
                              sigma=0.1,
                              duration_index=labtrans.cuts)

    model.optimizer.set_lr(lr_opt)

    return model, callbacks
Beispiel #5
0
                cols_categorical) else x_train.shape[1]
            num_nodes = [args.num_nodes] * args.num_layers
            out_features = 1
            batch_norm = args.use_BN
            dropout = args.dropout
            output_bias = args.use_output_bias
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')

            if len(cols_categorical) > 0:
                net = MixedInputMLPCoxTime(in_features, num_embeddings,
                                           embedding_dims, num_nodes,
                                           batch_norm, dropout)
                # net = Transformer(in_features, num_embeddings, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)
            else:
                net = MLPVanillaCoxTime(in_features, num_nodes, batch_norm,
                                        dropout)
            net = net.to(device)

            if args.optimizer == 'AdamWR':
                model = CoxTime(net,
                                optimizer=tt.optim.AdamWR(
                                    lr=args.lr,
                                    decoupled_weight_decay=args.weight_decay,
                                    cycle_eta_multiplier=0.8),
                                device=device,
                                shrink=args.shrink,
                                labtrans=labtrans)

            wandb.init(
                project='icml_' + args.dataset + '_baseline',
                group=f'coxtime_fold{fold}_' + args.loss,