def test_auto_shard_tf(self): # file 1 contains all 0s, file 2 contains all 1s # If shard by files, then each model will # see the same records in the same batch. # If shard by records, then each batch # will have different records. # The loss func is constructed such that # the former case will return 0, and the latter # case will return non-zero. ray_ctx = RayContext.get() trainer = Estimator(model_creator=auto_shard_model_creator, verbose=True, config={"batch_size": 4}, backend="tf2", workers_per_node=2) stats = trainer.fit(create_auto_shard_datasets, epochs=1, steps_per_epoch=2) assert stats["train_loss"] == 0.0
"data_dir": args.data_dir, "bf16": args.use_bf16, "lr": initial_lr, } trainer = Estimator(model_creator=model_creator, compile_args_creator=compile_args_creator, verbose=True, config=config, backend="horovod") if args.benchmark: trainer.fit( data_creator=train_data_creator if not args.use_dummy_data else dummy_data_creator, epochs=3, steps_per_epoch=20, callbacks=callbacks, ) else: epoch = 0 for i in range(5): dummy = args.use_dummy_data results = trainer.fit( data_creator=train_data_creator if not dummy else dummy_data_creator, epochs=18, validation_data_creator=val_data_creator if not dummy else dummy_data_creator, steps_per_epoch=_NUM_IMAGES['train'] // global_batch_size,
lr_schdule = tf.keras.callbacks.LearningRateScheduler( lambda epoch: schedule(epoch, lr_multiplier), verbose=1) config = { "momentum": 0.9, "wd": 0.00005, "batch_size": args.batch_size_per_worker, "val_batch_size": args.batch_size_per_worker, "warmup_epoch": 5, "num_worker": args.worker_num, "data_dir": args.data_dir, } trainer = Estimator(model_creator=model_creator, compile_args_creator=compile_args_creator, verbose=True, config=config, backend="horovod") results = trainer.fit( data_creator=train_data_creator, epochs=90, validation_data_creator=val_data_creator, steps_per_epoch=_NUM_IMAGES['train'] // global_batch_size, callbacks=[lr_schdule], validation_steps=_NUM_IMAGES['validation'] // global_batch_size, ) print(results) stop_orca_context()