Esempio n. 1
0
  def _build_train_spec(self, checkpoint_path):
    train_hooks = [
        hooks.LogParametersCountHook()]

    if checkpoint_path is not None:
      train_hooks.append(hooks.LoadWeightsFromCheckpointHook(checkpoint_path))
    if self._hvd is not None:
      train_hooks.append(self._hvd.BroadcastGlobalVariablesHook(0))

    train_steps = self._config["train"].get("train_steps")
    if train_steps is not None and self._hvd is not None:
      train_steps //= self._hvd.size()
    train_spec = tf.estimator.TrainSpec(
        input_fn=estimator_util.make_input_fn(
            self._model,
            tf.estimator.ModeKeys.TRAIN,
            self._config["train"]["batch_size"],
            features_file=self._config["data"]["train_features_file"],
            labels_file=self._config["data"].get("train_labels_file"),
            batch_type=self._config["train"]["batch_type"],
            batch_multiplier=self._num_devices,
            bucket_width=self._config["train"]["bucket_width"],
            maximum_features_length=self._config["train"].get("maximum_features_length"),
            maximum_labels_length=self._config["train"].get("maximum_labels_length"),
            shuffle_buffer_size=self._config["train"]["sample_buffer_size"],
            single_pass=self._config["train"].get("single_pass", False),
            num_shards=self._hvd.size() if self._hvd is not None else 1,
            shard_index=self._hvd.rank() if self._hvd is not None else 0,
            num_threads=self._config["train"].get("num_threads"),
            prefetch_buffer_size=self._config["train"].get("prefetch_buffer_size"),
            return_dataset=False),
        max_steps=train_steps,
        hooks=train_hooks)
    return train_spec
Esempio n. 2
0
  def _build_train_spec(self, checkpoint_path):
    train_hooks = [
        hooks.LogParametersCountHook()]

    if checkpoint_path is not None:
      train_hooks.append(hooks.LoadWeightsFromCheckpointHook(checkpoint_path))

    train_spec = tf.estimator.TrainSpec(
        input_fn=self._model.input_fn(
            tf.estimator.ModeKeys.TRAIN,
            self._config["train"]["batch_size"],
            self._config["data"],
            self._config["data"]["train_features_file"],
            labels_file=self._config["data"]["train_labels_file"],
            batch_type=self._config["train"]["batch_type"],
            batch_multiplier=self._num_devices,
            bucket_width=self._config["train"]["bucket_width"],
            single_pass=self._config["train"].get("single_pass", False),
            num_threads=self._config["train"].get("num_threads"),
            sample_buffer_size=self._config["train"]["sample_buffer_size"],
            prefetch_buffer_size=self._config["train"].get("prefetch_buffer_size"),
            maximum_features_length=self._config["train"].get("maximum_features_length"),
            maximum_labels_length=self._config["train"].get("maximum_labels_length")),
        max_steps=self._config["train"].get("train_steps"),
        hooks=train_hooks)
    return train_spec
Esempio n. 3
0
    def _build_train_spec(self, checkpoint_path):
        train_hooks = [hooks.LogParametersCountHook()]
        #if checkpoint_path is not None:
        #  train_hooks.append(hooks.LoadWeightsFromCheckpointHook(checkpoint_path))

        # NEW: loads params based on config.yml ["load_weights"] - see config*.yml
        if checkpoint_path is not None and "load_weights" in self._config:
            not_restore = []
            loadw = self._config["load_weights"]
            if loadw is not None:
                if not loadw.get("src_embs"):
                    not_restore.append("encoder/w_embs")
                if not loadw.get("tgt_embs"):
                    not_restore.append("decoder/w_embs")
                if not loadw.get("projection"):
                    not_restore.append("decoder/dense")
                if not loadw.get("shared_embs"):
                    not_restore.append("shared_embeddings/w_embs")
                if not loadw.get("encoder"):
                    not_restore.append("encoder")
                if not loadw.get("decoder"):
                    not_restore.append("decoder")
                if not loadw.get(
                        "optim"
                ):  # if true, also avoids global_step and word_per_sec
                    not_restore.append("optim")
                    if not loadw.get("global_step"):
                        not_restore.append("global_step")
                    if not loadw.get("words_per_sec"):
                        not_restore.append("words_per_sec")

            tf.logging.info("NOT RESTORING: %s",
                            json.dumps(not_restore, indent=2, sort_keys=True))
            train_hooks.append(
                hooks.LoadWeightsFromCheckpointHook(checkpoint_path,
                                                    not_restore))

        train_spec = tf.estimator.TrainSpec(
            input_fn=self._model.input_fn(
                tf.estimator.ModeKeys.TRAIN,
                self._config["train"]["batch_size"],
                self._config["data"],
                self._config["data"]["train_features_file"],
                labels_file=self._config["data"]["train_labels_file"],
                batch_type=self._config["train"]["batch_type"],
                batch_multiplier=self._num_devices,
                bucket_width=self._config["train"]["bucket_width"],
                single_pass=self._config["train"].get("single_pass", False),
                num_threads=self._config["train"].get("num_threads"),
                sample_buffer_size=self._config["train"]["sample_buffer_size"],
                prefetch_buffer_size=self._config["train"].get(
                    "prefetch_buffer_size"),
                maximum_features_length=self._config["train"].get(
                    "maximum_features_length"),
                maximum_labels_length=self._config["train"].get(
                    "maximum_labels_length")),
            max_steps=self._config["train"].get("train_steps"),
            hooks=train_hooks)
        return train_spec
