Ejemplo n.º 1
0
    def train_model(self,
                    num_epochs=300,
                    num_workers=1,
                    early_stop=False,
                    tenacity=10):
        if num_workers == 1:
            return self.train_model_local(num_epochs=num_epochs,
                                          early_stop=early_stop,
                                          tenacity=tenacity)
        else:
            from bigdl.chronos.model.tcmf.local_model_distributed_trainer import\
                train_yseq_hvd
            import ray

            # check whether there has been an activate ray context yet.
            from bigdl.orca.ray import RayContext
            ray_ctx = RayContext.get()
            Ymat_id = ray.put(self.Ymat)
            covariates_id = ray.put(self.covariates)
            Ycov_id = ray.put(self.Ycov)
            trainer_config_keys = [
                "vbsize", "hbsize", "end_index", "val_len", "lr", "num_inputs",
                "num_channels", "kernel_size", "dropout"
            ]
            trainer_config = {k: self.__dict__[k] for k in trainer_config_keys}
            model, val_loss = train_yseq_hvd(epochs=num_epochs,
                                             workers_per_node=num_workers //
                                             ray_ctx.num_ray_nodes,
                                             Ymat_id=Ymat_id,
                                             covariates_id=covariates_id,
                                             Ycov_id=Ycov_id,
                                             **trainer_config)
            self.seq = model
            return val_loss
    def test_parquet_images_training(self):
        from bigdl.orca.learn.tf2 import Estimator
        temp_dir = tempfile.mkdtemp()
        try:
            ParquetDataset.write("file://" + temp_dir, images_generator(),
                                 images_schema)
            path = "file://" + temp_dir
            output_types = {
                "id": tf.string,
                "image": tf.string,
                "label": tf.float32
            }
            output_shapes = {"id": (), "image": (), "label": ()}

            def data_creator(config, batch_size):
                dataset = read_parquet("tf_dataset",
                                       path=path,
                                       output_types=output_types,
                                       output_shapes=output_shapes)
                dataset = dataset.shuffle(10)
                dataset = dataset.map(lambda data_dict:
                                      (data_dict["image"], data_dict["label"]))
                dataset = dataset.map(parse_data_train)
                dataset = dataset.batch(batch_size)
                return dataset

            ray_ctx = RayContext.get()
            trainer = Estimator.from_keras(model_creator=model_creator)
            trainer.fit(data=data_creator, epochs=1, batch_size=2)
        finally:
            shutil.rmtree(temp_dir)
Ejemplo n.º 3
0
 def get_default_num_workers():
     from bigdl.orca.ray import RayContext
     try:
         ray_ctx = RayContext.get(initialize=False)
         num_workers = ray_ctx.num_ray_nodes
     except:
         num_workers = 1
     return num_workers
Ejemplo n.º 4
0
    def predict(self, x=None, horizon=24, mc=False,
                future_covariates=None,
                future_dti=None,
                num_workers=None):
        """
        Predict horizon time-points ahead the input x in fit_eval
        :param x: We don't support input x currently.
        :param horizon: horizon length to predict
        :param mc:
        :param future_covariates: covariates corresponding to future horizon steps data to predict.
        :param future_dti: dti corresponding to future horizon steps data to predict.
        :param num_workers: the number of workers to use. Note that there has to be an activate
            RayContext if num_workers > 1.
        :return:
        """
        if x is not None:
            raise ValueError("We don't support input x directly.")
        if self.model is None:
            raise Exception("Needs to call fit_eval or restore first before calling predict")
        self._check_covariates_dti(covariates=future_covariates, dti=future_dti, ts_len=horizon,
                                   method_name="predict")
        if num_workers is None:
            num_workers = TCMF.get_default_num_workers()
        if num_workers > 1:
            import ray
            from bigdl.orca.ray import RayContext
            try:
                RayContext.get(initialize=False)
            except:
                try:
                    # detect whether ray has been started.
                    ray.put(None)
                except:
                    raise RuntimeError(f"There must be an activate ray context while running with "
                                       f"{num_workers} workers. You can either start and init a "
                                       f"RayContext by init_orca_context(..., init_ray_on_spark="
                                       f"True) or start Ray with ray.init()")

        out = self.model.predict_horizon(
            future=horizon,
            bsize=90,
            num_workers=num_workers,
            future_covariates=future_covariates,
            future_dti=future_dti,
        )
        return out[:, -horizon::]
