示例#1
0
def train_example(num_replicas=1, batch_size=128, use_gpu=False):
    trainer = TFTrainer(model_creator=simple_model,
                        data_creator=simple_dataset,
                        num_replicas=num_replicas,
                        use_gpu=use_gpu,
                        verbose=True,
                        config=create_config(batch_size * num_replicas))

    train_stats1 = trainer.train()
    train_stats1.update(trainer.validate())
    print(train_stats1)
    print("Throughput: " +
          str(batch_size * num_replicas / train_stats1["batch_time"]))

    val_stats = trainer.validate()
    print(val_stats)
    print("success!")
示例#2
0
def train_example(num_replicas=1, batch_size=128, use_gpu=False):
    trainer = TFTrainer(
        model_creator=simple_model,
        data_creator=simple_dataset,
        num_replicas=num_replicas,
        use_gpu=use_gpu,
        verbose=True,
        config=create_config(batch_size))

    # model baseline performance
    start_stats = trainer.validate()
    print(start_stats)

    # train for 2 epochs
    trainer.train()
    trainer.train()

    # model performance after training (should improve)
    end_stats = trainer.validate()
    print(end_stats)

    # sanity check that training worked
    dloss = end_stats["validation_loss"] - start_stats["validation_loss"]
    dmse = (end_stats["validation_mean_squared_error"] -
            start_stats["validation_mean_squared_error"])
    print(f"dLoss: {dloss}, dMSE: {dmse}")

    if dloss > 0 or dmse > 0:
        print("training sanity check failed. loss increased!")
    else:
        print("success!")
示例#3
0
    ray.init(address=args.address)
    data_size = 60000
    test_size = 10000
    batch_size = args.batch_size

    num_train_steps = 10 if args.smoke_test else data_size // batch_size
    num_eval_steps = 10 if args.smoke_test else test_size // batch_size

    trainer = TFTrainer(model_creator=create_model,
                        data_creator=(data_augmentation_creator
                                      if args.augment_data else data_creator),
                        num_replicas=args.num_replicas,
                        use_gpu=args.use_gpu,
                        verbose=True,
                        config={
                            "batch_size": batch_size,
                            "fit_config": {
                                "steps_per_epoch": num_train_steps,
                            },
                            "evaluate_config": {
                                "steps": num_eval_steps,
                            }
                        })

    for i in range(3):
        # Trains num epochs
        train_stats1 = trainer.train()
        train_stats1.update(trainer.validate())
        print("iter {}:".format(i), train_stats1)

    model = trainer.get_model()