def build(config, is_training): if not isinstance(config, bidirectional_rnn_pb2.BidirectionalRnn): raise ValueError( 'config not of type bidirectional_rnn_pb2.BidirectionalRnn') if config.static: brnn_class = bidirectional_rnn.StaticBidirectionalRnn else: brnn_class = bidirectional_rnn.DynamicBidirectionalRnn fw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell) bw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell) rnn_regularizer_object = hyperparams_builder._build_regularizer( config.rnn_regularizer) fc_hyperparams_object = None if config.num_output_units > 0: if config.fc_hyperparams.op != hyperparams_pb2.Hyperparams.FC: raise ValueError('op type must be FC') fc_hyperparams_object = hyperparams_builder.build( config.fc_hyperparams, is_training) return brnn_class(fw_cell_object, bw_cell_object, rnn_regularizer=rnn_regularizer_object, num_output_units=config.num_output_units, fc_hyperparams=fc_hyperparams_object, summarize_activations=config.summarize_activations)
def build(config, is_training): if not isinstance(config, predictor_pb2.Predictor): raise ValueError('config not of type predictor_pb2.AttentionPredictor') predictor_oneof = config.WhichOneof('predictor_oneof') if predictor_oneof == 'attention_predictor': predictor_config = config.attention_predictor rnn_cell_object = rnn_cell_builder.build(predictor_config.rnn_cell) rnn_regularizer_object = hyperparams_builder._build_regularizer(predictor_config.rnn_regularizer) label_map_object = label_map_builder.build(predictor_config.label_map) loss_object = loss_builder.build(predictor_config.loss) if not predictor_config.HasField('lm_rnn_cell'): lm_rnn_cell_object = None else: lm_rnn_cell_object = _build_language_model_rnn_cell(predictor_config.lm_rnn_cell) attention_predictor_object = attention_predictor.AttentionPredictor( rnn_cell=rnn_cell_object, rnn_regularizer=rnn_regularizer_object, num_attention_units=predictor_config.num_attention_units, max_num_steps=predictor_config.max_num_steps, multi_attention=predictor_config.multi_attention, beam_width=predictor_config.beam_width, reverse=predictor_config.reverse, label_map=label_map_object, loss=loss_object, sync=predictor_config.sync, lm_rnn_cell=lm_rnn_cell_object, is_training=is_training ) return attention_predictor_object else: raise ValueError('Unknown predictor_oneof: {}'.format(predictor_oneof))
def _build_language_model_rnn_cell(config): if not isinstance(config, predictor_pb2.LanguageModelRnnCell): raise ValueError('config not of type predictor_pb2.LanguageModelRnnCell') rnn_cell_list = [ rnn_cell_builder.build(rnn_cell_config) for rnn_cell_config in config.rnn_cell ] lm_rnn_cell = rnn.MultiRNNCell(rnn_cell_list) return lm_rnn_cell
def test_build_gru_cell(self): rnn_cell_text_proto = """ gru_cell { num_units: 1024 initializer { orthogonal_initializer { seed: 1 } } } """ rnn_cell_proto = rnn_cell_pb2.RnnCell() text_format.Merge(rnn_cell_text_proto, rnn_cell_proto) rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto) self.assertEqual(rnn_cell_object.state_size, 1024)
def test_build_lstm_cell(self): rnn_cell_text_proto = """ lstm_cell { num_units: 1024 use_peepholes: true forget_bias: 1.5 initializer { orthogonal_initializer { seed: 1 } } } """ rnn_cell_proto = rnn_cell_pb2.RnnCell() text_format.Merge(rnn_cell_text_proto, rnn_cell_proto) rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto) lstm_state_tuple = rnn_cell_object.state_size self.assertEqual(lstm_state_tuple[0], 1024) self.assertEqual(lstm_state_tuple[1], 1024)