Ejemplo n.º 5
0
    def fit(
        self,
        input_df,
        validation_df=None,
        metric="mse",
        recipe=SmokeRecipe(),
        mc=False,
        resources_per_trial={"cpu": 2},
        upload_dir=None,
    ):
        """
        Trains the model for time sequence prediction.
        If future sequence length > 1, use seq2seq model, else use vanilla LSTM model.
        :param input_df: The input time series data frame, Example:
         datetime   value   "extra feature 1"   "extra feature 2"
         2019-01-01 1.9 1   2
         2019-01-02 2.3 0   2
        :param validation_df: validation data
        :param metric: String. Metric used for train and validation. Available values are
                       "mean_squared_error" or "r_square"
        :param recipe: a Recipe object. Various recipes covers different search space and stopping
                      criteria. Default is SmokeRecipe().
        :param resources_per_trial: Machine resources to allocate per trial,
            e.g. ``{"cpu": 64, "gpu": 8}`
        :param upload_dir: Optional URI to sync training results and checkpoints. We only support
            hdfs URI for now. It defaults to
            "hdfs:///user/{hadoop_user_name}/ray_checkpoints/{predictor_name}".
            Where hadoop_user_name is specified in init_orca_context or init_spark_on_yarn,
            which defaults to "root". predictor_name is the name used in predictor instantiation.
        )
        :return: a pipeline constructed with the best model and configs.
        """
        self._check_df(input_df)
        if validation_df is not None:
            self._check_df(validation_df)

        ray_ctx = RayContext.get()
        is_local = ray_ctx.is_local
        # BasePredictor._check_fit_metric(metric)
        if not is_local:
            if not upload_dir:
                hadoop_user_name = os.getenv("HADOOP_USER_NAME")
                upload_dir = os.path.join(os.sep, "user", hadoop_user_name,
                                          "ray_checkpoints", self.name)
            cmd = "hadoop fs -mkdir -p {}".format(upload_dir)
            process(cmd)
        else:
            upload_dir = None

        self.pipeline = self._hp_search(
            input_df,
            validation_df=validation_df,
            metric=metric,
            recipe=recipe,
            mc=mc,
            resources_per_trial=resources_per_trial,
            remote_dir=upload_dir)
        return self.pipeline
 def get_default_remote_dir(name):
     from bigdl.orca.ray import RayContext
     from bigdl.orca.automl.search.utils import process
     ray_ctx = RayContext.get()
     if ray_ctx.is_local:
         return None
     else:
         default_remote_dir = f"hdfs:///tmp/{name}"
         process(command=f"hadoop fs -mkdir -p {default_remote_dir}")
         return default_remote_dir
