Пример #1
0
def main():
    hp = Hyperparameters()

    print('Load data...')
    data = np.load(hp.data_pp_dir + 'data_arrays_' + hp.gender + '.npz')
    df_index_code = feather.read_dataframe(hp.data_pp_dir + 'df_index_code_' +
                                           hp.gender + '.feather')

    print('Train on each fold...')
    for fold in range(hp.num_folds):
        for swap in range(2):
            print('Fold: {} Swap: {}'.format(fold, swap))

            idx = (data['fold'][:, fold] == swap)
            x = data['x'][idx]
            time = data['time'][idx]
            event = data['event'][idx]
            codes = data['codes'][idx]
            month = data['month'][idx]
            diagt = data['diagt'][idx]

            if not hp.redundant_predictors:
                cols_list = load_obj(hp.data_pp_dir + 'cols_list.pkl')
                x = x[:, [cols_list.index(i) for i in hp.reduced_col_list]]

            sort_idx, case_idx, max_idx_control = sort_and_case_indices(
                x, time, event)
            x, time, event = x[sort_idx], time[sort_idx], event[sort_idx]
            codes, month, diagt = codes[sort_idx], month[sort_idx], diagt[
                sort_idx]

            print('Create data loaders and tensors...')
            case = utils.TensorDataset(torch.from_numpy(x[case_idx]),
                                       torch.from_numpy(time[case_idx]),
                                       torch.from_numpy(max_idx_control),
                                       torch.from_numpy(codes[case_idx]),
                                       torch.from_numpy(month[case_idx]),
                                       torch.from_numpy(diagt[case_idx]))

            x = torch.from_numpy(x)
            time = torch.from_numpy(time)
            event = torch.from_numpy(event)
            codes = torch.from_numpy(codes)
            month = torch.from_numpy(month)
            diagt = torch.from_numpy(diagt)

            for trial in range(hp.num_trials):
                print('Trial: {}'.format(trial))

                # Create batch queues
                trn_loader = utils.DataLoader(case,
                                              batch_size=hp.batch_size,
                                              shuffle=True,
                                              drop_last=True)

                print('Train...')
                # Neural Net
                hp.model_name = str(trial) + '_' + datetime.now().strftime(
                    '%Y%m%d_%H%M%S_%f') + '.pt'
                num_input = x.shape[1] + 1 if hp.nonprop_hazards else x.shape[1]
                net = NetRNNFinal(num_input, df_index_code.shape[0] + 1,
                                  hp).to(hp.device)  #+1 for zero padding
                criterion = CoxPHLoss().to(hp.device)
                optimizer = optim.Adam(net.parameters(), lr=hp.learning_rate)

                for epoch in range(hp.max_epochs):
                    trn(trn_loader, x, codes, month, diagt, net, criterion,
                        optimizer, hp)
                if hp.redundant_predictors:
                    torch.save(
                        net.state_dict(), hp.log_dir + 'fold_' + str(fold) +
                        '_' + str(swap) + '/' + hp.model_name)
                else:
                    torch.save(
                        net.state_dict(), hp.log_dir + 'fold_' + str(fold) +
                        '_' + str(swap) + '_no_redundancies/' + hp.model_name)
                print('Done')
