def model_fn(features, labels, mode, params): """doc.""" #### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) #### Retrieve `mems` from `params["cache"]` mems = {} idx = 0 for obj_len, key in zip([FLAGS.seq_len], ["mems"]): if obj_len > 0 and FLAGS.mem_len > 0: n_layer = FLAGS.n_layer if FLAGS.use_extra_layer: n_layer += 1 mems[key] = params["cache"][idx * n_layer:(idx + 1) * n_layer] idx += 1 #### Get loss from inputs if FLAGS.loss_type == "electra": total_loss, new_mems, monitor_dict = model_func_builder.electra_loss( features, labels, mems, n_token, is_training) elif FLAGS.loss_type == "mlm": total_loss, new_mems, monitor_dict = model_func_builder.mlm_loss( features, labels, mems, n_token, is_training) elif FLAGS.loss_type == "xlnet": total_loss, new_mems, monitor_dict = model_func_builder.xlnet_loss( features, labels, mems, n_token, is_training) else: raise NotImplementedError #### Turn `new_mems` into `new_cache` new_cache = [] for obj_len, key in zip([FLAGS.seq_len], ["mems"]): if obj_len > 0 and FLAGS.mem_len > 0: new_cache += new_mems[key] #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: %d", num_params) if FLAGS.verbose: format_str = "{{:<{0}s}}\t{{}}".format( max([len(v.name) for v in tf.trainable_variables()])) for v in tf.trainable_variables(): tf.logging.info(format_str.format(v.name, v.get_shape())) #### Evaluation mode if mode == tf.estimator.ModeKeys.EVAL: #### Reduce sum losses from all TPU cores with tf.colocate_with(total_loss): total_loss = tf.contrib.tpu.cross_replica_sum(total_loss) total_loss = total_loss / FLAGS.num_hosts / FLAGS.num_core_per_host metric_loss = tf.reshape(total_loss, [1]) #### Constructing evaluation TPUEstimatorSpec with new cache. eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=(metric_fn, [metric_loss])) eval_spec.cache = new_cache return eval_spec #### Get the train op train_op, optim_dict = optimization.get_train_op(total_loss) monitor_dict.update(optim_dict) #### Customized initial checkpoint tvars = tf.global_variables() initialized_variable_names = {} scaffold_fn = None if FLAGS.init_checkpoint is not None: if FLAGS.init_checkpoint.endswith("latest"): ckpt_dir = os.path.dirname(FLAGS.init_checkpoint) init_checkpoint = tf.train.latest_checkpoint(ckpt_dir) else: init_checkpoint = FLAGS.init_checkpoint tf.logging.info("Initialize from the ckpt %s", init_checkpoint) (assignment_map, initialized_variable_names ) = model_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if FLAGS.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # Log customized initialization tf.logging.info("**** Global Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) #### Creating host calls host_call = model_utils.construct_scalar_host_call( monitor_dict=monitor_dict, model_dir=FLAGS.model_dir, prefix="train/", reduce_fn=tf.reduce_mean, log_freq=FLAGS.log_freq) #### Constructing training TPUEstimatorSpec with new cache. train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) train_spec.cache = new_cache return train_spec
def model_fn(features, labels, mode, params): """Actual model function.""" #### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) #### Get loss from inputs if FLAGS.seq2seq_type == "dec_only": seq2seq_loss = model_func_builder.joint_loss elif FLAGS.seq2seq_type == "encdec": seq2seq_loss = model_func_builder.encdec_loss else: raise NotImplementedError total_loss, monitor_dict = seq2seq_loss(features, labels, n_token, is_training) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: %d", num_params) if FLAGS.verbose: format_str = "{{:<{0}s}}\t{{}}".format( max([len(v.name) for v in tf.trainable_variables()])) for v in tf.trainable_variables(): tf.logging.info(format_str.format(v.name, v.get_shape())) #### Evaluation mode if mode == tf.estimator.ModeKeys.EVAL: #### Reduce sum losses from all TPU cores with tf.colocate_with(total_loss): total_loss = tf.contrib.tpu.cross_replica_sum(total_loss) total_loss = total_loss / FLAGS.num_hosts / FLAGS.num_core_per_host metric_loss = tf.reshape(total_loss, [1]) #### Constructing evaluation TPUEstimatorSpec. eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=(metric_fn, [metric_loss])) return eval_spec #### Customized initial checkpoint tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if FLAGS.init_checkpoint: if FLAGS.init_checkpoint.endswith("latest"): ckpt_dir = os.path.dirname(FLAGS.init_checkpoint) init_checkpoint = tf.train.latest_checkpoint(ckpt_dir) else: init_checkpoint = FLAGS.init_checkpoint tf.logging.info("Initialize from the ckpt %s", init_checkpoint) (assignment_map, initialized_variable_names ) = model_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if FLAGS.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # Log customized initialization tf.logging.info("**** Global Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) #### Get the train op train_op, optim_dict = optimization.get_train_op(total_loss) monitor_dict.update(optim_dict) #### Creating host calls host_call = model_utils.construct_scalar_host_call( monitor_dict=monitor_dict, model_dir=FLAGS.model_dir, prefix="train/", reduce_fn=tf.reduce_mean, log_freq=FLAGS.log_freq) #### Constructing training TPUEstimatorSpec. train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) return train_spec
def model_fn(features, labels, mode, params): """doc.""" # not used del labels #### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) #### Retrieve `mems` from `params["cache"]` mems = None if FLAGS.mem_len > 0: mems = params["cache"] #### Get loss from inputs total_loss, new_mems, monitor_dict = model_func_builder.get_lm_loss( features, mems, n_token, is_training) #### Put `new_mems` into `new_cache` new_cache = [] if FLAGS.mem_len > 0: new_cache += new_mems #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: %d", num_params) if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(loss): """Evaluation metric Fn which runs on CPU.""" perplexity = tf.exp(tf.reduce_mean(loss) * 1.2) return { "perplexity": tf.metrics.mean(perplexity), } metric_loss = tf.reshape(total_loss, [1, 1]) eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=(metric_fn, [metric_loss])) eval_spec.cache = new_cache return eval_spec #### Configuring the optimizer train_op, optim_dict = optimization.get_train_op(total_loss) monitor_dict.update(optim_dict) #### Customized initial checkpoint tvars = tf.global_variables() initialized_variable_names = {} scaffold_fn = None if FLAGS.init_checkpoint is not None: if FLAGS.init_checkpoint.endswith("latest"): ckpt_dir = os.path.dirname(FLAGS.init_checkpoint) init_checkpoint = tf.train.latest_checkpoint(ckpt_dir) else: init_checkpoint = FLAGS.init_checkpoint tf.logging.info("Initialize from the ckpt %s", init_checkpoint) (assignment_map, initialized_variable_names ) = model_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if FLAGS.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # Log customized initialization tf.logging.info("**** Global Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) #### Creating host calls host_call = model_utils.construct_scalar_host_call( monitor_dict=monitor_dict, model_dir=FLAGS.model_dir, prefix="train/", reduce_fn=tf.reduce_mean, log_freq=FLAGS.log_freq) #### Constructing training TPUEstimatorSpec with new cache. train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) train_spec.cache = new_cache return train_spec
def model_fn(features, labels, mode, params): """doc.""" # not used del labels #### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) #### Retrieve `mems` from `params["cache"]` mems = None if FLAGS.mem_len > 0: mems = params["cache"] #### Get loss from inputs total_loss, new_mems, monitor_dict = model_function.get_lm_loss( features, mems, is_training) #### Put `new_mems` into `new_cache` new_cache = [] if FLAGS.mem_len > 0: new_cache += new_mems #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: %d", num_params) if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(loss): """Evaluation metric Fn which runs on CPU.""" perplexity = tf.exp(tf.reduce_mean(loss) * 1.2) return { "perplexity": tf.metrics.mean(perplexity), } metric_loss = tf.reshape(total_loss, [1, 1]) eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=(metric_fn, [metric_loss])) eval_spec.cache = new_cache return eval_spec #### Configuring the optimizer train_op, optim_dict = optimization.get_train_op(total_loss) monitor_dict.update(optim_dict) #### Customized initial checkpoint scaffold_fn = model_utils.init_from_checkpoint(global_vars=True) #### Creating host calls host_call = model_function.construct_scalar_host_call( monitor_dict=monitor_dict, model_dir=FLAGS.model_dir, prefix="train/", reduce_fn=tf.reduce_mean) #### Constructing training TPUEstimatorSpec with new cache. train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) train_spec.cache = new_cache return train_spec