コード例 #1
0
    def test_sparkxshards(self):

        train_data_shard = XShards.partition({
            "x":
            np.random.randn(100, 1),
            "y":
            np.random.randint(0, 1, size=(100))
        })

        config = {"batch_size": 4, "lr": 0.8}
        trainer = Estimator(model_creator=model_creator,
                            verbose=True,
                            config=config,
                            workers_per_node=2)

        trainer.fit(train_data_shard, epochs=1, steps_per_epoch=25)
        trainer.evaluate(train_data_shard, steps=25)
コード例 #2
0
    def impl_test_fit_and_evaluate(self, backend):
        import tensorflow as tf
        ray_ctx = RayContext.get()
        batch_size = 32
        global_batch_size = batch_size * ray_ctx.num_ray_nodes
        config = {"batch_size": global_batch_size}

        trainer = Estimator(model_creator=simple_model,
                            compile_args_creator=compile_args,
                            verbose=True,
                            config=config,
                            backend=backend)

        # model baseline performance
        start_stats = trainer.evaluate(create_test_dataset,
                                       steps=NUM_TEST_SAMPLES //
                                       global_batch_size)
        print(start_stats)

        def scheduler(epoch):
            if epoch < 2:
                return 0.001
            else:
                return 0.001 * tf.math.exp(0.1 * (2 - epoch))

        scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler,
                                                             verbose=1)
        # train for 2 epochs
        trainer.fit(create_train_datasets, epochs=2, callbacks=[scheduler])
        trainer.fit(create_train_datasets, epochs=2, callbacks=[scheduler])

        # model performance after training (should improve)
        end_stats = trainer.evaluate(create_test_dataset,
                                     steps=NUM_TEST_SAMPLES //
                                     global_batch_size)
        print(end_stats)

        # sanity check that training worked
        dloss = end_stats["validation_loss"] - start_stats["validation_loss"]
        dmse = (end_stats["validation_mean_squared_error"] -
                start_stats["validation_mean_squared_error"])
        print(f"dLoss: {dloss}, dMSE: {dmse}")

        assert dloss < 0 and dmse < 0, "training sanity check failed. loss increased!"