コード例 #1
0
 def create_cells(self):
     """
     Creates a Tensorflow RNN cell object by using the given configuration.
     """
     self.cell = get_rnn_cell(scope='rnn_cell',
                              reuse=self.reuse,
                              **self.rnn_config)
     self.initial_states = self.cell.zero_state(batch_size=self.batch_size,
                                                dtype=tf.float32)
コード例 #2
0
    def build_cell(self):
        """
        Builds a Tensorflow RNN cell object by using the given configuration `self.rnn_layer_config`.
        """
        if self.stack_fw_bw_cells:
            single_cell_config = self.rnn_layer_config.copy()
            single_cell_config['num_layers'] = 1
            for i in range(self.rnn_layer_config['num_layers']):
                cell_fw = get_rnn_cell(scope='rnn_cell_fw', reuse=self.reuse, **single_cell_config)
                self.cells_fw.append(cell_fw)
                self.initial_states_fw.append(cell_fw.zero_state(batch_size=self.batch_size, dtype=tf.float32))

                cell_bw = get_rnn_cell(scope='rnn_cell_bw', reuse=self.reuse, **single_cell_config)
                self.cells_bw.append(cell_bw)
                self.initial_states_bw.append(cell_bw.zero_state(batch_size=self.batch_size, dtype=tf.float32))
        else:
            cell_fw = get_rnn_cell(scope='rnn_cell_fw', reuse=self.reuse, **self.rnn_layer_config)
            self.cells_fw.append(cell_fw)
            self.initial_states_fw.append(cell_fw.zero_state(batch_size=self.batch_size, dtype=tf.float32))

            cell_bw = get_rnn_cell(scope='rnn_cell_bw', reuse=self.reuse, **self.rnn_layer_config)
            self.cells_bw.append(cell_bw)
            self.initial_states_bw.append(cell_bw.zero_state(batch_size=self.batch_size, dtype=tf.float32))
コード例 #3
0
    def create_cells(self):
        if self.stack_fw_bw_cells:
            single_cell_config = self.rnn_config.copy()
            single_cell_config['num_layers'] = 1
            for i in range(self.rnn_config['num_layers']):
                cell_fw = get_rnn_cell(scope='rnn_cell_fw',
                                       reuse=self.reuse,
                                       **single_cell_config)
                self.cells_fw.append(cell_fw)
                self.initial_states_fw.append(
                    cell_fw.zero_state(batch_size=self.batch_size,
                                       dtype=tf.float32))

                cell_bw = get_rnn_cell(scope='rnn_cell_bw',
                                       reuse=self.reuse,
                                       **single_cell_config)
                self.cells_bw.append(cell_bw)
                self.initial_states_bw.append(
                    cell_bw.zero_state(batch_size=self.batch_size,
                                       dtype=tf.float32))
        else:
            cell_fw = get_rnn_cell(scope='rnn_cell_fw',
                                   reuse=self.reuse,
                                   **self.rnn_config)
            self.cells_fw.append(cell_fw)
            self.initial_states_fw.append(
                cell_fw.zero_state(batch_size=self.batch_size,
                                   dtype=tf.float32))

            cell_bw = get_rnn_cell(scope='rnn_cell_bw',
                                   reuse=self.reuse,
                                   **self.rnn_config)
            self.cells_bw.append(cell_bw)
            self.initial_states_bw.append(
                cell_bw.zero_state(batch_size=self.batch_size,
                                   dtype=tf.float32))
コード例 #4
0
    def __init__(self, reuse, mode, config):
        """

        Args:
            config (dict): In addition to standard <key, value> pairs, stores the following dictionaries for rnn and
                output configurations.

                config['output'] = {}
                config['output']['keys']
                config['output']['dims']
                config['output']['activation_funcs']

                config['*_rnn'] = {}
                config['*_rnn']['num_layers'] (default: 1)
                config['*_rnn']['cell_type'] (default: lstm)
                config['*_rnn']['size'] (default: 512)

            reuse: reuse model parameters.
            mode: 'training' or 'sampling'.
        """
        self.input_dims = config['input_dims']
        self.h_dim = config['latent_hidden_size']
        self.z_dim = config['latent_size']
        self.additive_q_mu = config['additive_q_mu']

        self.dropout_keep_prob = config.get('input_keep_prop', 1)
        self.num_linear_layers = config.get('num_fc_layers', 1)
        self.use_latent_h_in_outputs = config.get('use_latent_h_in_outputs',
                                                  True)
        self.use_batch_norm = config['use_batch_norm_fc']

        self.reuse = reuse
        self.mode = mode
        self.is_sampling = mode == 'sampling'

        if not (mode == "training"):
            self.dropout_keep_prob = 1.0

        self.output_config = config['output']

        self.output_size_ = [self.z_dim] * 4
        self.output_size_.extend(
            self.output_config['dims']
        )  # q_mu, q_sigma, p_mu, p_sigma + model outputs

        self.state_size_ = []
        # Optional. Linear layers will be used if not passed.
        self.input_rnn = False
        if 'input_rnn' in config and not (config['input_rnn'] is None) and len(
                config['input_rnn'].keys()) > 0:
            self.input_rnn = True
            self.input_rnn_config = config['input_rnn']

            self.input_rnn_cell = get_rnn_cell(scope='input_rnn',
                                               **config['input_rnn'])
            self.state_size_.append(self.input_rnn_cell.state_size)

        self.latent_rnn_config = config['latent_rnn']
        self.latent_rnn_cell_type = config['latent_rnn']['cell_type']
        self.latent_rnn_cell = get_rnn_cell(scope='latent_rnn',
                                            **config['latent_rnn'])
        self.state_size_.append(self.latent_rnn_cell.state_size)

        # Optional. Linear layers will be used if not passed.
        self.output_rnn = False
        if 'output_rnn' in config and not (
                config['output_rnn'] is None) and len(
                    config['output_rnn'].keys()) > 0:
            self.output_rnn = True
            self.output_rnn_config = config['output_rnn']

            self.output_rnn_cell = get_rnn_cell(scope='output_rnn',
                                                **config['output_rnn'])
            self.state_size_.append(self.output_rnn_cell.state_size)

        self.activation_func = get_activation_fn(
            config.get('fc_layer_activation_func', 'relu'))
        self.sigma_func = get_activation_fn('softplus')
