Esempio n. 1
0
    def model_fn(features, labels, mode, params):
        """doc."""
        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        #### Retrieve `mems` from `params["cache"]`
        mems = {}
        idx = 0
        for obj_len, key in zip([FLAGS.seq_len], ["mems"]):
            if obj_len > 0 and FLAGS.mem_len > 0:
                n_layer = FLAGS.n_layer
                if FLAGS.use_extra_layer:
                    n_layer += 1
                mems[key] = params["cache"][idx * n_layer:(idx + 1) * n_layer]
                idx += 1

        #### Get loss from inputs
        if FLAGS.loss_type == "electra":
            total_loss, new_mems, monitor_dict = model_func_builder.electra_loss(
                features, labels, mems, n_token, is_training)
        elif FLAGS.loss_type == "mlm":
            total_loss, new_mems, monitor_dict = model_func_builder.mlm_loss(
                features, labels, mems, n_token, is_training)
        elif FLAGS.loss_type == "xlnet":
            total_loss, new_mems, monitor_dict = model_func_builder.xlnet_loss(
                features, labels, mems, n_token, is_training)
        else:
            raise NotImplementedError

        #### Turn `new_mems` into `new_cache`
        new_cache = []
        for obj_len, key in zip([FLAGS.seq_len], ["mems"]):
            if obj_len > 0 and FLAGS.mem_len > 0:
                new_cache += new_mems[key]

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: %d", num_params)

        if FLAGS.verbose:
            format_str = "{{:<{0}s}}\t{{}}".format(
                max([len(v.name) for v in tf.trainable_variables()]))
            for v in tf.trainable_variables():
                tf.logging.info(format_str.format(v.name, v.get_shape()))

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            #### Reduce sum losses from all TPU cores
            with tf.colocate_with(total_loss):
                total_loss = tf.contrib.tpu.cross_replica_sum(total_loss)
                total_loss = total_loss / FLAGS.num_hosts / FLAGS.num_core_per_host
            metric_loss = tf.reshape(total_loss, [1])

            #### Constructing evaluation TPUEstimatorSpec with new cache.
            eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=(metric_fn, [metric_loss]))

            eval_spec.cache = new_cache

            return eval_spec

        #### Get the train op
        train_op, optim_dict = optimization.get_train_op(total_loss)
        monitor_dict.update(optim_dict)

        #### Customized initial checkpoint
        tvars = tf.global_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if FLAGS.init_checkpoint is not None:
            if FLAGS.init_checkpoint.endswith("latest"):
                ckpt_dir = os.path.dirname(FLAGS.init_checkpoint)
                init_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
            else:
                init_checkpoint = FLAGS.init_checkpoint

            tf.logging.info("Initialize from the ckpt %s", init_checkpoint)

            (assignment_map, initialized_variable_names
             ) = model_utils.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if FLAGS.use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

            # Log customized initialization
            tf.logging.info("**** Global Variables ****")
            for var in tvars:
                init_string = ""
                if var.name in initialized_variable_names:
                    init_string = ", *INIT_FROM_CKPT*"
                tf.logging.info("  name = %s, shape = %s%s", var.name,
                                var.shape, init_string)

        #### Creating host calls
        host_call = model_utils.construct_scalar_host_call(
            monitor_dict=monitor_dict,
            model_dir=FLAGS.model_dir,
            prefix="train/",
            reduce_fn=tf.reduce_mean,
            log_freq=FLAGS.log_freq)

        #### Constructing training TPUEstimatorSpec with new cache.
        train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op,
                                                     host_call=host_call,
                                                     scaffold_fn=scaffold_fn)

        train_spec.cache = new_cache

        return train_spec
Esempio n. 2
0
    def model_fn(features, labels, mode, params):
        """Actual model function."""
        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        #### Get loss from inputs
        if FLAGS.seq2seq_type == "dec_only":
            seq2seq_loss = model_func_builder.joint_loss
        elif FLAGS.seq2seq_type == "encdec":
            seq2seq_loss = model_func_builder.encdec_loss
        else:
            raise NotImplementedError
        total_loss, monitor_dict = seq2seq_loss(features, labels, n_token,
                                                is_training)

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: %d", num_params)

        if FLAGS.verbose:
            format_str = "{{:<{0}s}}\t{{}}".format(
                max([len(v.name) for v in tf.trainable_variables()]))
            for v in tf.trainable_variables():
                tf.logging.info(format_str.format(v.name, v.get_shape()))

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            #### Reduce sum losses from all TPU cores
            with tf.colocate_with(total_loss):
                total_loss = tf.contrib.tpu.cross_replica_sum(total_loss)
                total_loss = total_loss / FLAGS.num_hosts / FLAGS.num_core_per_host
            metric_loss = tf.reshape(total_loss, [1])

            #### Constructing evaluation TPUEstimatorSpec.
            eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=(metric_fn, [metric_loss]))

            return eval_spec

        #### Customized initial checkpoint
        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if FLAGS.init_checkpoint:
            if FLAGS.init_checkpoint.endswith("latest"):
                ckpt_dir = os.path.dirname(FLAGS.init_checkpoint)
                init_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
            else:
                init_checkpoint = FLAGS.init_checkpoint

            tf.logging.info("Initialize from the ckpt %s", init_checkpoint)

            (assignment_map, initialized_variable_names
             ) = model_utils.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if FLAGS.use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

            # Log customized initialization
            tf.logging.info("**** Global Variables ****")
            for var in tvars:
                init_string = ""
                if var.name in initialized_variable_names:
                    init_string = ", *INIT_FROM_CKPT*"
                tf.logging.info("  name = %s, shape = %s%s", var.name,
                                var.shape, init_string)

        #### Get the train op
        train_op, optim_dict = optimization.get_train_op(total_loss)
        monitor_dict.update(optim_dict)

        #### Creating host calls
        host_call = model_utils.construct_scalar_host_call(
            monitor_dict=monitor_dict,
            model_dir=FLAGS.model_dir,
            prefix="train/",
            reduce_fn=tf.reduce_mean,
            log_freq=FLAGS.log_freq)

        #### Constructing training TPUEstimatorSpec.
        train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op,
                                                     host_call=host_call,
                                                     scaffold_fn=scaffold_fn)

        return train_spec
