예제 #1
0
 def default_params():
   params = AttentionSeq2Seq.default_params()
   params["discriminator_units"] = 256
   params["discriminator_loss_multiplier"] = 10.0
   params["discriminator_reverse_grad"] = False
   params["discriminator_mix_context"] = False
   return params
예제 #2
0
 def create_model(self, params=None):
     params_ = AttentionSeq2Seq.default_params().copy()
     params_.update({"source.reverse": True})
     params_.update(params or {})
     return AttentionSeq2Seq(source_vocab_info=self.vocab_info,
                             target_vocab_info=self.vocab_info,
                             params=params_)
예제 #3
0
 def create_model(self, mode, params=None):
     params_ = AttentionSeq2Seq.default_params().copy()
     params_.update(TEST_PARAMS)
     params_.update({
         "source.reverse": True,
         "vocab_source": self.vocab_file.name,
         "vocab_target": self.vocab_file.name,
     })
     params_.update(params or {})
     return AttentionSeq2Seq(params=params_, mode=mode)
예제 #4
0
 def create_model(self, mode, params=None):
   params_ = AttentionSeq2Seq.default_params().copy()
   params_.update(TEST_PARAMS)
   params_.update({
       "source.reverse": True,
       "vocab_source": self.vocab_file.name,
       "vocab_target": self.vocab_file.name,
   })
   params_.update(params or {})
   return AttentionSeq2Seq(params=params_, mode=mode)
예제 #5
0
 def default_params():
   """call in configurable class, return default params
   """
   params = AttentionSeq2Seq.default_params().copy()
   params.update({
       "pointer_gen": True,
       "coverage": True,
       "embedding.share": True,
       "attention.class": "AttentionLayerBahdanau",
       "attention.params": {}, # Arbitrary attention layer parameters
       "bridge.class": "seq2seq.models.bridges.ZeroBridge",
       "encoder.class": "seq2seq.encoders.BidirectionalRNNEncoder",
       "encoder.params": {},  # Arbitrary parameters for the encoder
       "decoder.class": "seq2seq.decoders.CopyGenDecoder",
       "decoder.params": {}  # Arbitrary parameters for the decoder
   })
   return params
예제 #6
0
def test_model(source_path, target_path, vocab_path):

    tf.logging.set_verbosity(tf.logging.INFO)
    batch_size = 2

    # Build model graph
    mode = tf.contrib.learn.ModeKeys.TRAIN
    params_ = AttentionSeq2Seq.default_params().copy()
    params_.update({
        "vocab_source": vocab_path,
        "vocab_target": vocab_path,
    })
    model = AttentionSeq2Seq(params=params_, mode=mode)

    tf.logging.info(vocab_path)

    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={
        "source_files": [source_path],
        "target_files": [target_path]
    },
                                                               mode=mode)
    input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                              batch_size=batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)

    fetches = [_ for _ in fetches if _ is not None]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())
        with tf.contrib.slim.queues.QueueRunners(sess):
            fetches_ = sess.run(fetches)

    return model, fetches_
예제 #7
0
 def create_model(self):
     return AttentionSeq2Seq(source_vocab_info=self.vocab_info,
                             target_vocab_info=self.vocab_info,
                             params=AttentionSeq2Seq.default_params())