コード例 #5
0
    def __init__(self, reuse, mode, sample_fn, config):
        """

        Args:
            reuse: reuse model parameters.
            mode: 'training' or 'sampling'.
            sample_fn: function to generate sample given model outputs.

            config (dict): In addition to standard <key, value> pairs, stores the following dictionaries for rnn and
                output configurations.

                config['output_layer'] = {}
                config['output_layer']['out_keys']
                config['output_layer']['out_dims']
                config['output_layer']['out_activation_fn']

                config['*_rnn'] = {}
                config['*_rnn']['num_layers'] (default: 1)
                config['*_rnn']['cell_type'] (default: lstm)
                config['*_rnn']['size'] (default: 512)
        """
        self.reuse = reuse
        self.mode = mode
        self.sample_fn = sample_fn
        self.is_sampling = mode == 'sampling'
        self.is_evaluation = mode == "validation" or mode == "test"

        self.input_dims = config['input_dims']
        self.h_dim = config['hidden_size']
        self.latent_h_dim = config.get('latent_hidden_size', self.h_dim)
        self.z_dim = config['latent_size']
        self.additive_q_mu = config['additive_q_mu']

        self.dropout_keep_prob = config.get('input_keep_prop', 1)
        self.num_linear_layers = config.get('num_fc_layers', 1)
        self.use_latent_h_in_outputs = config.get('use_latent_h_in_outputs',
                                                  True)
        self.use_batch_norm = config['use_batch_norm_fc']

        if not (mode == "training"):
            self.dropout_keep_prob = 1.0

        self.output_config = config['output_layer']

        self.output_size_ = [self.z_dim] * 4
        self.output_size_.extend(
            self.output_config['out_dims']
        )  # q_mu, q_sigma, p_mu, p_sigma + model outputs

        self.state_size_ = []
        # Optional. Linear layers will be used if not passed.
        self.input_rnn = False
        if 'input_rnn' in config and not (config['input_rnn'] is None) and len(
                config['input_rnn'].keys()) > 0:
            self.input_rnn = True
            self.input_rnn_config = config['input_rnn']

            self.input_rnn_cell = get_rnn_cell(scope='input_rnn',
                                               **config['input_rnn'])

            # Variational dropout
            if config['input_rnn'].get('use_variational_dropout', False):
                # TODO input dimensions are hard-coded.
                self.input_rnn_cell = tf.contrib.rnn.DropoutWrapper(
                    self.input_rnn_cell,
                    input_keep_prob=self.dropout_keep_prob,
                    output_keep_prob=self.dropout_keep_prob,
                    variational_recurrent=True,
                    input_size=(216),
                    dtype=tf.float32)
                self.dropout_keep_prob = 1.0

            self.state_size_.append(self.input_rnn_cell.state_size)

        self.latent_rnn_config = config['latent_rnn']
        self.latent_rnn_cell_type = config['latent_rnn']['cell_type']
        self.latent_rnn_cell = get_rnn_cell(scope='latent_rnn',
                                            **config['latent_rnn'])
        self.state_size_.append(self.latent_rnn_cell.state_size)

        # Optional. Linear layers will be used if not passed.
        self.output_rnn = False
        if 'output_rnn' in config and not (
                config['output_rnn'] is None) and len(
                    config['output_rnn'].keys()) > 0:
            self.output_rnn = True
            self.output_rnn_config = config['output_rnn']

            self.output_rnn_cell = get_rnn_cell(scope='output_rnn',
                                                **config['output_rnn'])
            self.state_size_.append(self.output_rnn_cell.state_size)

        self.activation_func = get_activation_fn(
            config.get('fc_layer_activation_func', 'relu'))
        self.sigma_activaction_fn = tf.nn.softplus
コード例 #6
0
 def build_cell(self):
     """
     Builds a Tensorflow RNN cell object by using the given configuration `self.rnn_layer_config`.
     """
     self.cell = get_rnn_cell(scope='rnn_cell', reuse=self.reuse, **self.rnn_layer_config)
     self.initial_states = self.cell.zero_state(batch_size=self.batch_size, dtype=tf.float32)