示例#1
0
def fit_CRN_encoder(dataset_train, dataset_val, model_name, model_dir, hyperparams_file,
                    b_hyperparam_opt):
    _, length, num_covariates = dataset_train['current_covariates'].shape
    num_treatments = dataset_train['current_treatments'].shape[-1]
    num_outputs = dataset_train['outputs'].shape[-1]
    num_inputs = dataset_train['current_covariates'].shape[-1] + dataset_train['current_treatments'].shape[-1]

    params = {'num_treatments': num_treatments,
              'num_covariates': num_covariates,
              'num_outputs': num_outputs,
              'max_sequence_length': length,
              'num_epochs': 100}

    hyperparams = dict()
    num_simulations = 50
    best_validation_mse = 1000000

    if b_hyperparam_opt:
        logging.info("Performing hyperparameter optimization")
        for simulation in range(num_simulations):
            logging.info("Simulation {} out of {}".format(simulation + 1, num_simulations))

            hyperparams['rnn_hidden_units'] = int(np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * num_inputs)
            hyperparams['br_size'] = int(np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * num_inputs)
            hyperparams['fc_hidden_units'] = int(np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * (hyperparams['br_size']))
            hyperparams['learning_rate'] = np.random.choice([0.01, 0.001])
            hyperparams['batch_size'] = np.random.choice([64, 128, 256])
            hyperparams['rnn_keep_prob'] = np.random.choice([0.7, 0.8, 0.9])

            logging.info("Current hyperparams used for training \n {}".format(hyperparams))
            model = CRN_Model(params, hyperparams)
            model.train(dataset_train, dataset_val, model_name, model_dir)
            validation_mse, _ = model.evaluate_predictions(dataset_val)

            if (validation_mse < best_validation_mse):
                logging.info(
                    "Updating best validation loss | Previous best validation loss: {} | Current best validation loss: {}".format(
                        best_validation_mse, validation_mse))
                best_validation_mse = validation_mse
                best_hyperparams = hyperparams.copy()

            logging.info("Best hyperparams: \n {}".format(best_hyperparams))

        write_results_to_file(hyperparams_file, best_hyperparams)

    else:
        logging.info("Using default hyperparameters")
        best_hyperparams = {
            'rnn_hidden_units': 24,
            'br_size': 12,
            'fc_hidden_units': 36,
            'learning_rate': 0.01,
            'batch_size': 128,
            'rnn_keep_prob': 0.9}
        logging.info("Best hyperparams: \n {}".format(best_hyperparams))
        write_results_to_file(hyperparams_file, best_hyperparams)

    model = CRN_Model(params, best_hyperparams)
    model.train(dataset_train, dataset_val, model_name, model_dir)
示例#2
0
def fit_CRN_decoder(dataset_train, dataset_val, model_name, model_dir,
                    encoder_hyperparams_file, decoder_hyperparams_file,
                    b_hyperparam_opt):
    logging.info("Fitting CRN decoder.")

    _, length, num_covariates = dataset_train['current_covariates'].shape
    num_treatments = dataset_train['current_treatments'].shape[-1]
    num_outputs = dataset_train['outputs'].shape[-1]
    num_inputs = dataset_train['current_covariates'].shape[-1] + dataset_train[
        'current_treatments'].shape[-1]

    params = {
        'num_treatments': num_treatments,
        'num_covariates': num_covariates,
        'num_outputs': num_outputs,
        'max_sequence_length': length,
        'num_epochs': 100
    }

    hyperparams = dict()
    num_simulations = 30
    best_validation_mse = 1000000

    with open(encoder_hyperparams_file, 'rb') as handle:
        encoder_best_hyperparams = pickle.load(handle)

    if b_hyperparam_opt:
        logging.info("Performing hyperparameter optimization.")
        for simulation in range(num_simulations):
            logging.info("Simulation {} out of {}".format(
                simulation + 1, num_simulations))

            # The first rnn hidden state in the decoder is initialized with the balancing representation
            # outputed by the encoder.
            hyperparams['rnn_hidden_units'] = encoder_best_hyperparams[
                'br_size']

            hyperparams['br_size'] = int(
                np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * num_inputs)
            hyperparams['fc_hidden_units'] = int(
                np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) *
                (hyperparams['br_size']))
            hyperparams['learning_rate'] = np.random.choice(
                [0.01, 0.001, 0.0001])
            hyperparams['batch_size'] = np.random.choice([256, 512, 1024])
            hyperparams['rnn_keep_prob'] = np.random.choice([0.7, 0.8, 0.9])

            logging.info("Current hyperparams used for training \n {}".format(
                hyperparams))
            model = CRN_Model(params, hyperparams, b_train_decoder=True)
            model.train(dataset_train, dataset_val, model_name, model_dir)
            validation_mse, _ = model.evaluate_predictions(dataset_val)

            if (validation_mse < best_validation_mse):
                logging.info(
                    "Updating best validation loss | Previous best validation loss: {} | Current best validation loss: {}"
                    .format(best_validation_mse, validation_mse))
                best_validation_mse = validation_mse
                best_hyperparams = hyperparams.copy()

            logging.info("Best hyperparams: \n {}".format(best_hyperparams))

        write_results_to_file(decoder_hyperparams_file, best_hyperparams)

    else:
        # The rnn_hidden_units needs to be the same as the encoder br_size.
        logging.info("Using default hyperparameters")
        best_hyperparams = {
            'br_size': 18,
            'rnn_keep_prob': 0.9,
            'fc_hidden_units': 36,
            'batch_size': 1024,
            'learning_rate': 0.001,
            'rnn_hidden_units': encoder_best_hyperparams['br_size']
        }

        write_results_to_file(decoder_hyperparams_file, best_hyperparams)

    model = CRN_Model(params, best_hyperparams, b_train_decoder=True)
    model.train(dataset_train, dataset_val, model_name, model_dir)