コード例 #1
0
def build(config, is_training):
    if not isinstance(config, bidirectional_rnn_pb2.BidirectionalRnn):
        raise ValueError(
            'config not of type bidirectional_rnn_pb2.BidirectionalRnn')

    if config.static:
        brnn_class = bidirectional_rnn.StaticBidirectionalRnn
    else:
        brnn_class = bidirectional_rnn.DynamicBidirectionalRnn

    fw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell)
    bw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell)
    rnn_regularizer_object = hyperparams_builder._build_regularizer(
        config.rnn_regularizer)
    fc_hyperparams_object = None
    if config.num_output_units > 0:
        if config.fc_hyperparams.op != hyperparams_pb2.Hyperparams.FC:
            raise ValueError('op type must be FC')
        fc_hyperparams_object = hyperparams_builder.build(
            config.fc_hyperparams, is_training)

    return brnn_class(fw_cell_object,
                      bw_cell_object,
                      rnn_regularizer=rnn_regularizer_object,
                      num_output_units=config.num_output_units,
                      fc_hyperparams=fc_hyperparams_object,
                      summarize_activations=config.summarize_activations)
コード例 #2
0
def build(config, is_training):
  if not isinstance(config, predictor_pb2.Predictor):
    raise ValueError('config not of type predictor_pb2.AttentionPredictor')
  predictor_oneof = config.WhichOneof('predictor_oneof')

  if predictor_oneof == 'attention_predictor':
    predictor_config = config.attention_predictor
    rnn_cell_object = rnn_cell_builder.build(predictor_config.rnn_cell)
    rnn_regularizer_object = hyperparams_builder._build_regularizer(predictor_config.rnn_regularizer)
    label_map_object = label_map_builder.build(predictor_config.label_map)
    loss_object = loss_builder.build(predictor_config.loss)
    if not predictor_config.HasField('lm_rnn_cell'):
      lm_rnn_cell_object = None
    else:
      lm_rnn_cell_object = _build_language_model_rnn_cell(predictor_config.lm_rnn_cell)
      
    attention_predictor_object = attention_predictor.AttentionPredictor(
      rnn_cell=rnn_cell_object,
      rnn_regularizer=rnn_regularizer_object,
      num_attention_units=predictor_config.num_attention_units,
      max_num_steps=predictor_config.max_num_steps,
      multi_attention=predictor_config.multi_attention,
      beam_width=predictor_config.beam_width,
      reverse=predictor_config.reverse,
      label_map=label_map_object,
      loss=loss_object,
      sync=predictor_config.sync,
      lm_rnn_cell=lm_rnn_cell_object,
      is_training=is_training
    )
    return attention_predictor_object
  else:
    raise ValueError('Unknown predictor_oneof: {}'.format(predictor_oneof))
コード例 #3
0
def _build_language_model_rnn_cell(config):
  if not isinstance(config, predictor_pb2.LanguageModelRnnCell):
    raise ValueError('config not of type predictor_pb2.LanguageModelRnnCell')
  rnn_cell_list = [
    rnn_cell_builder.build(rnn_cell_config) for rnn_cell_config in config.rnn_cell
  ]
  lm_rnn_cell = rnn.MultiRNNCell(rnn_cell_list)
  return lm_rnn_cell
コード例 #4
0
    def test_build_gru_cell(self):
        rnn_cell_text_proto = """
    gru_cell {
      num_units: 1024
      initializer { orthogonal_initializer { seed: 1 } }
    }
    """
        rnn_cell_proto = rnn_cell_pb2.RnnCell()
        text_format.Merge(rnn_cell_text_proto, rnn_cell_proto)
        rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto)

        self.assertEqual(rnn_cell_object.state_size, 1024)
コード例 #5
0
    def test_build_lstm_cell(self):
        rnn_cell_text_proto = """
    lstm_cell {
      num_units: 1024
      use_peepholes: true
      forget_bias: 1.5
      initializer { orthogonal_initializer { seed: 1 } }
    }
    """
        rnn_cell_proto = rnn_cell_pb2.RnnCell()
        text_format.Merge(rnn_cell_text_proto, rnn_cell_proto)
        rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto)

        lstm_state_tuple = rnn_cell_object.state_size

        self.assertEqual(lstm_state_tuple[0], 1024)
        self.assertEqual(lstm_state_tuple[1], 1024)