Ejemplo n.º 1
0
def test_checkpoint_not_in_ir():
    import warnings
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)

    # Config
    args = """
    --config unit_test
    --lr-schedule constant
    --layers-per-ipu 0 3
    --vocab-size 30400
    --weight-decay 0.0
    --recompute-checkpoint-every-layer False
    """.split()
    config = BertConfig(**(vars(parse_bert_args(args))))

    assert config.recompute_checkpoint_every_layer is False

    # Execution parameters
    opts = get_options(config)
    model = PipelinedBertWithLoss(config).half().train()
    optimizer = get_optimizer(config, model)
    poptorch_model = poptorch.trainingModel(model, opts, optimizer=optimizer)

    # Compile model
    datum = get_generated_datum(config)
    poptorch_model.compile(*datum)
    ir = json.loads(poptorch_model._debugGetPopartIR())
    assert not any(["Checkpoint" in node["name"] for node in ir["maingraph"]
                    ]), ("Popart IR should contain a checkpoint")

    # Stash: 5 inputs, and 1 stash for transformers on ipu1
    exp_num_stash = 5 + 1
    assert sum([
        "Stash" in node["type"] for node in ir["maingraph"]
    ]) == exp_num_stash, ("Both the graph input and the checkpoint(s) "
                          "should be stashed")
    print(sum(["Stash" in node["type"] for node in ir["maingraph"]]))
Ejemplo n.º 2
0
            scheduler.last_epoch = steps_finished = checkpoint["step"]
            checkpoint_metrics = checkpoint["metrics"]
        else:
            # Checkpoint model with epochs and optimizer state reset
            # for further training
            save_checkpoint(config, model, optimizer, steps_finished)
    else:
        # Checkpoint model at start of run
        save_checkpoint(config, model, optimizer, steps_finished)

    poptorch_model = trainingModel(model, opts, optimizer=optimizer)

    # Compile model
    logger("---------- Compilation/Loading from Cache Started ---------")
    start_compile = time.perf_counter()
    datum = get_generated_datum(config)
    poptorch_model.compile(*datum)
    duration_compilation = time.perf_counter() - start_compile
    logger(f"Compiled/Loaded model in {duration_compilation} secs")
    logger("-----------------------------------------------------------")

    # Training loop
    logger("--------------------- Training Started --------------------")
    factor = config.gradient_accumulation * config.batches_per_step
    start_train = time.perf_counter()
    loader = cycle(loader)
    train_iterator = tqdm(range(steps_finished, config.training_steps),
                          desc="Training",
                          disable=config.disable_progress_bar)
    for step in train_iterator:
        start_step = time.perf_counter()
Ejemplo n.º 3
0
 def mock_data():
     return get_generated_datum(config)