def train(config): logger = logging.getLogger('') """Train a model with a config file.""" data_reader = DataReader(config=config) model = eval(config.model)(config=config, num_gpus=config.train.num_gpus) model.build_train_model(test=config.train.eval_on_dev) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True summary_writer = tf.summary.FileWriter(config.model_dir, graph=model.graph) with tf.Session(config=sess_config, graph=model.graph) as sess: # Initialize all variables. sess.run(tf.global_variables_initializer()) # Reload variables in disk. if tf.train.latest_checkpoint(config.model_dir): available_vars = available_variables(config.model_dir) if available_vars: saver = tf.train.Saver(var_list=available_vars) saver.restore(sess, tf.train.latest_checkpoint(config.model_dir)) for v in available_vars: logger.info('Reload {} from disk.'.format(v.name)) else: logger.info('Nothing to be reload from disk.') else: logger.info('Nothing to be reload from disk.') evaluator = Evaluator() evaluator.init_from_existed(model, sess, data_reader) global dev_bleu, toleration dev_bleu = evaluator.evaluate( **config.dev) if config.train.eval_on_dev else 0 toleration = config.train.toleration def train_one_step(batch): feat_batch, target_batch, batch_size = batch feed_dict = expand_feed_dict({ model.src_pls: feat_batch, model.dst_pls: target_batch }) step, lr, loss, _ = sess.run([ model.global_step, model.learning_rate, model.loss, model.train_op ], feed_dict=feed_dict) if step % config.train.summary_freq == 0: summary = sess.run(model.summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary, global_step=step) return step, lr, loss def maybe_save_model(): global dev_bleu, toleration new_dev_bleu = evaluator.evaluate( **config.dev) if config.train.eval_on_dev else dev_bleu + 1 if new_dev_bleu >= dev_bleu: mp = config.model_dir + '/model_step_{}'.format(step) model.saver.save(sess, mp) logger.info('Save model in %s.' % mp) toleration = config.train.toleration dev_bleu = new_dev_bleu else: toleration -= 1 step = 0 for epoch in range(1, config.train.num_epochs + 1): for batch in data_reader.get_training_batches_with_buckets(): # Train normal instances. start_time = time.time() step, lr, loss = train_one_step(batch) logger.info( 'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}\tbatch_size: {5}' .format(epoch, step, lr, loss, time.time() - start_time, batch[2])) # Save model if config.train.save_freq > 0 and step % config.train.save_freq == 0: maybe_save_model() if config.train.num_steps and step >= config.train.num_steps: break # Save model per epoch if config.train.save_freq is less or equal than zero if config.train.save_freq <= 0: maybe_save_model() # Early stop if toleration <= 0: break logger.info("Finish training.")
def train(config, num_epoch, last_pretrain_model_dir, pretrain_model_dir, model_dir, block_idx_enc, block_idx_dec): logger = logging.getLogger('') config.num_blocks_enc = block_idx_enc config.num_blocks_dec = block_idx_dec # if block_idx >= 2: # config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str( # block_idx - 1) + '|' + 'encoder/src_embedding' + '|' + 'decoder/dst_embedding' # if block_idx >= 2: # config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str( # block_idx - 1) logger.info("config.num_blocks_enc=" + str(config.num_blocks_enc) + ",config.num_blocks_dec=" + str(config.num_blocks_dec) + ',config.train.var_filter=' + str(config.train.var_filter)) """Train a model with a config file.""" data_reader = DataReader(config=config) model = eval(config.model)(config=config, num_gpus=config.train.num_gpus) model.build_train_model(test=config.train.eval_on_dev) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True summary_writer = tf.summary.FileWriter(pretrain_model_dir, graph=model.graph) with tf.Session(config=sess_config, graph=model.graph) as sess: # Initialize all variables. sess.run(tf.global_variables_initializer()) # Reload variables in disk. if tf.train.latest_checkpoint(last_pretrain_model_dir): available_vars = available_variables_without_global_step( last_pretrain_model_dir) # available_vars = available_variables(last_pretrain_model_dir) if available_vars: saver = tf.train.Saver(var_list=available_vars) saver.restore( sess, tf.train.latest_checkpoint(last_pretrain_model_dir)) for v in available_vars: logger.info('Reload {} from disk.'.format(v.name)) else: logger.info('Nothing to be reload from disk.') else: logger.info('Nothing to be reload from disk.') evaluator = Evaluator() evaluator.init_from_existed(model, sess, data_reader) global dev_bleu, toleration dev_bleu = evaluator.evaluate( **config.dev) if config.train.eval_on_dev else 0 toleration = config.train.toleration def train_one_step(batch): feat_batch, target_batch = batch feed_dict = expand_feed_dict({ model.src_pls: feat_batch, model.dst_pls: target_batch }) step, lr, loss, _ = sess.run([ model.global_step, model.learning_rate, model.loss, model.train_op ], feed_dict=feed_dict) if step % config.train.summary_freq == 0: logger.info('pretrain summary_writer...') summary = sess.run(model.summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary, global_step=step) summary_writer.flush() return step, lr, loss def maybe_save_model(model_dir, is_save_global_step=True): global dev_bleu, toleration new_dev_bleu = evaluator.evaluate( **config.dev) if config.train.eval_on_dev else dev_bleu + 1 if new_dev_bleu >= dev_bleu: mp = model_dir + '/pretrain_model_step_{}'.format(step) # model.saver.save(sess, mp) if is_save_global_step: model.saver.save(sess, mp) else: variables_without_global_step = global_variables_without_global_step( ) saver = tf.train.Saver( var_list=variables_without_global_step, max_to_keep=10) saver.save(sess, mp) logger.info('Save model in %s.' % mp) toleration = config.train.toleration dev_bleu = new_dev_bleu else: toleration -= 1 step = 0 for epoch in range(1, num_epoch + 1): for batch in data_reader.get_training_batches_with_buckets(): # Train normal instances. start_time = time.time() step, lr, loss = train_one_step(batch) logger.info( 'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}' .format(epoch, step, lr, loss, time.time() - start_time)) if config.train.num_steps and step >= config.train.num_steps: break # Early stop if toleration <= 0: break maybe_save_model(pretrain_model_dir) if model_dir: maybe_save_model(model_dir, False) logger.info("Finish pretrain block_idx_enc=" + str(block_idx_enc) + ',block_idx_dec=' + str(block_idx_dec))