Ejemplo n.º 7
0
    def test_horovod_learning_rate_schedule(self):
        import horovod
        major, minor, patch = horovod.__version__.split(".")

        larger_major = int(major) > 0
        larger_minor = int(major) == 0 and int(minor) > 19
        larger_patch = int(major) == 0 and int(minor) == 19 and int(patch) >= 2

        if larger_major or larger_minor or larger_patch:
            ray_ctx = RayContext.get()
            batch_size = 32
            workers_per_node = 4
            global_batch_size = batch_size * workers_per_node
            config = {"lr": 0.8}
            trainer = Estimator.from_keras(model_creator=simple_model,
                                           compile_args_creator=compile_args,
                                           verbose=True,
                                           config=config,
                                           backend="horovod",
                                           workers_per_node=workers_per_node)
            import horovod.tensorflow.keras as hvd
            callbacks = [
                hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5,
                                                         initial_lr=0.4,
                                                         verbose=True),
                hvd.callbacks.LearningRateScheduleCallback(start_epoch=5,
                                                           end_epoch=10,
                                                           multiplier=1.,
                                                           initial_lr=0.4),
                hvd.callbacks.LearningRateScheduleCallback(start_epoch=10,
                                                           end_epoch=15,
                                                           multiplier=1e-1,
                                                           initial_lr=0.4),
                hvd.callbacks.LearningRateScheduleCallback(start_epoch=15,
                                                           end_epoch=20,
                                                           multiplier=1e-2,
                                                           initial_lr=0.4),
                hvd.callbacks.LearningRateScheduleCallback(start_epoch=20,
                                                           multiplier=1e-3,
                                                           initial_lr=0.4),
                LRChecker()
            ]
            for i in range(30):
                trainer.fit(create_train_datasets,
                            epochs=1,
                            batch_size=global_batch_size,
                            callbacks=callbacks)
        else:
            # skip tests in horovod lower version
            pass
Ejemplo n.º 8
0
 def test_gluon(self):
     current_ray_ctx = RayContext.get()
     address_info = current_ray_ctx.address_info
     assert "object_store_address" in address_info
     config = create_config(log_interval=2, optimizer="adam",
                            optimizer_params={'learning_rate': 0.02})
     estimator = Estimator.from_mxnet(config=config,
                                      model_creator=get_model,
                                      loss_creator=get_loss,
                                      eval_metrics_creator=get_metrics,
                                      validation_metrics_creator=get_metrics,
                                      num_workers=2)
     estimator.fit(get_train_data_iter, validation_data=get_test_data_iter, epochs=2)
     estimator.shutdown()
Ejemplo n.º 9
0
def stop_orca_context():
    """
    Stop the SparkContext (and stop Ray services across the cluster if necessary).
    """
    from pyspark import SparkContext
    # If users successfully call stop_orca_context after the program finishes,
    # namely when there is no active SparkContext, the registered exit function
    # should do nothing.
    if SparkContext._active_spark_context is not None:
        print("Stopping orca context")
        from bigdl.orca.ray import RayContext
        ray_ctx = RayContext.get(initialize=False)
        if ray_ctx.initialized:
            ray_ctx.stop()
        sc = SparkContext.getOrCreate()
        if sc.getConf().get("spark.master").startswith("spark://"):
            from bigdl.dllib.nncontext import stop_spark_standalone
            stop_spark_standalone()
        sc.stop()
Ejemplo n.º 10
0
    def to_spark_xshards(self):
        from bigdl.orca.data import SparkXShards
        ray_ctx = RayContext.get()
        sc = ray_ctx.sc
        address = ray_ctx.redis_address
        password = ray_ctx.redis_password
        num_parts = self.num_partitions()
        partition2store = self.partition2store_name
        rdd = self.rdd.mapPartitionsWithIndex(lambda idx, _: get_from_ray(
            idx, address, password, partition2store))

        # the reason why we trigger computation here is to ensure we get the data
        # from ray before the RayXShards goes out of scope and the data get garbage collected
        from pyspark.storagelevel import StorageLevel
        rdd = rdd.cache()
        result_rdd = rdd.map(
            lambda x: x)  # sparkxshards will uncache the rdd when gc
        spark_xshards = SparkXShards(result_rdd)
        return spark_xshards
