def fit(self, train_ds: MLDataset, evaluate_ds: Optional[MLDataset] = None, num_steps=None, profile=False, reduce_results=True, max_retries=3, info=None) -> NoReturn: super().fit(train_ds, evaluate_ds) train_ds = train_ds.batch(self._batch_size) train_tf_ds = self._create_tf_ds(train_ds) if evaluate_ds is not None: evaluate_ds = evaluate_ds.batch(self._batch_size) evaluate_tf_ds = self._create_tf_ds(evaluate_ds) else: evaluate_tf_ds = None self._create_trainer(train_tf_ds, evaluate_tf_ds) assert self._trainer is not None for i in range(self._num_epochs): stats = self._trainer.train( num_steps=num_steps, profile=profile, reduce_results=reduce_results, max_retries=max_retries, info=info) print(f"Epoch-{i}: {stats}") if evaluate_tf_ds is not None: print(self._trainer.validate(num_steps, profile, reduce_results, info))
def fit(self, train_ds: MLDataset, evaluate_ds: Optional[MLDataset] = None) -> NoReturn: super().fit(train_ds, evaluate_ds) def model_creator(config): # https://github.com/ray-project/ray/issues/5914 import tensorflow.keras as keras # pylint: disable=C0415, W0404 model: keras.Model = keras.models.model_from_json( self._serialized_model) optimizer = keras.optimizers.get(self._serialized_optimizer) loss = keras.losses.get(self._serialized_loss) metrics = [keras.metrics.get(m) for m in self._serialized_metrics] model.compile(optimizer=optimizer, loss=loss, metrics=metrics) return model train_ds = train_ds.batch(self._batch_size) train_tf_ds = self._create_tf_ds(train_ds) if evaluate_ds is not None: evaluate_ds = evaluate_ds.batch(self._batch_size) evaluate_tf_ds = self._create_tf_ds(evaluate_ds) else: evaluate_tf_ds = None def data_creator(config): if "TF_CONFIG" in os.environ: tf_config = json.loads(os.environ["TF_CONFIG"]) world_rank = tf_config["task"]["index"] else: world_rank = -1 batch_size = config["batch_size"] get_shard_config = config.get("get_shard", {}) if "shuffle" in config: get_shard_config["shuffle"] = config["shuffle"] train_data = train_tf_ds.get_shard( world_rank, **get_shard_config).repeat().batch(batch_size) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = \ tf.data.experimental.AutoShardPolicy.OFF train_data = train_data.with_options(options) evaluate_data = None if evaluate_tf_ds is not None: evaluate_data = evaluate_tf_ds.get_shard( world_rank, **get_shard_config).batch(batch_size) evaluate_data = evaluate_data.with_options(options) return train_data, evaluate_data self._trainer = TFTrainer(model_creator=model_creator, data_creator=data_creator, num_replicas=self._num_workers, **self._extra_config) for i in range(self._num_epochs): stats = self._trainer.train() print(f"Epoch-{i}: {stats}") if evaluate_tf_ds is not None: print(self._trainer.validate())