def fit_CRN_encoder(dataset_train, dataset_val, model_name, model_dir, hyperparams_file, b_hyperparam_opt): _, length, num_covariates = dataset_train['current_covariates'].shape num_treatments = dataset_train['current_treatments'].shape[-1] num_outputs = dataset_train['outputs'].shape[-1] num_inputs = dataset_train['current_covariates'].shape[-1] + dataset_train['current_treatments'].shape[-1] params = {'num_treatments': num_treatments, 'num_covariates': num_covariates, 'num_outputs': num_outputs, 'max_sequence_length': length, 'num_epochs': 100} hyperparams = dict() num_simulations = 50 best_validation_mse = 1000000 if b_hyperparam_opt: logging.info("Performing hyperparameter optimization") for simulation in range(num_simulations): logging.info("Simulation {} out of {}".format(simulation + 1, num_simulations)) hyperparams['rnn_hidden_units'] = int(np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * num_inputs) hyperparams['br_size'] = int(np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * num_inputs) hyperparams['fc_hidden_units'] = int(np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * (hyperparams['br_size'])) hyperparams['learning_rate'] = np.random.choice([0.01, 0.001]) hyperparams['batch_size'] = np.random.choice([64, 128, 256]) hyperparams['rnn_keep_prob'] = np.random.choice([0.7, 0.8, 0.9]) logging.info("Current hyperparams used for training \n {}".format(hyperparams)) model = CRN_Model(params, hyperparams) model.train(dataset_train, dataset_val, model_name, model_dir) validation_mse, _ = model.evaluate_predictions(dataset_val) if (validation_mse < best_validation_mse): logging.info( "Updating best validation loss | Previous best validation loss: {} | Current best validation loss: {}".format( best_validation_mse, validation_mse)) best_validation_mse = validation_mse best_hyperparams = hyperparams.copy() logging.info("Best hyperparams: \n {}".format(best_hyperparams)) write_results_to_file(hyperparams_file, best_hyperparams) else: logging.info("Using default hyperparameters") best_hyperparams = { 'rnn_hidden_units': 24, 'br_size': 12, 'fc_hidden_units': 36, 'learning_rate': 0.01, 'batch_size': 128, 'rnn_keep_prob': 0.9} logging.info("Best hyperparams: \n {}".format(best_hyperparams)) write_results_to_file(hyperparams_file, best_hyperparams) model = CRN_Model(params, best_hyperparams) model.train(dataset_train, dataset_val, model_name, model_dir)
def fit_CRN_decoder(dataset_train, dataset_val, model_name, model_dir, encoder_hyperparams_file, decoder_hyperparams_file, b_hyperparam_opt): logging.info("Fitting CRN decoder.") _, length, num_covariates = dataset_train['current_covariates'].shape num_treatments = dataset_train['current_treatments'].shape[-1] num_outputs = dataset_train['outputs'].shape[-1] num_inputs = dataset_train['current_covariates'].shape[-1] + dataset_train[ 'current_treatments'].shape[-1] params = { 'num_treatments': num_treatments, 'num_covariates': num_covariates, 'num_outputs': num_outputs, 'max_sequence_length': length, 'num_epochs': 100 } hyperparams = dict() num_simulations = 30 best_validation_mse = 1000000 with open(encoder_hyperparams_file, 'rb') as handle: encoder_best_hyperparams = pickle.load(handle) if b_hyperparam_opt: logging.info("Performing hyperparameter optimization.") for simulation in range(num_simulations): logging.info("Simulation {} out of {}".format( simulation + 1, num_simulations)) # The first rnn hidden state in the decoder is initialized with the balancing representation # outputed by the encoder. hyperparams['rnn_hidden_units'] = encoder_best_hyperparams[ 'br_size'] hyperparams['br_size'] = int( np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * num_inputs) hyperparams['fc_hidden_units'] = int( np.random.choice([0.5, 1.0, 2.0, 3.0, 4.0]) * (hyperparams['br_size'])) hyperparams['learning_rate'] = np.random.choice( [0.01, 0.001, 0.0001]) hyperparams['batch_size'] = np.random.choice([256, 512, 1024]) hyperparams['rnn_keep_prob'] = np.random.choice([0.7, 0.8, 0.9]) logging.info("Current hyperparams used for training \n {}".format( hyperparams)) model = CRN_Model(params, hyperparams, b_train_decoder=True) model.train(dataset_train, dataset_val, model_name, model_dir) validation_mse, _ = model.evaluate_predictions(dataset_val) if (validation_mse < best_validation_mse): logging.info( "Updating best validation loss | Previous best validation loss: {} | Current best validation loss: {}" .format(best_validation_mse, validation_mse)) best_validation_mse = validation_mse best_hyperparams = hyperparams.copy() logging.info("Best hyperparams: \n {}".format(best_hyperparams)) write_results_to_file(decoder_hyperparams_file, best_hyperparams) else: # The rnn_hidden_units needs to be the same as the encoder br_size. logging.info("Using default hyperparameters") best_hyperparams = { 'br_size': 18, 'rnn_keep_prob': 0.9, 'fc_hidden_units': 36, 'batch_size': 1024, 'learning_rate': 0.001, 'rnn_hidden_units': encoder_best_hyperparams['br_size'] } write_results_to_file(decoder_hyperparams_file, best_hyperparams) model = CRN_Model(params, best_hyperparams, b_train_decoder=True) model.train(dataset_train, dataset_val, model_name, model_dir)