Ejemplo n.º 11
0
    def _from_spark_xshards_ray_api(spark_xshards):
        ray_ctx = RayContext.get()
        address = ray_ctx.redis_address
        password = ray_ctx.redis_password
        driver_ip = ray._private.services.get_node_ip_address()
        uuid_str = str(uuid.uuid4())
        resources = ray.cluster_resources()
        nodes = []
        for key, value in resources.items():
            if key.startswith("node:"):
                # if running in cluster, filter out driver ip
                if key != f"node:{driver_ip}":
                    nodes.append(key)
        # for the case of local mode and single node spark standalone
        if not nodes:
            nodes.append(f"node:{driver_ip}")

        partition_stores = {}
        for node in nodes:
            name = f"partition:{uuid_str}:{node}"
            if version.parse(ray.__version__) >= version.parse("1.4.0"):
                store = ray.remote(num_cpus=0, resources={node: 1e-4})(LocalStore)\
                    .options(name=name, lifetime="detached").remote()
            else:
                store = ray.remote(num_cpus=0, resources={node: 1e-4})(LocalStore) \
                    .options(name=name).remote()
            partition_stores[name] = store

        # actor creation is aync, this is to make sure they all have been started
        ray.get([v.get_partitions.remote() for v in partition_stores.values()])
        partition_store_names = list(partition_stores.keys())
        result_rdd = spark_xshards.rdd.mapPartitionsWithIndex(
            lambda idx, part: write_to_ray(idx, part, address, password,
                                           partition_store_names)).cache()
        result = result_rdd.collect()

        id2ip = {}
        id2store_name = {}
        for idx, ip, local_store_name in result:
            id2ip[idx] = ip
            id2store_name[idx] = local_store_name

        return RayXShards(uuid_str, result_rdd, partition_stores)
Ejemplo n.º 12
0
    def test_auto_shard_tf(self):
        # file 1 contains all 0s, file 2 contains all 1s
        # If shard by files, then each model will
        # see the same records in the same batch.
        # If shard by records, then each batch
        # will have different records.
        # The loss func is constructed such that
        # the former case will return 0, and the latter
        # case will return non-zero.

        ray_ctx = RayContext.get()
        trainer = Estimator.from_keras(model_creator=auto_shard_model_creator,
                                       verbose=True,
                                       backend="tf2",
                                       workers_per_node=2)
        stats = trainer.fit(create_auto_shard_datasets,
                            epochs=1,
                            batch_size=4,
                            steps_per_epoch=2)
        assert stats["train_loss"] == 0.0
Ejemplo n.º 13
0
 def from_partition_refs(ip2part_id, part_id2ref, old_rdd):
     ray_ctx = RayContext.get()
     uuid_str = str(uuid.uuid4())
     id2store_name = {}
     partition_stores = {}
     part_id2ip = {}
     result = []
     for node, part_ids in ip2part_id.items():
         name = f"partition:{uuid_str}:{node}"
         store = ray.remote(num_cpus=0, resources={f"node:{node}": 1e-4})(LocalStore) \
             .options(name=name).remote()
         partition_stores[name] = store
         for idx in part_ids:
             result.append(
                 store.upload_partition.remote(idx, part_id2ref[idx]))
             id2store_name[idx] = name
             part_id2ip[idx] = node
     ray.get(result)
     new_id_ip_store_rdd = old_rdd.mapPartitionsWithIndex(lambda idx, _: [(
         idx, part_id2ip[idx], id2store_name[idx])]).cache()
     return RayXShards(uuid_str, new_id_ip_store_rdd, partition_stores)
