예제 #1
0
파일: entry.py 프로젝트: charles9304/FastNN
def main(_):
    assert FLAGS.runner_name in runner_map, \
      "Wrong runner_name for {}, default images".format(FLAGS.runner_name)
    ####################
    # Kicks off Whale. #
    ####################
    import whale as wh
    if FLAGS.enable_whale:
        wh.init()
    layout = {"average": FLAGS.average}
    if FLAGS.runner_name == "large_scale_classification":
        layout = "all"
    cluster = wh.cluster(worker_hosts=FLAGS.worker_hosts,
                         ps_hosts=FLAGS.ps_hosts,
                         job_name=FLAGS.job_name,
                         rank=FLAGS.task_index,
                         layout=layout)
    config_proto = tf.ConfigProto(log_device_placement=False,
                                  allow_soft_placement=True,
                                  gpu_options=tf.GPUOptions(
                                      force_gpu_compatible=True,
                                      allow_growth=True))

    if FLAGS.task_type in ['pretrain', 'finetune']:
        with cluster:
            runner_map[FLAGS.runner_name].run_model(cluster, config_proto)
    else:
        raise ValueError('task_type [%s] was not recognized' % FLAGS.task_type)
예제 #2
0
    def main():
      app = init_flask_app()

      whale.init()

      whale.route(flask_app = app)
      static.route(flask_app=app)

      app.run(debug=True)
예제 #3
0
    def main():
        app = init_flask_app()

        whale.init()

        import datetime

        # Each of the below modules represent categories of endpoints. Each module binds the required endpoints to the
        # flask app so we can serve those requests.
        static.route(flask_app=app)
        pages.route(flask_app=app)
        whale.route(flask_app=app)

        app.run(debug=True)
