Example #1
0
def test_save_and_restore(ray_start_2_cpus, num_workers, use_local,
                          tmp_path):  # noqa: F811
    trainer1 = TorchTrainer(training_operator_cls=Operator,
                            num_workers=num_workers,
                            use_local=use_local)
    trainer1.train()
    checkpoint_path = os.path.join(tmp_path, "checkpoint")
    trainer1.save(checkpoint_path)

    model1 = trainer1.get_model()
    ints1 = trainer1.apply_all_operators(lambda op: op.get_model().rand_int)[0]

    trainer1.shutdown()

    trainer2 = TorchTrainer(training_operator_cls=Operator,
                            num_workers=num_workers,
                            use_local=use_local)
    trainer2.load(checkpoint_path)

    model2 = trainer2.get_model()
    ints2 = trainer2.apply_all_operators(lambda op: op.get_model().rand_int)

    model1_state_dict = model1.state_dict()
    model2_state_dict = model2.state_dict()

    assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())

    for k in model1_state_dict:
        assert torch.equal(model1_state_dict[k], model2_state_dict[k])
    for i in ints2:
        assert i == ints1
    trainer2.shutdown()
Example #2
0
def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
    Operator = TrainingOperator.from_ptl(LitMNIST)
    trainer = TorchTrainer(
        training_operator_cls=Operator,
        num_workers=num_workers,
        config={"lr": 1e-3, "batch_size": 64},
        use_gpu=use_gpu,
        use_tqdm=True,
    )
    for i in range(num_epochs):
        stats = trainer.train()
        print(stats)

    print(trainer.validate())
    print("Saving model checkpoint to ./model.pt")
    trainer.save("./model.pt")
    print("Model Checkpointed!")
    trainer.shutdown()
    print("success!")
Example #3
0
def main(args):
    if args.smoke_test:
        ray.init(num_cpus=4)
    else:
        ray.init(address=args.address,
                 num_cpus=args.num_workers,
                 log_to_driver=True)

    # Trainer Initialization
    trainer = TorchTrainer(training_operator_cls=CIFAR10Module,
                           num_workers=args.num_workers,
                           config={
                               "lr": args.learning_rate,
                               "lr_decay": args.lr_decay,
                               "eps": args.eps,
                               "momentum": args.momentum,
                               "wd": args.wd,
                               "data_dir": args.data_dir,
                               "batch_size": args.batch_size,
                               "num_workers": args.num_workers,
                               "smoke_test": args.smoke_test
                           },
                           use_gpu=args.use_gpu,
                           scheduler_step_freq="epoch",
                           use_fp16=args.fp16,
                           use_tqdm=False)

    train_loss = []
    val_loss = []
    val_acc = []

    path = os.path.join("/root/volume/Paper/MLVC_Internship",
                        args.checkpoint_dir,
                        args.model_name + "_" + str(args.trial))
    if not os.path.exists(path):
        os.mkdir(path)

    from tabulate import tabulate
    pbar = trange(args.max_epochs, unit="epoch")
    for it in pbar:
        stats = trainer.train(max_retries=1,
                              info=dict(epoch_idx=it,
                                        num_epochs=args.max_epochs))
        train_loss.append(stats["train_loss"])
        val_stats = trainer.validate()
        val_loss.append(val_stats["val_loss"])
        pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))

        trainer.save(
            "/root/volume/Paper/MLVC_Internship/checkpoint/{}_{}/epoch_{}.ray".
            format(args.model_name, args.trial, it))
        torch.save(
            [train_loss, val_loss],
            "/root/volume/Paper/MLVC_Internship/checkpoint/{}_{}/epoch_{}.loss"
            .format(args.model_name, args.trial, it))
        torch.save(
            [val_acc],
            "/root/volume/Paper/MLVC_Internship/checkpoint/{}_{}/epoch_{}.acc".
            format(args.model_name, args.trial, it))

    print(val_stats)
    trainer.shutdown()
    print("success!")