Ejemplo n.º 14
0
    def __init__(self,
                 model_creator,
                 compile_args_creator=None,
                 config=None,
                 verbose=False,
                 backend="tf2",
                 workers_per_node=1,
                 cpu_binding=False):
        self.model_creator = model_creator
        self.compile_args_creator = compile_args_creator
        self.config = {} if config is None else config
        self.verbose = verbose

        ray_ctx = RayContext.get()
        if "batch_size" in self.config:
            raise Exception(
                "Please do not specify batch_size in config. Input batch_size in the"
                " fit/evaluate function of the estimator instead.")

        if "inter_op_parallelism" not in self.config:
            self.config["inter_op_parallelism"] = 1

        if "intra_op_parallelism" not in self.config:
            self.config[
                "intra_op_parallelism"] = ray_ctx.ray_node_cpu_cores // workers_per_node

        if backend == "horovod":
            assert compile_args_creator is not None, "compile_args_creator should not be None," \
                                                     " when backend is set to horovod"

        params = {
            "model_creator": model_creator,
            "compile_args_creator": compile_args_creator,
            "config": self.config,
            "verbose": self.verbose,
        }

        if backend == "tf2":
            cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node
            num_nodes = ray_ctx.num_ray_nodes * workers_per_node

            self.cluster = RayDLCluster(num_workers=num_nodes,
                                        worker_cores=cores_per_node,
                                        worker_cls=TFRunner,
                                        worker_param=params,
                                        cpu_binding=cpu_binding)
            self.remote_workers = self.cluster.get_workers()
            ips = ray.get([
                worker.get_node_ip.remote() for worker in self.remote_workers
            ])
            ports = ray.get([
                worker.find_free_port.remote()
                for worker in self.remote_workers
            ])

            urls = [
                "{ip}:{port}".format(ip=ips[i], port=ports[i])
                for i in range(len(self.remote_workers))
            ]
            ray.get([worker.setup.remote() for worker in self.remote_workers])
            # Get setup tasks in order to throw errors on failure
            ray.get([
                worker.setup_distributed.remote(urls, i,
                                                len(self.remote_workers))
                for i, worker in enumerate(self.remote_workers)
            ])
        elif backend == "horovod":
            # it is necessary to call self.run first to set horovod environment
            from bigdl.orca.learn.horovod.horovod_ray_runner import HorovodRayRunner
            horovod_runner = HorovodRayRunner(
                ray_ctx,
                worker_cls=TFRunner,
                worker_param=params,
                workers_per_node=workers_per_node)
            horovod_runner.run(lambda: print("worker initialized"))
            self.remote_workers = horovod_runner.remote_workers
            ray.get([worker.setup.remote() for worker in self.remote_workers])
            ray.get([
                worker.setup_horovod.remote()
                for i, worker in enumerate(self.remote_workers)
            ])
        else:
            raise Exception("Only \"tf2\" and \"horovod\" are legal "
                            "values of backend, but got {}".format(backend))

        self.num_workers = len(self.remote_workers)
