Ejemplo n.º 1
0
 def __init__(self, config):
   """
   :param Config config:
   """
   print("Initialize distributed TensorFlow", file=log.v2)
   self.config = config
   opts = config.get_of_type("distributed_tf", dict, {})
   opts = CollectionReadCheckCovered(opts)
   self.opts = opts
   if opts.get("local_only", False):  # might be useful for testing
     cluster_resolver = LocalOnlyClusterResolver()
     print("Use local-only cluster resolver,", file=log.v4, end=" ")
   elif os.environ.get("TF_CONFIG", ""):
     cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
     print("Use TF_CONFIG %s," % os.environ["TF_CONFIG"], file=log.v4, end=" ")
   else:
     cluster_resolver = MPIClusterResolver()
     print("Use MPI cluster resolver,", file=log.v4, end=" ")
   print("cluster spec %s, master %s" % (cluster_resolver.cluster_spec(), cluster_resolver.master()), file=log.v4)
   self.cluster_resolver = cluster_resolver
   cluster_spec = cluster_resolver.cluster_spec()
   self.cluster_spec = cluster_spec
   tf_session_opts = config.typed_dict.get("tf_session_opts", {})
   server_config = tf.compat.v1.ConfigProto(**tf_session_opts)
   # Note that there is no clean way currently in TF to uninit the TF server.
   # If we would use this multiple times (e.g. in tests),
   # it might actually be better to cache the server as a singleton...
   server = tf.distribute.Server(
     cluster_spec,
     job_name=cluster_resolver.task_type, task_index=cluster_resolver.task_id,
     config=server_config)
   self.server = server
   self.strategy = ReturnnDefaultStrategy()  # not really used currently...
   self.opts.assert_all_read()
