Beispiel #1
0
def train(data_hparams, model_hparams, training_hparams):
    """Executes the training pipeline for SNPs."""
    all_context_action_pairs, all_rewards = procure_dataset(data_hparams,
                                                            num_wheels=100,
                                                            seed=0)
    train_dataset = (all_context_action_pairs, all_rewards)

    all_context_action_pairs, all_rewards = procure_dataset(data_hparams,
                                                            num_wheels=10,
                                                            seed=42)
    valid_dataset = (all_context_action_pairs, all_rewards)

    model = regressor.Regressor(
        input_dim=data_hparams.context_dim + data_hparams.num_actions,
        output_dim=1,
        x_encoder_sizes=model_hparams.x_encoder_sizes,
        x_y_encoder_sizes=model_hparams.x_y_encoder_sizes,
        global_latent_net_sizes=model_hparams.global_latent_net_sizes,
        local_latent_net_sizes=model_hparams.local_latent_net_sizes,
        heteroskedastic_net_sizes=model_hparams.heteroskedastic_net_sizes,
        att_type=model_hparams.att_type,
        att_heads=model_hparams.att_heads,
        uncertainty_type=model_hparams.uncertainty_type,
        mean_att_type=model_hparams.mean_att_type,
        scale_att_type_1=model_hparams.scale_att_type_1,
        scale_att_type_2=model_hparams.scale_att_type_2,
        activation=model_hparams.activation,
        output_activation=model_hparams.output_activation,
        data_uncertainty=model_hparams.data_uncertainty,
        local_variational=model_hparams.local_variational)

    training_loop(train_dataset, valid_dataset, model, training_hparams)
Beispiel #2
0
    def __init__(self, name, hparams, optimizer='RMS'):
        self.name = name
        self.hparams = hparams
        self.verbose = getattr(hparams, 'verbose', True)

        self.update_freq_lr = hparams.training_freq
        self.update_freq_nn = hparams.training_freq_network

        self.t = 0
        self.num_epochs = hparams.training_epochs
        self.data_h = contextual_dataset.ContextualDataset(hparams.context_dim,
                                                           hparams.num_actions,
                                                           intercept=False)

        self.gradient_updates = tf.Variable(0, trainable=False)
        if self.hparams.activate_decay:
            self.lr = tf.train.inverse_time_decay(self.hparams.initial_lr,
                                                  self.gradient_updates, 1,
                                                  self.hparams.lr_decay_rate)
        else:
            self.lr = tf.Variable(self.hparams.initial_lr, trainable=False)
        optimizer = tf.train.RMSPropOptimizer(self.lr)
        self._optimizer_config = {
            'optimizer': optimizer,
            'max_grad_norm': hparams.max_grad_norm
        }

        if self.verbose:
            print('Initializing model {}.'.format(self.name))
        self.snp = regressor.Regressor(
            input_dim=hparams.context_dim + hparams.num_actions,
            output_dim=1,
            x_encoder_sizes=hparams.x_encoder_sizes,
            x_y_encoder_sizes=hparams.x_y_encoder_sizes,
            global_latent_net_sizes=hparams.global_latent_net_sizes,
            local_latent_net_sizes=hparams.local_latent_net_sizes,
            heteroskedastic_net_sizes=hparams.heteroskedastic_net_sizes,
            att_type=hparams.att_type,
            att_heads=hparams.att_heads,
            uncertainty_type=hparams.uncertainty_type,
            mean_att_type=hparams.mean_att_type,
            scale_att_type_1=hparams.scale_att_type_1,
            scale_att_type_2=hparams.scale_att_type_2,
            activation=hparams.activation,
            output_activation=hparams.output_activation,
            data_uncertainty=hparams.data_uncertainty,
            local_variational=hparams.local_variational,
            model_path=hparams.model_path)

        self._step = tf.function(utils.mse_step.python_function)  # pytype: disable=module-attr

        self._one_hot_vectors = tf.one_hot(indices=np.arange(
            hparams.num_actions),
                                           depth=hparams.num_actions)
Beispiel #3
0
    def __init__(self, name, hparams):
        self.name = name
        self.hparams = hparams
        self.verbose = getattr(hparams, 'verbose', True)
        self._is_anp = getattr(hparams, 'is_anp', False)
        if self._is_anp:
            input_dim = hparams.context_dim
            output_dim = hparams.num_actions
        else:
            input_dim = hparams.context_dim + hparams.num_actions
            output_dim = 1

        self.t = 0
        self.data_h = contextual_dataset.ContextualDataset(hparams.context_dim,
                                                           hparams.num_actions,
                                                           intercept=False)

        if self.verbose:
            print('Initializing model {}.'.format(self.name))
        self.snp = regressor.Regressor(
            input_dim=input_dim,
            output_dim=output_dim,
            x_encoder_sizes=hparams.x_encoder_sizes,
            x_y_encoder_sizes=hparams.x_y_encoder_sizes,
            global_latent_net_sizes=hparams.global_latent_net_sizes,
            local_latent_net_sizes=hparams.local_latent_net_sizes,
            heteroskedastic_net_sizes=hparams.heteroskedastic_net_sizes,
            att_type=hparams.att_type,
            att_heads=hparams.att_heads,
            uncertainty_type=hparams.uncertainty_type,
            mean_att_type=hparams.mean_att_type,
            scale_att_type_1=hparams.scale_att_type_1,
            scale_att_type_2=hparams.scale_att_type_2,
            activation=hparams.activation,
            output_activation=hparams.output_activation,
            data_uncertainty=hparams.data_uncertainty,
            local_variational=hparams.local_variational,
            model_path=hparams.model_path)

        self._one_hot_vectors = tf.one_hot(indices=np.arange(
            hparams.num_actions),
                                           depth=hparams.num_actions)