Esempio n. 3
0
    def model_fn(features, labels, mode, params):
        """doc."""
        # not used
        del labels

        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        #### Retrieve `mems` from `params["cache"]`
        mems = None
        if FLAGS.mem_len > 0:
            mems = params["cache"]

        #### Get loss from inputs
        total_loss, new_mems, monitor_dict = model_func_builder.get_lm_loss(
            features, mems, n_token, is_training)

        #### Put `new_mems` into `new_cache`
        new_cache = []
        if FLAGS.mem_len > 0:
            new_cache += new_mems

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: %d", num_params)

        if mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(loss):
                """Evaluation metric Fn which runs on CPU."""
                perplexity = tf.exp(tf.reduce_mean(loss) * 1.2)
                return {
                    "perplexity": tf.metrics.mean(perplexity),
                }

            metric_loss = tf.reshape(total_loss, [1, 1])
            eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=(metric_fn, [metric_loss]))

            eval_spec.cache = new_cache

            return eval_spec

        #### Configuring the optimizer
        train_op, optim_dict = optimization.get_train_op(total_loss)
        monitor_dict.update(optim_dict)

        #### Customized initial checkpoint
        tvars = tf.global_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if FLAGS.init_checkpoint is not None:
            if FLAGS.init_checkpoint.endswith("latest"):
                ckpt_dir = os.path.dirname(FLAGS.init_checkpoint)
                init_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
            else:
                init_checkpoint = FLAGS.init_checkpoint

            tf.logging.info("Initialize from the ckpt %s", init_checkpoint)

            (assignment_map, initialized_variable_names
             ) = model_utils.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if FLAGS.use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

            # Log customized initialization
            tf.logging.info("**** Global Variables ****")
            for var in tvars:
                init_string = ""
                if var.name in initialized_variable_names:
                    init_string = ", *INIT_FROM_CKPT*"
                tf.logging.info("  name = %s, shape = %s%s", var.name,
                                var.shape, init_string)

        #### Creating host calls
        host_call = model_utils.construct_scalar_host_call(
            monitor_dict=monitor_dict,
            model_dir=FLAGS.model_dir,
            prefix="train/",
            reduce_fn=tf.reduce_mean,
            log_freq=FLAGS.log_freq)

        #### Constructing training TPUEstimatorSpec with new cache.
        train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op,
                                                     host_call=host_call,
                                                     scaffold_fn=scaffold_fn)

        train_spec.cache = new_cache

        return train_spec
Esempio n. 4
0
    def model_fn(features, labels, mode, params):
        """doc."""
        # not used
        del labels

        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        #### Retrieve `mems` from `params["cache"]`
        mems = None
        if FLAGS.mem_len > 0:
            mems = params["cache"]

        #### Get loss from inputs
        total_loss, new_mems, monitor_dict = model_function.get_lm_loss(
            features, mems, is_training)

        #### Put `new_mems` into `new_cache`
        new_cache = []
        if FLAGS.mem_len > 0:
            new_cache += new_mems

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: %d", num_params)

        if mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(loss):
                """Evaluation metric Fn which runs on CPU."""
                perplexity = tf.exp(tf.reduce_mean(loss) * 1.2)
                return {
                    "perplexity": tf.metrics.mean(perplexity),
                }

            metric_loss = tf.reshape(total_loss, [1, 1])
            eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=(metric_fn, [metric_loss]))

            eval_spec.cache = new_cache

            return eval_spec

        #### Configuring the optimizer
        train_op, optim_dict = optimization.get_train_op(total_loss)
        monitor_dict.update(optim_dict)

        #### Customized initial checkpoint
        scaffold_fn = model_utils.init_from_checkpoint(global_vars=True)

        #### Creating host calls
        host_call = model_function.construct_scalar_host_call(
            monitor_dict=monitor_dict,
            model_dir=FLAGS.model_dir,
            prefix="train/",
            reduce_fn=tf.reduce_mean)

        #### Constructing training TPUEstimatorSpec with new cache.
        train_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op,
                                                     host_call=host_call,
                                                     scaffold_fn=scaffold_fn)

        train_spec.cache = new_cache

        return train_spec