Esempio n. 4
0
    def _build_train_spec(self, checkpoint_path):
        train_hooks = [hooks.LogParametersCountHook()]

        #if checkpoint_path is not None:
        #  train_hooks.append(hooks.LoadWeightsFromCheckpointHook(checkpoint_path))

        # TODO: pass what not/to load based on the config "load_weights" boolean info. MOVE OPTS TO LOAD TO UTIL FUNCTION

        if checkpoint_path is not None and "load_weights" in self._config:
            #not_restore = ['encoder', 'decoder', 'shared_embeddings', 'optim', 'global_step', 'word_per_sec', 'output_layer']

            #not_restore = ['optim', 'global_step', 'word_per_sec'] #, 'encoder'
            # FOR RUNN THE word_per_sec DOEN'T EXIST THROWS ERROR - HOW TO HANDLE? --checkpoint LOADING? if.startswith("word_per_sec") WAY IS THE ANSE
            not_restore = []  # make sure its empty at first call
            loadw = self._config[
                "load_weights"]  # search in the var_list with [0] for optim..., [1] for enc/dec and [1:2] for embs, projection
            if loadw is not None:
                if not loadw.get("src_embs"):
                    not_restore.append("encoder/w_embs")
                if not loadw.get("tgt_embs"):
                    not_restore.append("decoder/w_embs")

                if not loadw.get("projection"):
                    not_restore.append("decoder/dense")

                if not loadw.get("shared_embs"):
                    not_restore.append("shared_embeddings/w_embs")

                if not loadw.get("encoder"):
                    not_restore.append("encoder")
                if not loadw.get("decoder"):
                    not_restore.append("decoder")

                if not loadw.get(
                        "optim"):  #if avoided the the next two are avoided
                    not_restore.append("optim")

                    if not loadw.get("global_step"):
                        not_restore.append("global_step")
                    if not loadw.get("words_per_sec"):
                        not_restore.append("words_per_sec")

            tf.logging.info("NOT RESTORING SUB-NETWORKS: %s",
                            json.dumps(not_restore, indent=2, sort_keys=True))

            train_hooks.append(
                hooks.LoadWeightsFromCheckpointHook(
                    checkpoint_path,
                    not_restore))  #self._config["load_partial_weights"]))

        train_spec = tf.estimator.TrainSpec(
            input_fn=self._model.input_fn(
                tf.estimator.ModeKeys.TRAIN,
                self._config["train"]["batch_size"],
                self._config["data"],
                self._config["data"]["train_features_file"],
                labels_file=self._config["data"]["train_labels_file"],
                batch_type=self._config["train"]["batch_type"],
                batch_multiplier=self._num_devices,
                bucket_width=self._config["train"]["bucket_width"],
                single_pass=self._config["train"].get("single_pass", False),
                num_threads=self._config["train"].get("num_threads"),
                sample_buffer_size=self._config["train"]["sample_buffer_size"],
                prefetch_buffer_size=self._config["train"].get(
                    "prefetch_buffer_size"),
                maximum_features_length=self._config["train"].get(
                    "maximum_features_length"),
                maximum_labels_length=self._config["train"].get(
                    "maximum_labels_length")),
            max_steps=self._config["train"].get("train_steps"),
            hooks=train_hooks)
        return train_spec