def objective(trial, data, df_index_code):
    hp = Hyperparameters(trial)
    #hp = Hyperparameters()
    print(trial.params)

    idx_trn = (data['fold'] != 99)
    x_trn = data['x'][idx_trn]
    time_trn = data['time'][idx_trn]
    event_trn = data['event'][idx_trn]
    codes_trn = data['codes'][idx_trn]
    month_trn = data['month'][idx_trn]
    diagt_trn = data['diagt'][idx_trn]

    idx_val = (data['fold'] == 99)
    x_val = data['x'][idx_val]
    time_val = data['time'][idx_val]
    event_val = data['event'][idx_val]
    codes_val = data['codes'][idx_val]
    month_val = data['month'][idx_val]
    diagt_val = data['diagt'][idx_val]

    # could move this outside objective function for efficiency
    sort_idx_trn, case_idx_trn, max_idx_control_trn = sort_and_case_indices(
        x_trn, time_trn, event_trn)
    sort_idx_val, case_idx_val, max_idx_control_val = sort_and_case_indices(
        x_val, time_val, event_val)

    x_trn, time_trn, event_trn = x_trn[sort_idx_trn], time_trn[
        sort_idx_trn], event_trn[sort_idx_trn]
    codes_trn, month_trn, diagt_trn = codes_trn[sort_idx_trn], month_trn[
        sort_idx_trn], diagt_trn[sort_idx_trn]

    x_val, time_val, event_val = x_val[sort_idx_val], time_val[
        sort_idx_val], event_val[sort_idx_val]
    codes_val, month_val, diagt_val = codes_val[sort_idx_val], month_val[
        sort_idx_val], diagt_val[sort_idx_val]

    #######################################################################################################

    print('Create data loaders and tensors...')
    case_trn = utils.TensorDataset(torch.from_numpy(x_trn[case_idx_trn]),
                                   torch.from_numpy(time_trn[case_idx_trn]),
                                   torch.from_numpy(max_idx_control_trn),
                                   torch.from_numpy(codes_trn[case_idx_trn]),
                                   torch.from_numpy(month_trn[case_idx_trn]),
                                   torch.from_numpy(diagt_trn[case_idx_trn]))
    case_val = utils.TensorDataset(torch.from_numpy(x_val[case_idx_val]),
                                   torch.from_numpy(time_val[case_idx_val]),
                                   torch.from_numpy(max_idx_control_val),
                                   torch.from_numpy(codes_val[case_idx_val]),
                                   torch.from_numpy(month_val[case_idx_val]),
                                   torch.from_numpy(diagt_val[case_idx_val]))

    x_trn, x_val = torch.from_numpy(x_trn), torch.from_numpy(x_val)
    time_trn, time_val = torch.from_numpy(time_trn), torch.from_numpy(time_val)
    event_trn, event_val = torch.from_numpy(event_trn), torch.from_numpy(
        event_val)
    codes_trn, codes_val = torch.from_numpy(codes_trn), torch.from_numpy(
        codes_val)
    month_trn, month_val = torch.from_numpy(month_trn), torch.from_numpy(
        month_val)
    diagt_trn, diagt_val = torch.from_numpy(diagt_trn), torch.from_numpy(
        diagt_val)

    # Create batch queues
    trn_loader = utils.DataLoader(case_trn,
                                  batch_size=hp.batch_size,
                                  shuffle=True,
                                  drop_last=True)
    val_loader = utils.DataLoader(case_val,
                                  batch_size=hp.batch_size,
                                  shuffle=False,
                                  drop_last=False)

    print('Train...')
    # Neural Net
    hp.model_name = str(trial.number) + '_' + hp.model_name
    num_input = x_trn.shape[1] + 1 if hp.nonprop_hazards else x_trn.shape[1]
    net = NetRNNFinal(num_input, df_index_code.shape[0] + 1,
                      hp).to(hp.device)  #+1 for zero padding
    criterion = CoxPHLoss().to(hp.device)
    optimizer = optim.Adam(net.parameters(), lr=hp.learning_rate)

    best, num_bad_epochs = 100., 0
    for epoch in range(1000):
        trn(trn_loader, x_trn, codes_trn, month_trn, diagt_trn, net, criterion,
            optimizer, hp)
        loss_val = val(val_loader, x_val, codes_val, month_val, diagt_val, net,
                       criterion, epoch, hp)
        # early stopping
        if loss_val < best:
            print(
                '############### Saving good model ###############################'
            )
            torch.save(net.state_dict(), hp.log_dir + hp.model_name)
            best = loss_val
            num_bad_epochs = 0
        else:
            num_bad_epochs += 1
            if num_bad_epochs == hp.patience:
                break
        # pruning
        trial.report(best, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

    print('Done')
    return best