def test_train(ray_start_2_cpus, num_replicas): # noqa: F811 trainer = PyTorchTrainer(model_creator, data_creator, optimizer_creator, num_replicas=num_replicas) train_loss1 = trainer.train()["train_loss"] validation_loss1 = trainer.validate()["validation_loss"] train_loss2 = trainer.train()["train_loss"] validation_loss2 = trainer.validate()["validation_loss"] print(train_loss1, train_loss2) print(validation_loss1, validation_loss2) assert train_loss2 <= train_loss1 assert validation_loss2 <= validation_loss1
def train_example(num_replicas=1, num_epochs=5, use_gpu=False, use_fp16=False, test_mode=False): config = {TEST_MODE: test_mode} trainer1 = PyTorchTrainer(ResNet18, cifar_creator, optimizer_creator, nn.CrossEntropyLoss, scheduler_creator=scheduler_creator, initialization_hook=initialization_hook, num_replicas=num_replicas, config=config, use_gpu=use_gpu, batch_size=16 if test_mode else 512, backend="nccl" if use_gpu else "gloo", scheduler_step_freq="epoch", use_fp16=use_fp16) for i in range(num_epochs): # Increase `max_retries` to turn on fault tolerance. stats = trainer1.train(max_retries=0) print(stats) print(trainer1.validate()) trainer1.shutdown() print("success!")
def test_train(ray_start_2_cpus, num_replicas): # noqa: F811 trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_replicas=num_replicas) for i in range(3): train_loss1 = trainer.train()["train_loss"] validation_loss1 = trainer.validate()["validation_loss"] for i in range(3): train_loss2 = trainer.train()["train_loss"] validation_loss2 = trainer.validate()["validation_loss"] print(train_loss1, train_loss2) print(validation_loss1, validation_loss2) assert train_loss2 <= train_loss1 assert validation_loss2 <= validation_loss1
def test_test_mode(ray_start_2_cpus): # noqa: F811 trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), config={TEST_MODE: True}, num_replicas=1) metrics = trainer.train() assert metrics[BATCH_COUNT] == 1 val_metrics = trainer.validate() assert val_metrics[BATCH_COUNT] == 1
def train_example(num_replicas=1, use_gpu=False, test_mode=False): config = {"test_mode": test_mode} trainer1 = PyTorchTrainer(ResNet18, cifar_creator, optimizer_creator, nn.CrossEntropyLoss, initialization_hook=initialization_hook, train_function=train, validation_function=validate, num_replicas=num_replicas, config=config, use_gpu=use_gpu, batch_size=16 if test_mode else 512, backend="nccl" if use_gpu else "gloo") for i in range(5): stats = trainer1.train() print(stats) print(trainer1.validate()) trainer1.shutdown() print("success!")