예제 #4
0
    def __init__(self, **kwargs):

        if self.config.mode == 'train' or self.config.mode == "train_and_evaluate" or \
                self.config.mode == "train_and_evaluate_on_the_fly" or self.config.mode == "train_on_the_fly":

            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))

            if self.config.enable_xla is True:
                tf.logging.info("***********Enable Tao***********")
                os.environ['BRIDGE_ENABLE_TAO'] = 'True'
                os.environ["TAO_ENABLE_CHECK"] = "false"
                os.environ["TAO_COMPILATION_MODE_ASYNC"] = "false"
                os.environ["DISABLE_DEADNESS_ANALYSIS"] = "true"
            else:
                tf.logging.info("***********Disable Tao***********")

            if self.config.enable_auto_mixed_precision is True:
                tf.logging.info(
                    "***********Enable AUTO_MIXED_PRECISION***********")
                os.environ['TF_AUTO_MIXED_PRECISION'] = 'True'
                os.environ['lossScaling'] = 'auto'
            else:
                tf.logging.info(
                    "***********Disable AUTO_MIXED_PRECISION***********")

            NCCL_MAX_NRINGS = "4"
            NCCL_MIN_NRINGS = "4"
            TF_JIT_PROFILING = 'False'
            PAI_ENABLE_HLO_DUMPER = 'False'
            os.environ['PAI_ENABLE_HLO_DUMPER'] = PAI_ENABLE_HLO_DUMPER
            os.environ['TF_JIT_PROFILING'] = TF_JIT_PROFILING
            os.environ["NCCL_MAX_NRINGS"] = NCCL_MAX_NRINGS
            os.environ["NCCL_MIN_NRINGS"] = NCCL_MIN_NRINGS
            os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
            tf.logging.info("***********NCCL_MAX_NRINGS {}***********".format(
                NCCL_MAX_NRINGS))
            tf.logging.info("***********NCCL_MIN_NRINGS {}***********".format(
                NCCL_MIN_NRINGS))
            tf.logging.info("***********TF_JIT_PROFILING {}***********".format(
                TF_JIT_PROFILING))
            tf.logging.info(
                "***********PAI_ENABLE_HLO_DUMPER {}***********".format(
                    PAI_ENABLE_HLO_DUMPER))

            self.strategy = None
            if self.config.num_gpus >= 1 and self.config.num_workers >= 1 and \
                    (self.config.distribution_strategy == "ExascaleStrategy" or
                     self.config.distribution_strategy == "CollectiveAllReduceStrategy"):

                if FLAGS.usePAI:
                    import pai
                    worker_hosts = self.config.worker_hosts.split(',')
                    tf.logging.info(
                        "***********Job Name is {}***********".format(
                            self.config.job_name))
                    tf.logging.info(
                        "***********Task Index is {}***********".format(
                            self.config.task_index))
                    tf.logging.info(
                        "***********Worker Hosts is {}***********".format(
                            self.config.worker_hosts))
                    pai.distribute.set_tf_config(
                        self.config.job_name,
                        self.config.task_index,
                        worker_hosts,
                        has_evaluator=self.config.
                        pull_evaluation_in_multiworkers_training)

                if self.config.distribution_strategy == "ExascaleStrategy":
                    tf.logging.info(
                        "*****************Using ExascaleStrategy*********************"
                    )
                    if FLAGS.usePAI:
                        self.strategy = pai.distribute.ExascaleStrategy(
                            num_gpus=self.config.num_gpus,
                            num_micro_batches=self.config.
                            num_accumulated_batches,
                            max_splits=1,
                            enable_sparse_allreduce=False)
                    else:
                        raise ValueError("Please set usePAI is True")

                elif self.config.distribution_strategy == "CollectiveAllReduceStrategy":
                    tf.logging.info(
                        "*****************Using CollectiveAllReduceStrategy*********************"
                    )
                    if FLAGS.usePAI:
                        self.strategy = tf.contrib.distribute.CollectiveAllReduceStrategy(
                            num_gpus_per_worker=self.config.num_gpus,
                            cross_tower_ops_type='default',
                            all_dense=True,
                            iter_size=self.config.num_accumulated_batches)
                    else:
                        self.strategy = tf.contrib.distribute.CollectiveAllReduceStrategy(
                            num_gpus_per_worker=self.config.num_gpus)

                if self.config.pull_evaluation_in_multiworkers_training is True:
                    real_num_workers = self.config.num_workers - 1
                else:
                    real_num_workers = self.config.num_workers

                global_batch_size = self.config.train_batch_size * self.config.num_gpus * real_num_workers


            elif self.config.num_gpus > 1 and self.config.num_workers == 1 and \
                    self.config.distribution_strategy == "MirroredStrategy":
                tf.logging.info(
                    "*****************Using MirroredStrategy*********************"
                )
                if FLAGS.usePAI:
                    from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
                    cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
                        'nccl')
                    self.strategy = tf.contrib.distribute.MirroredStrategy(
                        num_gpus=self.config.num_gpus,
                        cross_tower_ops=cross_tower_ops,
                        all_dense=True,
                        iter_size=self.config.num_accumulated_batches)
                else:
                    self.strategy = tf.contrib.distribute.MirroredStrategy(
                        num_gpus=self.config.num_gpus)

                global_batch_size = self.config.train_batch_size * self.config.num_gpus * self.config.num_accumulated_batches

            elif self.config.num_gpus >= 1 and self.config.num_workers >= 1 and \
                    self.config.distribution_strategy == "WhaleStrategy":

                if FLAGS.usePAI:
                    import pai
                    worker_hosts = self.config.worker_hosts.split(',')
                    tf.logging.info(
                        "***********Job Name is {}***********".format(
                            self.config.job_name))
                    tf.logging.info(
                        "***********Task Index is {}***********".format(
                            self.config.task_index))
                    tf.logging.info(
                        "***********Worker Hosts is {}***********".format(
                            self.config.worker_hosts))
                    pai.distribute.set_tf_config(
                        self.config.job_name,
                        self.config.task_index,
                        worker_hosts,
                        has_evaluator=self.config.
                        pull_evaluation_in_multiworkers_training)

                tf.logging.info(
                    "*****************Using WhaleStrategy*********************"
                )
                os.environ["WHALE_COMMUNICATION_SPARSE_AS_DENSE"] = "True"
                os.environ["WHALE_COMMUNICATION_NUM_COMMUNICATORS"] = "2"
                os.environ["WHALE_COMMUNICATION_NUM_SPLITS"] = "8"
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches * self.config.num_model_replica

            elif self.config.num_gpus == 1 and self.config.num_workers == 1:
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches
                tf.logging.info(
                    "***********Single worker, Single gpu, Don't use distribution strategy***********"
                )

            elif self.config.num_gpus == 0 and self.config.num_workers == 1:
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches
                tf.logging.info(
                    "***********Single worker, Running on CPU***********")

            else:
                raise ValueError(
                    "In train model, Please set correct num_workers, num_gpus and distribution_strategy, \n"
                    "num_workers>=1, num_gpus>=1, distribution_strategy=WhaleStrategy|ExascaleStrategy|CollectiveAllReduceStrategy \n"
                    "num_workers>1, num_gpus==1, distribution_strategy=MirroredStrategy \n"
                    "num_workers=1, num_gpus=1, distribution_strategy=None")

            # Validate optional keyword arguments.
            if "num_train_examples" not in kwargs:
                raise ValueError('Please pass num_train_examples')

            self.num_train_examples = kwargs['num_train_examples']

            # if save steps is None, save per epoch
            if self.config.save_steps is None:
                self.save_steps = int(self.num_train_examples /
                                      global_batch_size)
            else:
                self.save_steps = self.config.save_steps

            self.train_steps = int(
                self.num_train_examples * self.config.num_epochs /
                global_batch_size) + 1

            self.throttle_secs = self.config.throttle_secs
            self.model_dir = self.config.model_dir
            tf.logging.info("model_dir: {}".format(self.config.model_dir))
            tf.logging.info("num workers: {}".format(self.config.num_workers))
            tf.logging.info("num gpus: {}".format(self.config.num_gpus))
            tf.logging.info("learning rate: {}".format(
                self.config.learning_rate))
            tf.logging.info("train batch size: {}".format(
                self.config.train_batch_size))
            tf.logging.info("global batch size: {}".format(global_batch_size))
            tf.logging.info("num accumulated batches: {}".format(
                self.config.num_accumulated_batches))
            tf.logging.info("num model replica: {}".format(
                self.config.num_model_replica))
            tf.logging.info("num train examples per epoch: {}".format(
                self.num_train_examples))
            tf.logging.info("num epochs: {}".format(self.config.num_epochs))
            tf.logging.info("train steps: {}".format(self.train_steps))
            tf.logging.info("save steps: {}".format(self.save_steps))
            tf.logging.info("throttle secs: {}".format(self.throttle_secs))
            tf.logging.info("keep checkpoint max: {}".format(
                self.config.keep_checkpoint_max))
            tf.logging.info("warmup ratio: {}".format(
                self.config.warmup_ratio))
            tf.logging.info("gradient clip: {}".format(
                self.config.gradient_clip))
            tf.logging.info("log step count steps: {}".format(
                self.config.log_step_count_steps))

            if self.config.distribution_strategy != "WhaleStrategy":
                self.estimator = tf.estimator.Estimator(
                    model_fn=self._build_model_fn(),
                    model_dir=self.config.model_dir,
                    config=self._get_run_train_config(config=self.config))
            else:
                tf.logging.info("***********Using Whale Estimator***********")
                try:
                    from easytransfer.engines.whale_estimator import WhaleEstimator
                    import whale as wh
                    wh.init()
                    self.estimator = WhaleEstimator(
                        model_fn=self._build_model_fn(),
                        model_dir=self.config.model_dir,
                        num_model_replica=self.config.num_model_replica,
                        num_accumulated_batches=self.config.
                        num_accumulated_batches)
                except:
                    raise NotImplementedError(
                        "WhaleStrategy doesn't work well")

            if self.config.mode == 'train_and_evaluate' or self.config.mode == 'train_and_evaluate_on_the_fly':
                self.num_eval_steps = self.config.num_eval_steps
                tf.logging.info("num eval steps: {}".format(
                    self.num_eval_steps))

        elif self.config.mode == 'evaluate' or self.config.mode == 'evaluate_on_the_fly':
            self.num_eval_steps = self.config.num_eval_steps
            tf.logging.info("num eval steps: {}".format(self.num_eval_steps))
            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'predict' or self.config.mode == 'predict_on_the_fly':
            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'export':
            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'preprocess':
            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=tf.estimator.RunConfig())

            self.first_sequence = self.config.first_sequence
            self.second_sequence = self.config.second_sequence
            self.label_enumerate_values = self.config.label_enumerate_values
            self.label_name = self.config.label_name
예제 #5
0
 def main():
     app = create_app()
     whale.init()
     whale.route(flask_app=app)
     app.run(debug=True)