Ejemplo n.º 15
0
    def __init__(self,
                 *,
                 model_creator,
                 optimizer_creator,
                 loss_creator=None,
                 metrics=None,
                 scheduler_creator=None,
                 training_operator_cls=TrainingOperator,
                 initialization_hook=None,
                 config=None,
                 scheduler_step_freq="batch",
                 use_tqdm=False,
                 backend="torch_distributed",
                 workers_per_node=1,
                 sync_stats=True,
                 log_level=logging.INFO):
        if config is not None and "batch_size" in config:
            raise Exception(
                "Please do not specify batch_size in config. Input batch_size in the"
                " fit/evaluate/predict function of the estimator instead.")

        # todo remove ray_ctx to run on workers
        ray_ctx = RayContext.get()
        if not (isinstance(model_creator, types.FunctionType) and isinstance(
                optimizer_creator,
                types.FunctionType)):  # Torch model is also callable.
            raise ValueError(
                "Must provide a function for both model_creator and optimizer_creator"
            )

        self.model_creator = model_creator
        self.optimizer_creator = optimizer_creator
        self.loss_creator = loss_creator
        self.scheduler_creator = scheduler_creator
        self.training_operator_cls = training_operator_cls
        self.scheduler_step_freq = scheduler_step_freq
        self.use_tqdm = use_tqdm
        self.sync_stats = sync_stats

        if not training_operator_cls and not loss_creator:
            raise ValueError("If a loss_creator is not provided, you must "
                             "provide a custom training operator.")

        self.initialization_hook = initialization_hook
        self.config = {} if config is None else config
        worker_config = self.config.copy()
        params = dict(model_creator=self.model_creator,
                      optimizer_creator=self.optimizer_creator,
                      loss_creator=self.loss_creator,
                      scheduler_creator=self.scheduler_creator,
                      training_operator_cls=self.training_operator_cls,
                      scheduler_step_freq=self.scheduler_step_freq,
                      use_tqdm=self.use_tqdm,
                      config=worker_config,
                      metrics=metrics,
                      sync_stats=sync_stats,
                      log_level=log_level)

        if backend == "torch_distributed":
            cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node
            num_nodes = ray_ctx.num_ray_nodes * workers_per_node
            RemoteRunner = ray.remote(
                num_cpus=cores_per_node)(PytorchRayWorker)
            self.remote_workers = [
                RemoteRunner.remote(**params) for i in range(num_nodes)
            ]
            ray.get([
                worker.setup.remote(cores_per_node)
                for i, worker in enumerate(self.remote_workers)
            ])

            head_worker = self.remote_workers[0]
            address = ray.get(head_worker.setup_address.remote())

            logger.info(f"initializing pytorch process group on {address}")

            ray.get([
                worker.setup_torch_distribute.remote(address, i, num_nodes)
                for i, worker in enumerate(self.remote_workers)
            ])

        elif backend == "horovod":
            from bigdl.orca.learn.horovod.horovod_ray_runner import HorovodRayRunner
            self.horovod_runner = HorovodRayRunner(
                ray_ctx,
                worker_cls=PytorchRayWorker,
                worker_param=params,
                workers_per_node=workers_per_node)
            self.remote_workers = self.horovod_runner.remote_workers
            cores_per_node = self.horovod_runner.cores_per_node
            ray.get([
                worker.setup.remote(cores_per_node)
                for i, worker in enumerate(self.remote_workers)
            ])

            ray.get([
                worker.setup_horovod.remote()
                for i, worker in enumerate(self.remote_workers)
            ])
        else:
            raise Exception(
                "Only \"torch_distributed\" and \"horovod\" are supported "
                "values of backend, but got {}".format(backend))
        self.num_workers = len(self.remote_workers)
Ejemplo n.º 16
0
    def __init__(self,
                 config,
                 model_creator,
                 loss_creator=None,
                 eval_metrics_creator=None,
                 validation_metrics_creator=None,
                 num_workers=None,
                 num_servers=None,
                 runner_cores=None):
        ray_ctx = RayContext.get()
        if not num_workers:
            num_workers = ray_ctx.num_ray_nodes
        self.config = {} if config is None else config
        assert isinstance(config, dict), "config must be a dict"
        for param in ["optimizer", "optimizer_params", "log_interval"]:
            assert param in config, param + " must be specified in config"
        self.model_creator = model_creator
        self.loss_creator = loss_creator
        self.validation_metrics_creator = validation_metrics_creator
        self.eval_metrics_creator = eval_metrics_creator
        self.num_workers = num_workers
        self.num_servers = num_servers if num_servers else self.num_workers

        # Generate actor class
        # Add a dummy custom resource: _mxnet_worker and _mxnet_server to diff worker from server
        # if runner_cores is specified so that we can place one worker and one server on a node
        # for better performance.
        Worker = ray.remote(num_cpus=runner_cores, resources={"_mxnet_worker": 1})(MXNetRunner) \
            if runner_cores else ray.remote(MXNetRunner)
        Server = ray.remote(num_cpus=runner_cores, resources={"_mxnet_server": 1})(MXNetRunner) \
            if runner_cores else ray.remote(MXNetRunner)

        # Start runners: workers followed by servers
        self.workers = [Worker.remote() for i in range(self.num_workers)]
        self.servers = [Server.remote() for i in range(self.num_servers)]
        self.runners = self.workers + self.servers

        env = {
            "DMLC_PS_ROOT_URI": str(get_host_ip()),
            "DMLC_PS_ROOT_PORT": str(find_free_port()),
            "DMLC_NUM_SERVER": str(self.num_servers),
            "DMLC_NUM_WORKER": str(self.num_workers),
        }
        envs = []
        for i in range(self.num_workers):
            current_env = env.copy()
            current_env['DMLC_ROLE'] = 'worker'
            envs.append(current_env)
        for i in range(self.num_servers):
            current_env = env.copy()
            current_env['DMLC_ROLE'] = 'server'
            envs.append(current_env)

        env['DMLC_ROLE'] = 'scheduler'
        modified_env = os.environ.copy()
        modified_env.update(env)
        # Need to contain system env to run bash
        # TODO: Need to kill this process manually?
        subprocess.Popen("python -c 'import mxnet'",
                         shell=True,
                         env=modified_env)

        ray.get([
            runner.setup_distributed.remote(envs[i], self.config,
                                            self.model_creator,
                                            self.loss_creator,
                                            self.validation_metrics_creator,
                                            self.eval_metrics_creator)
            for i, runner in enumerate(self.runners)
        ])
