Пример #1
0
    def test_auto_shard_tf(self):
        # file 1 contains all 0s, file 2 contains all 1s
        # If shard by files, then each model will
        # see the same records in the same batch.
        # If shard by records, then each batch
        # will have different records.
        # The loss func is constructed such that
        # the former case will return 0, and the latter
        # case will return non-zero.

        ray_ctx = RayContext.get()
        trainer = Estimator(model_creator=auto_shard_model_creator,
                            verbose=True,
                            config={"batch_size": 4},
                            backend="tf2",
                            workers_per_node=2)
        stats = trainer.fit(create_auto_shard_datasets,
                            epochs=1,
                            steps_per_epoch=2)
        assert stats["train_loss"] == 0.0
Пример #2
0
        "data_dir": args.data_dir,
        "bf16": args.use_bf16,
        "lr": initial_lr,
    }

    trainer = Estimator(model_creator=model_creator,
                        compile_args_creator=compile_args_creator,
                        verbose=True,
                        config=config,
                        backend="horovod")

    if args.benchmark:
        trainer.fit(
            data_creator=train_data_creator
            if not args.use_dummy_data else dummy_data_creator,
            epochs=3,
            steps_per_epoch=20,
            callbacks=callbacks,
        )
    else:
        epoch = 0
        for i in range(5):
            dummy = args.use_dummy_data

            results = trainer.fit(
                data_creator=train_data_creator
                if not dummy else dummy_data_creator,
                epochs=18,
                validation_data_creator=val_data_creator
                if not dummy else dummy_data_creator,
                steps_per_epoch=_NUM_IMAGES['train'] // global_batch_size,
Пример #3
0
    lr_schdule = tf.keras.callbacks.LearningRateScheduler(
        lambda epoch: schedule(epoch, lr_multiplier), verbose=1)

    config = {
        "momentum": 0.9,
        "wd": 0.00005,
        "batch_size": args.batch_size_per_worker,
        "val_batch_size": args.batch_size_per_worker,
        "warmup_epoch": 5,
        "num_worker": args.worker_num,
        "data_dir": args.data_dir,
    }

    trainer = Estimator(model_creator=model_creator,
                        compile_args_creator=compile_args_creator,
                        verbose=True,
                        config=config,
                        backend="horovod")

    results = trainer.fit(
        data_creator=train_data_creator,
        epochs=90,
        validation_data_creator=val_data_creator,
        steps_per_epoch=_NUM_IMAGES['train'] // global_batch_size,
        callbacks=[lr_schdule],
        validation_steps=_NUM_IMAGES['validation'] // global_batch_size,
    )

    print(results)
    stop_orca_context()