Ejemplo n.º 1
0
def create_model(session,
                 gen_config,
                 vocab_size,
                 forward_only,
                 name_scope,
                 initializer=None):
    """Create translation model and initialize or load parameters in session."""
    with tf.variable_scope(name_or_scope=name_scope, initializer=initializer):
        model = seq2seq_model.Seq2SeqModel(gen_config,
                                           vocab_size=vocab_size,
                                           name_scope=name_scope,
                                           forward_only=forward_only)
        gen_ckpt_dir = os.path.abspath(
            os.path.join(gen_config.train_dir, "checkpoints"))
        ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print("Reading Gen model parameters from %s" %
                  ckpt.model_checkpoint_path)
            model.saver.restore(session, ckpt.model_checkpoint_path)
        else:
            print("Created Gen model with fresh parameters.")
            gen_global_variables = [
                gv for gv in tf.global_variables() if name_scope in gv.name
            ]
            session.run(tf.variables_initializer(gen_global_variables))
        return model
Ejemplo n.º 2
0
def create_model(session, gen_config, initializer=None, name="gen_model"):
    """Create translation model and initialize or load parameters in session."""
    with tf.variable_scope(name_or_scope=name, initializer=initializer):
        model = seq2seq_model.Seq2SeqModel(gen_config)
        gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.data_dir, "checkpoints"))
        ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print("Reading Gen model parameters from %s" % ckpt.model_checkpoint_path)
            model.saver.restore(session, ckpt.model_checkpoint_path)
        else:
            print("Created Gen model with fresh parameters.")
            session.run(tf.global_variables_initializer())
        return model
Ejemplo n.º 3
0
def create_model(session,
                 gen_config,
                 forward_only,
                 name_scope,
                 word2id,
                 initializer=None):
    """Create translation model and initialize or load parameters in session."""
    with tf.variable_scope(name_or_scope=name_scope, initializer=initializer):
        model = seq2seq_model.Seq2SeqModel(gen_config,
                                           name_scope=name_scope,
                                           forward_only=forward_only)
        if not gen_config.adv:
            gen_ckpt_dir = os.path.abspath(
                os.path.join(gen_config.data_dir, 'gen_model', "checkpoints"))
        else:
            gen_ckpt_dir = os.path.abspath(
                os.path.join(
                    gen_config.model_dir, 'gen_model',
                    "data-{}_pre_embed-{}_exp-{}".format(
                        gen_config.data_id, gen_config.pre_embed, 1)))
        # gen_config.continue_train==True will overwrite the previous gen_ckpt_dir and continue adv-training
        if gen_config.continue_train:
            gen_ckpt_dir = os.path.abspath(
                os.path.join(
                    gen_config.model_dir, 'gen_model',
                    "data-{}_pre_embed-{}_ent-{}_exp-{}_teacher-{}".format(
                        gen_config.data_id, gen_config.pre_embed,
                        gen_config.ent_weight, gen_config.exp_id,
                        gen_config.teacher_forcing)))

        print("check model path: %s" % gen_ckpt_dir)
        ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print("Reading Gen model parameters from %s" %
                  ckpt.model_checkpoint_path)
            model.saver.restore(session, ckpt.model_checkpoint_path)
            load_embeddings_generator(session, name_scope, word2id,
                                      gen_config.word_embedding, True)
            # reset_lr = model.learning_rate.assign(gen_config.learning_rate)
            # session.run(reset_lr)
        else:
            print("Create Gen model with fresh parameters.")
            gen_global_variables = [
                gv for gv in tf.global_variables() if name_scope in gv.name
            ]
            session.run(tf.variables_initializer(gen_global_variables))
            print("Finished Creating Gen model with fresh parameters.")
            if gen_config.pre_embed:
                load_embeddings_generator(session, name_scope, word2id,
                                          gen_config.word_embedding, False)
        return model
def create_model(session, gen_config, forward_only):
    """Create translation model and initialize or load parameters in session."""
    model = seq2seq_model.Seq2SeqModel(
      gen_config.vocab_size, gen_config.vocab_size, _buckets,
      gen_config.size, gen_config.num_layers, gen_config.max_gradient_norm, gen_config.batch_size,
      gen_config.learning_rate, gen_config.learning_rate_decay_factor, forward_only=forward_only)

    ckpt = tf.train.get_checkpoint_state(gen_config.train_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created Gen_RNN model with fresh parameters.")
        session.run(tf.global_variables_initializer())
    return model
Ejemplo n.º 5
0
def create_model(session, gen_config, forward_only, name_scope, initializer=None):
    """Create translation model and initialize or load parameters in session.
        if model exist then reuse it
    """
    print('its ok here ===================================================================================')
    
    with tf.variable_scope(name_or_scope=name_scope, initializer=initializer):
        model = seq2seq_model.Seq2SeqModel(gen_config,  name_scope=name_scope, forward_only=forward_only)
        print('its Fine here ===================================================================================')
        gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.train_dir, "checkpoints"))
        print('its good here ===================================================================================')
        ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir)
        print('its all ok here ===================================================================================')
        
        #if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print("Reading Gen model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
#        else:
#            print("Created Gen model with fresh parameters.")
#            gen_global_variables = [gv for gv in tf.global_variables() if name_scope in gv.name] # 只是因為在 adversrial 時候要看是去 initalize 哪一個 model
#            session.run(tf.variables_initializer(gen_global_variables))
        return model