def create_model(session, gen_config, vocab_size, forward_only, name_scope, initializer=None): """Create translation model and initialize or load parameters in session.""" with tf.variable_scope(name_or_scope=name_scope, initializer=initializer): model = seq2seq_model.Seq2SeqModel(gen_config, vocab_size=vocab_size, name_scope=name_scope, forward_only=forward_only) gen_ckpt_dir = os.path.abspath( os.path.join(gen_config.train_dir, "checkpoints")) ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir) if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): print("Reading Gen model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) else: print("Created Gen model with fresh parameters.") gen_global_variables = [ gv for gv in tf.global_variables() if name_scope in gv.name ] session.run(tf.variables_initializer(gen_global_variables)) return model
def create_model(session, gen_config, initializer=None, name="gen_model"): """Create translation model and initialize or load parameters in session.""" with tf.variable_scope(name_or_scope=name, initializer=initializer): model = seq2seq_model.Seq2SeqModel(gen_config) gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.data_dir, "checkpoints")) ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir) if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): print("Reading Gen model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) else: print("Created Gen model with fresh parameters.") session.run(tf.global_variables_initializer()) return model
def create_model(session, gen_config, forward_only, name_scope, word2id, initializer=None): """Create translation model and initialize or load parameters in session.""" with tf.variable_scope(name_or_scope=name_scope, initializer=initializer): model = seq2seq_model.Seq2SeqModel(gen_config, name_scope=name_scope, forward_only=forward_only) if not gen_config.adv: gen_ckpt_dir = os.path.abspath( os.path.join(gen_config.data_dir, 'gen_model', "checkpoints")) else: gen_ckpt_dir = os.path.abspath( os.path.join( gen_config.model_dir, 'gen_model', "data-{}_pre_embed-{}_exp-{}".format( gen_config.data_id, gen_config.pre_embed, 1))) # gen_config.continue_train==True will overwrite the previous gen_ckpt_dir and continue adv-training if gen_config.continue_train: gen_ckpt_dir = os.path.abspath( os.path.join( gen_config.model_dir, 'gen_model', "data-{}_pre_embed-{}_ent-{}_exp-{}_teacher-{}".format( gen_config.data_id, gen_config.pre_embed, gen_config.ent_weight, gen_config.exp_id, gen_config.teacher_forcing))) print("check model path: %s" % gen_ckpt_dir) ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir) if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): print("Reading Gen model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) load_embeddings_generator(session, name_scope, word2id, gen_config.word_embedding, True) # reset_lr = model.learning_rate.assign(gen_config.learning_rate) # session.run(reset_lr) else: print("Create Gen model with fresh parameters.") gen_global_variables = [ gv for gv in tf.global_variables() if name_scope in gv.name ] session.run(tf.variables_initializer(gen_global_variables)) print("Finished Creating Gen model with fresh parameters.") if gen_config.pre_embed: load_embeddings_generator(session, name_scope, word2id, gen_config.word_embedding, False) return model
def create_model(session, gen_config, forward_only): """Create translation model and initialize or load parameters in session.""" model = seq2seq_model.Seq2SeqModel( gen_config.vocab_size, gen_config.vocab_size, _buckets, gen_config.size, gen_config.num_layers, gen_config.max_gradient_norm, gen_config.batch_size, gen_config.learning_rate, gen_config.learning_rate_decay_factor, forward_only=forward_only) ckpt = tf.train.get_checkpoint_state(gen_config.train_dir) if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): print("Reading model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) else: print("Created Gen_RNN model with fresh parameters.") session.run(tf.global_variables_initializer()) return model
def create_model(session, gen_config, forward_only, name_scope, initializer=None): """Create translation model and initialize or load parameters in session. if model exist then reuse it """ print('its ok here ===================================================================================') with tf.variable_scope(name_or_scope=name_scope, initializer=initializer): model = seq2seq_model.Seq2SeqModel(gen_config, name_scope=name_scope, forward_only=forward_only) print('its Fine here ===================================================================================') gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.train_dir, "checkpoints")) print('its good here ===================================================================================') ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir) print('its all ok here ===================================================================================') #if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): print("Reading Gen model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) # else: # print("Created Gen model with fresh parameters.") # gen_global_variables = [gv for gv in tf.global_variables() if name_scope in gv.name] # 只是因為在 adversrial 時候要看是去 initalize 哪一個 model # session.run(tf.variables_initializer(gen_global_variables)) return model