def model_fn(features, labels, mode, params): """doc.""" #### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) assert is_training #### Retrieve `mems` from `params["cache"]` mems = {} idx = 0 if FLAGS.mem_len > 0: #mems["mems"] = params["cache"] mems["mems"] = cache_fn() #### Get loss from inputs total_loss, new_mems, monitor_dict = function_builder.get_loss( FLAGS, features, labels, mems, is_training) #### Turn `new_mems` into `new_cache` new_cache = [] if FLAGS.mem_len > 0: new_cache += new_mems["mems"] #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: {}".format(num_params)) #### Configuring the optimizer train_op, learning_rate, gnorm = model_utils.get_train_op( FLAGS, total_loss) monitor_dict["lr"] = learning_rate monitor_dict["gnorm"] = gnorm #### Customized initial checkpoint scaffold_fn = model_utils.init_from_checkpoint(FLAGS, global_vars=True) #### Creating host calls host_call = function_builder.construct_scalar_host_call( monitor_dict=monitor_dict, model_dir=FLAGS.model_dir, prefix="train/", reduce_fn=tf.reduce_mean) #### Constucting training TPUEstimatorSpec with new cache. if FLAGS.use_tpu: train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) else: train_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op) train_spec.cache = new_cache return train_spec
def model_fn(features, labels, mems): #### Get loss from inputs total_loss, new_mems = function_builder.get_loss( FLAGS, features, labels, mems, False) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('#params: {}'.format(num_params)) # GPU #assert is_training return total_loss, new_mems
def model_fn(features, labels, mems, is_training): #### Get loss from inputs total_loss, new_mems, monitor_dict = function_builder.get_loss( FLAGS, features, labels, mems, is_training) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.compat.v1.logging.info('#params: {}'.format(num_params)) # GPU assert is_training all_vars = tf.trainable_variables() grads = tf.gradients(total_loss, all_vars) grads_and_vars = list(zip(grads, all_vars)) return total_loss, new_mems, grads_and_vars
def model_fn(features, labels, mode, params): """doc.""" #### Training or Evaluation is_training = mode == tf.estimator.ModeKeys.TRAIN assert is_training #### Retrieve `mems` from `params["cache"]` mems = {} idx = 0 if FLAGS.mem_len > 0: mems['mems'] = params['cache'] #### Get loss from inputs total_loss, new_mems, monitor_dict = function_builder.get_loss( FLAGS, features, labels, mems, is_training) #### Turn `new_mems` into `new_cache` new_cache = [] if FLAGS.mem_len > 0: new_cache += new_mems['mems'] #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info('#params: {}'.format(num_params)) #### Configuring the optimizer train_op, learning_rate, gnorm = model_utils.get_train_op( FLAGS, total_loss) monitor_dict['lr'] = learning_rate monitor_dict['gnorm'] = gnorm #### Customized initial checkpoint scaffold_fn = model_utils.init_from_checkpoint(FLAGS, global_vars=True) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold=scaffold_fn, ) return output_spec
def model_fn(features, labels, mode, params): """doc.""" #### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) #assert is_training assert tf.gfile.Exists(logdir) #### Retrieve `mems` from `params["cache"]` mems = {} idx = 0 if FLAGS.mem_len > 0: mems["mems"] = params["cache"] #### Get loss from inputs if is_training: total_loss, new_mems, monitor_dict = function_builder.get_loss( FLAGS, features, labels, mems, is_training) else: total_loss, batch_loss, batch_tgt_mask, new_mems = function_builder.get_loss( FLAGS, features, labels, mems, is_training) #### Turn `new_mems` into `new_cache` new_cache = [] if FLAGS.mem_len > 0: new_cache += new_mems["mems"] #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: {}".format(num_params)) #### Customized initial checkpoint scaffold_fn = model_utils.init_from_checkpoint(FLAGS, global_vars=True) if is_training: #### Configuring the optimizer train_op, learning_rate, gnorm = model_utils.get_train_op( FLAGS, total_loss, None) monitor_dict["gnorm"] = gnorm monitor_dict["lr"] = learning_rate monitor_dict['pplx'] = tf.math.exp(total_loss) ''' #### Creating host calls host_call = function_builder.construct_scalar_host_call( monitor_dict=monitor_dict, log_dir=logdir, prefix="train/", reduce_fn=tf.reduce_mean) ''' #### Constucting training TPUEstimatorSpec with new cache. train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, #host_call=host_call, scaffold_fn=scaffold_fn) train_spec.cache = new_cache return train_spec else: #### Constucting validation TPUEstimatorSpec with new cache. eval_metrics = function_builder.construct_scalar_metric_fn( batch_loss, batch_tgt_mask) eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) eval_spec.cache = new_cache return eval_spec