Ejemplo n.º 2
0
class Optimization:
    """
  Hyper parameter optimization handler class.
  """
    def __init__(self, config, train_data):
        """
    :param returnn.config.Config config:
    :param Dataset train_data:
    """
        self.config = config
        self.opts = CollectionReadCheckCovered(
            config.get_of_type("hyper_param_tuning", dict, {}))
        self.log = log.v1
        train_data.init_seq_order(epoch=1)
        self.train_data = StaticDataset.copy_from_dataset(
            train_data, max_seqs=self.opts.get("num_train_steps", 100))
        self.hyper_params = []  # type: typing.List[HyperParam]
        self._find_hyper_params()
        if not self.hyper_params:
            raise Exception("No hyper params found.")
        self.hyper_params.sort(key=lambda p_: p_.unique_idx)
        print("We have found these hyper params:")
        for p in self.hyper_params:
            print(" %s" % p.description())
        self.dry_run_first_individual = self.opts.get(
            "dry_run_first_individual", True)
        self.num_iterations = self.opts["num_tune_iterations"]
        self.num_individuals = self.opts["num_individuals"]
        self.num_kill_individuals = self.opts.get("num_kill_individuals",
                                                  self.num_individuals // 2)
        self.num_best = self.opts.get("num_best", 10)
        self.num_threads = self.opts.get("num_threads",
                                         guess_requested_max_num_threads())
        self.opts.assert_all_read()

    def _find_hyper_params(self, base=None, visited=None):
        """
    :param _AttrChain base:
    :param set[int] visited: set of ids
    """
        from inspect import ismodule
        if base is None:
            base = _AttrChain(base=self.config)
        if isinstance(base.value, HyperParam):
            base.value.usages.append(base)
            if base.value not in self.hyper_params:
                self.hyper_params.append(base.value)
            return
        if visited is None:
            visited = set()
        if id(base.value) in visited:
            return
        visited.add(id(base.value))
        if ismodule(base.value):
            return
        if isinstance(base.value, dict):
            col_type = _AttribOrKey.ColTypeDict
            keys = base.value.keys()
        elif isinstance(base.value, Config):
            col_type = _AttribOrKey.ColTypeConfig
            keys = base.value.typed_dict.keys()
        else:
            # Add other specific object types, but not in generic all.
            return
        for key in sorted(keys):
            child = base.get_extended_chain(
                _AttribOrKey(key=key, col_type=col_type))
            self._find_hyper_params(base=child, visited=visited)

    def get_population(self, iteration_idx, num_individuals):
        """
    :param int iteration_idx:
    :param int num_individuals:
    :rtype: list[Individual]
    """
        assert num_individuals > 0
        return [
            self.get_individual(iteration_idx=iteration_idx, individual_idx=i)
            for i in range(num_individuals)
        ]

    def get_individual(self, iteration_idx, individual_idx):
        """
    :param int iteration_idx:
    :param int individual_idx:
    :rtype: Individual
    """
        return Individual(
            {
                p: p.get_random_value_by_idx(iteration_idx=iteration_idx,
                                             individual_idx=individual_idx)
                for p in self.hyper_params
            },
            name="%i-%i" % (iteration_idx, individual_idx))

    def cross_over(self, population, iteration_idx):
        """
    :param list[Individual] population: modified in-place
    :param int iteration_idx:
    """
        for i in range(len(population) - 1):
            population[i] = population[i].cross_over(
                hyper_params=self.hyper_params,
                population=population[:i] + population[i + 1:],
                random_seed=iteration_idx * 1013 + i * 17)

    def create_config_instance(self, hyper_param_mapping, gpu_ids):
        """
    :param dict[HyperParam] hyper_param_mapping: maps each hyper param to some value
    :param set[int] gpu_ids:
    :rtype: Config
    """
        assert set(self.hyper_params) == set(hyper_param_mapping.keys())
        from returnn.util.basic import deepcopy
        config = deepcopy(self.config)
        assert isinstance(config, Config)
        for p, value in hyper_param_mapping.items():
            assert isinstance(p, HyperParam)
            for attr_chain in p.usages:
                attr_chain.write_attrib(base=config, new_value=value)
        tf_session_opts = config.typed_dict.setdefault("tf_session_opts", {})
        # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto
        gpu_opts = tf_session_opts.setdefault("gpu_options",
                                              tf_compat.v1.GPUOptions())
        if isinstance(gpu_opts, dict):
            gpu_opts = tf_compat.v1.GPUOptions(**gpu_opts)
        gpu_opts.visible_device_list = ",".join(map(str, sorted(gpu_ids)))
        return config

    def work(self):
        """
    Start the optimization.
    """
        print("Starting hyper param search. Using %i threads." %
              self.num_threads,
              file=log.v1)
        from returnn.tf.util.basic import get_available_gpu_devices
        from returnn.log import wrap_log_streams, StreamDummy
        from threading import Thread, Condition
        from returnn.util.basic import progress_bar, hms, is_tty

        class Outstanding:
            """
      Queue of outstanding work.
      """
            cond = Condition()
            threads = []  # type: typing.List[WorkerThread]
            population = []
            exit = False
            exception = None

        class WorkerThread(Thread):
            """
      Worker threader.
      """
            def __init__(self, gpu_ids):
                """
        :param set[int] gpu_ids:
        """
                super(WorkerThread,
                      self).__init__(name="Hyper param tune train thread")
                self.gpu_ids = gpu_ids
                self.trainer = None  # type: typing.Optional[_IndividualTrainer]
                self.finished = False
                self.start()

            def cancel(self, join=False):
                """
        :param bool join:
        """
                with Outstanding.cond:
                    if self.trainer:
                        self.trainer.cancel_flag = True
                        if self.trainer.runner:
                            self.trainer.runner.cancel_flag = True
                if join:
                    self.join()

            def get_complete_frac(self):
                """
        :rtype: float
        """
                with Outstanding.cond:
                    if self.trainer and self.trainer.runner:
                        return self.trainer.runner.data_provider.get_complete_frac(
                        )
                return 0.0

            # noinspection PyMethodParameters
            def run(self_thread):
                """
        Run thread.
        """
                try:
                    while True:
                        with Outstanding.cond:
                            if Outstanding.exit or Outstanding.exception:
                                return
                            if not Outstanding.population:
                                self_thread.finished = True
                                Outstanding.cond.notify_all()
                                return
                            # noinspection PyShadowingNames
                            individual = Outstanding.population.pop(0)
                            self_thread.trainer = _IndividualTrainer(
                                optim=self,
                                individual=individual,
                                gpu_ids=self_thread.gpu_ids)
                        self_thread.name = "Hyper param tune train thread on %r" % individual.name
                        self_thread.trainer.run()
                except Exception as exc:
                    with Outstanding.cond:
                        if not Outstanding.exception:
                            Outstanding.exception = exc or True
                        Outstanding.cond.notify_all()
                    for thread in Outstanding.threads:
                        if thread is not self_thread:
                            thread.cancel()
                    if not isinstance(exc, CancelTrainingException):
                        with Outstanding.cond:  # So that we don't mix up multiple on sys.stderr.
                            # This would normally dump it on sys.stderr so it's fine.
                            sys.excepthook(*sys.exc_info())

        best_individuals = []
        population = []
        num_gpus = len(get_available_gpu_devices())
        print("Num available GPUs:", num_gpus)
        num_gpus = num_gpus or 1  # Would be ignored anyway.
        interactive = is_tty()
        try:
            print(
                "Population of %i individuals (hyper param setting instances), running for %i evaluation iterations."
                % (self.num_individuals, self.num_iterations),
                file=log.v2)
            for cur_iteration_idx in range(1, self.num_iterations + 1):
                print("Starting iteration %i." % cur_iteration_idx,
                      file=log.v2)
                if cur_iteration_idx == 1:
                    population.append(
                        Individual(
                            {
                                p: p.get_default_value()
                                for p in self.hyper_params
                            },
                            name="default"))
                    population.append(
                        Individual(
                            {
                                p: p.get_initial_value()
                                for p in self.hyper_params
                            },
                            name="canonical"))
                population.extend(
                    self.get_population(iteration_idx=cur_iteration_idx,
                                        num_individuals=self.num_individuals -
                                        len(population)))
                if cur_iteration_idx > 1:
                    self.cross_over(population=population,
                                    iteration_idx=cur_iteration_idx)
                if cur_iteration_idx == 1 and self.dry_run_first_individual:
                    # Train first directly for testing and to see log output.
                    # Later we will strip away all log output.
                    print("Very first try with log output:", file=log.v2)
                    _IndividualTrainer(optim=self,
                                       individual=population[0],
                                       gpu_ids={0}).run()
                print("Starting training with thread pool of %i threads." %
                      self.num_threads)
                iteration_start_time = time.time()
                with wrap_log_streams(StreamDummy(),
                                      also_sys_stdout=True,
                                      tf_log_verbosity="WARN"):
                    Outstanding.exit = False
                    Outstanding.population = list(population)
                    Outstanding.threads = [
                        WorkerThread(gpu_ids={i % num_gpus})
                        for i in range(self.num_threads)
                    ]
                    try:
                        while True:
                            with Outstanding.cond:
                                if all([
                                        thread.finished
                                        for thread in Outstanding.threads
                                ]) or Outstanding.exception:
                                    break
                                complete_frac = max(
                                    len(population) -
                                    len(Outstanding.population) -
                                    len(Outstanding.threads), 0)
                                complete_frac += sum([
                                    thread.get_complete_frac()
                                    for thread in Outstanding.threads
                                ])
                                complete_frac /= float(len(population))
                                remaining_str = ""
                                if complete_frac > 0:
                                    start_elapsed = time.time(
                                    ) - iteration_start_time
                                    total_time_estimated = start_elapsed / complete_frac
                                    remaining_estimated = total_time_estimated - start_elapsed
                                    remaining_str = hms(remaining_estimated)
                                if interactive:
                                    progress_bar(complete_frac,
                                                 prefix=remaining_str,
                                                 file=sys.__stdout__)
                                else:
                                    print("Progress: %.02f%%" %
                                          (complete_frac * 100),
                                          "remaining:",
                                          remaining_str or "unknown",
                                          file=sys.__stdout__)
                                    sys.__stdout__.flush()
                                Outstanding.cond.wait(1 if interactive else 10)
                        for thread in Outstanding.threads:
                            thread.join()
                    finally:
                        Outstanding.exit = True
                        for thread in Outstanding.threads:
                            thread.cancel(join=True)
                Outstanding.threads = []
                print("Training iteration elapsed time:",
                      hms(time.time() - iteration_start_time))
                if Outstanding.exception:
                    raise Outstanding.exception
                assert not Outstanding.population
                print("Training iteration finished.")
                population.sort(key=lambda p: p.cost)
                del population[-self.num_kill_individuals:]
                best_individuals.extend(population)
                best_individuals.sort(key=lambda p: p.cost)
                del best_individuals[self.num_best:]
                population = best_individuals[:self.num_kill_individuals //
                                              4] + population
                print(
                    "Current best setting, individual %s" %
                    best_individuals[0].name, "cost:",
                    best_individuals[0].cost)
                for p in self.hyper_params:
                    print(" %s -> %s" %
                          (p.description(),
                           best_individuals[0].hyper_param_mapping[p]))
        except KeyboardInterrupt:
            print("KeyboardInterrupt, canceled search.")

        print("Best %i settings:" % len(best_individuals))
        for individual in best_individuals:
            print("Individual %s" % individual.name, "cost:", individual.cost)
            for p in self.hyper_params:
                print(" %s -> %s" %
                      (p.description(), individual.hyper_param_mapping[p]))