예제 #1
0
def create_or_restore_model(session, buckets, forward_only, beam_search,
                            beam_size):
    """Create model and initialize or load parameters"""

    model = seq2seq_model.Seq2SeqModel(
        source_vocab_size=config.MAX_ENC_VOCABULARY,
        target_vocab_size=config.MAX_DEC_VOCABULARY,
        buckets=buckets,
        size=config.LAYER_SIZE,
        num_layers=config.NUM_LAYERS,
        max_gradient_norm=config.MAX_GRADIENT_NORM,
        batch_size=config.BATCH_SIZE,
        learning_rate=config.LEARNING_RATE,
        learning_rate_decay_factor=config.LEARNING_RATE_DECAY_FACTOR,
        beam_search=beam_search,
        attention=True,
        forward_only=forward_only,
        beam_size=beam_size)

    print("model initialized")
    ckpt = tf.train.get_checkpoint_state(config.GENERATED_DIR)
    # the checkpoint filename has changed in recent versions of tensorflow
    checkpoint_suffix = ".index"
    if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path +
                                checkpoint_suffix):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.global_variables_initializer())
    return model
예제 #2
0
def create_model(session, args, forward_only=True):
    """Create translation model and initialize or load parameters in session."""

    model = seq2seq_model.Seq2SeqModel(
        source_vocab_size=args.vocab_size,
        target_vocab_size=args.vocab_size,
        buckets=args.buckets,
        size=args.size,
        num_layers=args.num_layers,
        max_gradient_norm=args.max_gradient_norm,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        learning_rate_decay_factor=args.learning_rate_decay_factor,
        forward_only=forward_only,
    )

    # for tensorboard
    if args.en_tfboard:
        summary_writer = tf.train.SummaryWriter(args.tf_board_dir,
                                                session.graph)

    ckpt = tf.train.get_checkpoint_state(args.model_dir)
    # if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
    if ckpt and ckpt.model_checkpoint_path:
        print("Reading model parameters from %s @ %s" %
              (ckpt.model_checkpoint_path, datetime.now()))
        model.saver.restore(session, ckpt.model_checkpoint_path)
        print("Model reloaded @ %s" % (datetime.now()))
    else:
        print("Created model with fresh parameters.")
        session.run(tf.global_variables_initializer())

    return model
예제 #3
0
def create_or_restore_model(session, buckets, forward_only):
    """Create model and initialize or load parameters"""
    model = seq2seq_model.Seq2SeqModel(config.MAX_ENC_VOCABULARY,
                                       config.MAX_DEC_VOCABULARY,
                                       buckets,
                                       config.LAYER_SIZE,
                                       config.NUM_LAYERS,
                                       config.MAX_GRADIENT_NORM,
                                       config.BATCH_SIZE,
                                       config.LEARNING_RATE,
                                       config.LEARNING_RATE_DECAY_FACTOR,
                                       forward_only=forward_only)

    print("model initialized")
    ckpt = tf.train.get_checkpoint_state(config.GENERATED_DIR)
    # the checkpoint filename has changed in recent versions of tensorflow
    checkpoint_suffix = ""
    if tf.__version__ > "0.12":
        checkpoint_suffix = ".index"
    if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path +
                                checkpoint_suffix):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())
    return model
예제 #4
0
def self_test():
    """Test the translation model."""
    with tf.Session() as sess:
        print("Self-test for neural translation model.")
        # Create model with vocabularies of 10, 2 small buckets, 2 layers of 32.
        model = seq2seq_model.Seq2SeqModel(10,
                                           10, [(3, 3), (6, 6)],
                                           32,
                                           2,
                                           5.0,
                                           32,
                                           0.3,
                                           0.99,
                                           num_samples=8)
        sess.run(tf.global_variables_initializer())

        # Fake data set for both the (3, 3) and (6, 6) bucket.
        data_set = ([([1, 1], [2, 2]), ([3, 3], [4]),
                     ([5], [6])], [([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]),
                                   ([3, 3, 3], [5, 6])])
        for _ in xrange(5):  # Train the fake model for 5 steps.
            bucket_id = random.choice([0, 1])
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                data_set, bucket_id)
            model.step(sess, encoder_inputs, decoder_inputs, target_weights,
                       bucket_id, False)
def create_model(session, args):
    """Create translation model and initialize or load parameters in session."""
    model = seq2seq_model.Seq2SeqModel(
        source_vocab_size=args.vocab_size,
        target_vocab_size=args.vocab_size,
        buckets=args.buckets,
        size=args.size,
        num_layers=args.num_layers,
        max_gradient_norm=args.max_gradient_norm,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        learning_rate_decay_factor=args.learning_rate_decay_factor,
        use_lstm=False,
    )

    # for tensorboard
    if args.en_tfboard:
        summary_writer = tf.train.SummaryWriter(args.tf_board_dir,
                                                session.graph)

    ckpt = tf.train.get_checkpoint_state(args.model_dir)
    if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())
    return model
예제 #6
0
def create_model(session, forward_only):
    model = seq2seq_model.Seq2SeqModel(
        source_vocab_size=FLAGS.vocab_size,
        target_vocab_size=FLAGS.vocab_size,
        buckets=BUCKETS,
        size=FLAGS.size,
        num_layers=FLAGS.num_layers,
        max_gradient_norm=FLAGS.max_gradient_norm,
        batch_size=FLAGS.batch_size,
        learning_rate=FLAGS.learning_rate,
        learning_rate_decay_factor=FLAGS.learning_rate_decay_factor,
        use_lstm=True,
        forward_only=forward_only)

    ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)

    if ckpt:
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())
    return model
예제 #7
0
def create_model(session, forward_only, train_dir=None):
    """Create translation model and initialize or load parameters in session.
  """
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = seq2seq_model.Seq2SeqModel(FLAGS.from_vocab_size,
                                       FLAGS.to_vocab_size,
                                       _buckets,
                                       FLAGS.size,
                                       FLAGS.num_layers,
                                       FLAGS.max_gradient_norm,
                                       FLAGS.batch_size,
                                       FLAGS.learning_rate,
                                       FLAGS.learning_rate_decay_factor,
                                       forward_only=forward_only,
                                       dtype=dtype)
    train_dir = train_dir or FLAGS.train_dir
    ckpt = tf.train.get_checkpoint_state(train_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.global_variables_initializer())
    return model