def test_sparkxshards(self): train_data_shard = XShards.partition({ "x": np.random.randn(100, 1), "y": np.random.randint(0, 1, size=(100)) }) config = {"batch_size": 4, "lr": 0.8} trainer = Estimator(model_creator=model_creator, verbose=True, config=config, workers_per_node=2) trainer.fit(train_data_shard, epochs=1, steps_per_epoch=25) trainer.evaluate(train_data_shard, steps=25)
def impl_test_fit_and_evaluate(self, backend): import tensorflow as tf ray_ctx = RayContext.get() batch_size = 32 global_batch_size = batch_size * ray_ctx.num_ray_nodes config = {"batch_size": global_batch_size} trainer = Estimator(model_creator=simple_model, compile_args_creator=compile_args, verbose=True, config=config, backend=backend) # model baseline performance start_stats = trainer.evaluate(create_test_dataset, steps=NUM_TEST_SAMPLES // global_batch_size) print(start_stats) def scheduler(epoch): if epoch < 2: return 0.001 else: return 0.001 * tf.math.exp(0.1 * (2 - epoch)) scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1) # train for 2 epochs trainer.fit(create_train_datasets, epochs=2, callbacks=[scheduler]) trainer.fit(create_train_datasets, epochs=2, callbacks=[scheduler]) # model performance after training (should improve) end_stats = trainer.evaluate(create_test_dataset, steps=NUM_TEST_SAMPLES // global_batch_size) print(end_stats) # sanity check that training worked dloss = end_stats["validation_loss"] - start_stats["validation_loss"] dmse = (end_stats["validation_mean_squared_error"] - start_stats["validation_mean_squared_error"]) print(f"dLoss: {dloss}, dMSE: {dmse}") assert dloss < 0 and dmse < 0, "training sanity check failed. loss increased!"