Ejemplo n.º 17
0
 def get_ray_context():
     from bigdl.orca.ray import RayContext
     return RayContext.get()
Ejemplo n.º 18
0
    def impl_test_fit_and_evaluate(self, backend):
        import tensorflow as tf
        ray_ctx = RayContext.get()
        batch_size = 32
        global_batch_size = batch_size * ray_ctx.num_ray_nodes

        if backend == "horovod":
            trainer = Estimator.from_keras(model_creator=simple_model,
                                           compile_args_creator=compile_args,
                                           verbose=True,
                                           config=None,
                                           backend=backend)
        else:

            trainer = Estimator.from_keras(model_creator=model_creator,
                                           verbose=True,
                                           config=None,
                                           backend=backend,
                                           workers_per_node=2)

        # model baseline performance
        start_stats = trainer.evaluate(create_test_dataset,
                                       batch_size=global_batch_size,
                                       num_steps=NUM_TEST_SAMPLES //
                                       global_batch_size)
        print(start_stats)

        def scheduler(epoch):
            if epoch < 2:
                return 0.001
            else:
                return 0.001 * tf.math.exp(0.1 * (2 - epoch))

        scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler,
                                                             verbose=1)
        # train for 2 epochs
        trainer.fit(create_train_datasets,
                    epochs=2,
                    batch_size=global_batch_size,
                    steps_per_epoch=10,
                    callbacks=[scheduler])
        trainer.fit(create_train_datasets,
                    epochs=2,
                    batch_size=global_batch_size,
                    steps_per_epoch=10,
                    callbacks=[scheduler])

        # model performance after training (should improve)
        end_stats = trainer.evaluate(create_test_dataset,
                                     batch_size=global_batch_size,
                                     num_steps=NUM_TEST_SAMPLES //
                                     global_batch_size)
        print(end_stats)

        # sanity check that training worked
        dloss = end_stats["validation_loss"] - start_stats["validation_loss"]
        dmse = (end_stats["validation_mean_squared_error"] -
                start_stats["validation_mean_squared_error"])
        print(f"dLoss: {dloss}, dMSE: {dmse}")

        assert dloss < 0 and dmse < 0, "training sanity check failed. loss increased!"
Ejemplo n.º 19
0
parser.add_argument("--slave_num",
                    type=int,
                    default=2,
                    help="The number of slave nodes to be used in the cluster."
                    "You can change it depending on your own cluster setting.")
parser.add_argument(
    "--cores",
    type=int,
    default=8,
    help="The number of cpu cores you want to use on each node. "
    "You can change it depending on your own cluster setting.")
parser.add_argument(
    "--memory",
    type=str,
    default="10g",
    help="The size of slave(executor)'s memory you want to use."
    "You can change it depending on your own cluster setting.")

if __name__ == "__main__":

    args = parser.parse_args()
    num_nodes = 1 if args.cluster_mode == "local" else args.slave_num
    init_orca_context(cluster_mode=args.cluster_mode,
                      cores=args.cores,
                      num_nodes=num_nodes,
                      memory=args.memory)

    runner = HorovodRayRunner(RayContext.get())
    runner.run(func=run_horovod)
    stop_orca_context()