Exemplo n.º 1
0
def create_seq2seq(session, mode):

  if FLAGS.mode == 'TEST' or 'val' in FLAGS.mode:
    FLAGS.schedule_sampling = False 
  else:
    FLAGS.beam_search = False
  print('FLAGS.beam_search: ',FLAGS.beam_search)
  if FLAGS.beam_search:
    print('FLAGS.beam_size: ',FLAGS.beam_size)
    print('FLAGS.debug: ',bool(FLAGS.debug))
      
  model = seq2seq_model.Seq2seq(mode)

  if FLAGS.mode == 'val_rl':
    ckpt = tf.train.get_checkpoint_state(FLAGS.model_rl_dir)
  else:
    ckpt = tf.train.get_checkpoint_state(FLAGS.model_pre_dir)
  
  if FLAGS.load != '':
    print("Reading model from %s, mode: %s" % (FLAGS.load, FLAGS.mode))
    model.saver.restore(session, FLAGS.load)
  elif ckpt:
    print("Reading model from %s, mode: %s" % (ckpt.model_checkpoint_path, FLAGS.mode))
    model.saver.restore(session, ckpt.model_checkpoint_path)
  else:
    print("Create model with fresh parameters, mode: %s" % FLAGS.mode)
    session.run(tf.global_variables_initializer())
  
  return model
def create_seq2seq(session, mode):

  if mode == 'TEST':
    FLAGS.schedule_sampling = False 
  else:
    FLAGS.beam_search = False
  print('FLAGS.beam_search: ',FLAGS.beam_search)
  print('FLAGS.length_penalty: ',FLAGS.length_penalty)
  print('FLAGS.length_penalty_factor: ',FLAGS.length_penalty_factor)
  if FLAGS.beam_search:
    print('FLAGS.beam_size: ',FLAGS.beam_size)
    print('FLAGS.debug: ',bool(FLAGS.debug))
      
  model = seq2seq_model.Seq2seq(src_vocab_size = FLAGS.src_vocab_size,
                                trg_vocab_size = FLAGS.trg_vocab_size,
                                buckets = buckets,
                                size = FLAGS.hidden_size,
                                num_layers = FLAGS.num_layers,
                                batch_size = FLAGS.batch_size,
                                mode = mode,
                                input_keep_prob = FLAGS.input_keep_prob,
                                output_keep_prob = FLAGS.output_keep_prob,
                                state_keep_prob = FLAGS.state_keep_prob,
                                beam_search = FLAGS.beam_search,
                                beam_size = FLAGS.beam_size,
                                schedule_sampling = FLAGS.schedule_sampling,
                                sampling_decay_rate = FLAGS.sampling_decay_rate,
                                sampling_global_step = FLAGS.sampling_global_step,
                                sampling_decay_steps = FLAGS.sampling_decay_steps,
                                pretrain_vec = FLAGS.pretrain_vec,
                                pretrain_trainable = FLAGS.pretrain_trainable,
                                length_penalty = FLAGS.length_penalty,
                                length_penalty_factor = FLAGS.length_penalty_factor
                                )
  
  if len(FLAGS.bind) > 0:
    ckpt = tf.train.get_checkpoint_state(FLAGS.bind)
  elif mode != 'TEST':
    if FLAGS.mode == "MLE":
      ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
    elif FLAGS.mode == "RL":
      ckpt = tf.train.get_checkpoint_state(FLAGS.model_rl_dir)
  else:
    if FLAGS.test_mode == "MLE":
      ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
    elif FLAGS.test_mode == "RL":
      ckpt = tf.train.get_checkpoint_state(FLAGS.model_rl_dir)
  
  print("FLAGS.mode: ",FLAGS.mode)
  if ckpt:
    print("Reading model from %s, mode: %s" % (ckpt.model_checkpoint_path, mode))
    model.saver.restore(session, ckpt.model_checkpoint_path)
  else:
    print("Create model with fresh parameters, mode: %s" % mode)
    session.run(tf.global_variables_initializer())
  
  return model
Exemplo n.º 3
0
def create_seq2seq(session, mode):

  if mode == 'TEST':
    FLAGS.schedule_sampling = False 
  else:
    FLAGS.beam_search = False
  print('FLAGS.beam_search: ',FLAGS.beam_search)
  if FLAGS.beam_search:
    print('FLAGS.beam_size: ',FLAGS.beam_size)
    print('FLAGS.debug: ',bool(FLAGS.debug))
      
  model = seq2seq_model.Seq2seq(vocab_size = FLAGS.vocab_size,
                                buckets = buckets,
                                size = FLAGS.hidden_size,
                                num_layers = FLAGS.num_layers,
                                batch_size = FLAGS.batch_size,
                                mode = mode,
                                input_keep_prob = FLAGS.input_keep_prob,
                                output_keep_prob = FLAGS.output_keep_prob,
                                state_keep_prob = FLAGS.state_keep_prob,
                                beam_search = FLAGS.beam_search,
                                beam_size = FLAGS.beam_size,
                                schedule_sampling = FLAGS.schedule_sampling,
                                sampling_decay_rate = FLAGS.sampling_decay_rate,
                                sampling_global_step = FLAGS.sampling_global_step,
                                sampling_decay_steps = FLAGS.sampling_decay_steps
                                )
  
  #if mode != 'TEST':
  ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
  #else:
  #  ckpt = tf.train.get_checkpoint_state(FLAGS.model_rl_dir)
  
  if ckpt:
    print("Reading model from %s, mode: %s" % (ckpt.model_checkpoint_path, mode))
    model.saver.restore(session, ckpt.model_checkpoint_path)
  else:
    print("Create model with fresh parameters, mode: %s" % mode)
    session.run(tf.global_variables_initializer())
  
  return model
def create_seq2seq(session, mode):

    model = seq2seq_model.Seq2seq(vocab_size=FLAGS.vocab_size,
                                  buckets=buckets,
                                  size=FLAGS.hidden_size,
                                  num_layers=FLAGS.num_layers,
                                  batch_size=FLAGS.batch_size,
                                  mode=mode)

    #if mode != 'TEST':
    ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
    #else:
    #  ckpt = tf.train.get_checkpoint_state(FLAGS.model_rl_dir)

    if ckpt:
        print("Reading model from %s, mode: %s" %
              (ckpt.model_checkpoint_path, mode))
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Create model with fresh parameters, mode: %s" % mode)
        session.run(tf.global_variables_initializer())

    return model