Exemplo n.º 1
0
    def fit(self,
            train_ds: MLDataset,
            evaluate_ds: Optional[MLDataset] = None,
            num_steps=None,
            profile=False,
            reduce_results=True,
            max_retries=3,
            info=None) -> NoReturn:
        super().fit(train_ds, evaluate_ds)
        train_ds = train_ds.batch(self._batch_size)
        train_tf_ds = self._create_tf_ds(train_ds)

        if evaluate_ds is not None:
            evaluate_ds = evaluate_ds.batch(self._batch_size)
            evaluate_tf_ds = self._create_tf_ds(evaluate_ds)
        else:
            evaluate_tf_ds = None

        self._create_trainer(train_tf_ds, evaluate_tf_ds)
        assert self._trainer is not None
        for i in range(self._num_epochs):
            stats = self._trainer.train(
                num_steps=num_steps,
                profile=profile,
                reduce_results=reduce_results,
                max_retries=max_retries,
                info=info)
            print(f"Epoch-{i}: {stats}")

        if evaluate_tf_ds is not None:
            print(self._trainer.validate(num_steps, profile, reduce_results, info))
Exemplo n.º 2
0
    def fit(self,
            train_ds: MLDataset,
            evaluate_ds: Optional[MLDataset] = None) -> NoReturn:
        super().fit(train_ds, evaluate_ds)

        def model_creator(config):
            # https://github.com/ray-project/ray/issues/5914
            import tensorflow.keras as keras  # pylint: disable=C0415, W0404

            model: keras.Model = keras.models.model_from_json(
                self._serialized_model)
            optimizer = keras.optimizers.get(self._serialized_optimizer)
            loss = keras.losses.get(self._serialized_loss)
            metrics = [keras.metrics.get(m) for m in self._serialized_metrics]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            return model

        train_ds = train_ds.batch(self._batch_size)
        train_tf_ds = self._create_tf_ds(train_ds)

        if evaluate_ds is not None:
            evaluate_ds = evaluate_ds.batch(self._batch_size)
            evaluate_tf_ds = self._create_tf_ds(evaluate_ds)
        else:
            evaluate_tf_ds = None

        def data_creator(config):
            if "TF_CONFIG" in os.environ:
                tf_config = json.loads(os.environ["TF_CONFIG"])
                world_rank = tf_config["task"]["index"]
            else:
                world_rank = -1
            batch_size = config["batch_size"]
            get_shard_config = config.get("get_shard", {})
            if "shuffle" in config:
                get_shard_config["shuffle"] = config["shuffle"]
            train_data = train_tf_ds.get_shard(
                world_rank, **get_shard_config).repeat().batch(batch_size)
            options = tf.data.Options()
            options.experimental_distribute.auto_shard_policy = \
                tf.data.experimental.AutoShardPolicy.OFF
            train_data = train_data.with_options(options)
            evaluate_data = None
            if evaluate_tf_ds is not None:
                evaluate_data = evaluate_tf_ds.get_shard(
                    world_rank, **get_shard_config).batch(batch_size)
                evaluate_data = evaluate_data.with_options(options)
            return train_data, evaluate_data

        self._trainer = TFTrainer(model_creator=model_creator,
                                  data_creator=data_creator,
                                  num_replicas=self._num_workers,
                                  **self._extra_config)
        for i in range(self._num_epochs):
            stats = self._trainer.train()
            print(f"Epoch-{i}: {stats}")

        if evaluate_tf_ds is not None:
            print(self._trainer.validate())