Example #1
0
def test_method_get_trainer():
    codes = Code.load_from_json(cfg.FILE_INDEX_JSON_PATH)
    tokenizer = get_tokenier(cfg.MODEL_NAME, cfg.SPECIAL_TOKENS)
    model = get_model(cfg.MODEL_NAME, tokenizer, cfg.SPECIAL_TOKENS)
    ds = CodeDataset(
        codes=codes,
        tokenizer=tokenizer,
        control_tokens=cfg.CONTROL_TOKENS,
        max_length=cfg.MAX_LENGTH,
        num_description_sentences=cfg.NUM_DESCRIPTION_SENTENCES,
    )

    val_sz = int(0.2 * len(ds))
    train_sz = len(ds) - val_sz
    train_ds, val_ds = random_split(ds, [train_sz, val_sz])

    trainer = get_trainer(model, tokenizer, train_ds, val_ds, cfg)

    assert hasattr(trainer,
                   "model"), "Trainer object must have a model attribute."
    assert hasattr(
        trainer,
        "tokenizer"), "Trainer object must have a tokenizer attribute."
    assert hasattr(trainer,
                   "train"), "Trainer object must have a train method."
    assert hasattr(
        trainer.model, "save_pretrained"
    ), "Trainer.model object must have a save_pretrained method."
    assert hasattr(
        trainer.tokenizer, "save_pretrained"
    ), "Trainer.tokenizer object must have a save_pretrained method."
Example #2
0
def test_method_get_model():
    tokenizer = get_tokenier(cfg.MODEL_NAME, cfg.SPECIAL_TOKENS)
    model = get_model(cfg.MODEL_NAME, tokenizer, cfg.SPECIAL_TOKENS)
    assert hasattr(
        tokenizer,
        "__len__"), "The tokenizer object must have a __len__ property."
    assert hasattr(
        model, "forward"
    ), "The model is a subclass of pytorch.nn.module, hence must have a forward method."
Example #3
0
def run_training(cfg, checkpoint_path: Union[Path, str] = None) -> None:
    codes = Code.load_from_json(cfg.FILE_INDEX_JSON_PATH)
    print(
        f"loaded {len(codes)} code instances from",
        formatter(cfg.FILE_INDEX_JSON_PATH, "g", True),
    )
    tokenizer = get_tokenier(cfg.MODEL_NAME, cfg.SPECIAL_TOKENS)
    model = get_model(cfg.MODEL_NAME, tokenizer, cfg.SPECIAL_TOKENS)
    print(
        f"loaded tokenizer and model  for {formatter(cfg.MODEL_NAME, 'g', True)} model"
    )

    ds = CodeDataset(
        codes,
        tokenizer,
        cfg.CONTROL_TOKENS,
        cfg.MAX_LENGTH,
        cfg.NUM_DESCRIPTION_SENTENCES,
    )

    print("dataset loaded successfully")

    val_sz = int(len(ds) * cfg.VAL_PCT)
    train_sz = len(ds) - val_sz
    train_ds, val_ds = random_split(ds, [train_sz, val_sz])

    print("len(training dataset):", len(train_ds))
    print("len(validation dataset):", len(val_ds))

    trainer = get_trainer(model, tokenizer, train_ds, val_ds, cfg)

    if checkpoint_path:
        print(formatter("resuming training", "g", True))
        trainer.train(checkpoint_path)
    else:
        print(formatter("initializing training", "g", True))
        trainer.train()

    save_transformers(
        cfg.MODEL_PATH,
        cfg.TOKENIZER_PATH,
        trainer.model,
        trainer.tokenizer,
    )
Example #4
0
def test_module_code_dataset():
    codes = Code.load_from_json(cfg.FILE_INDEX_JSON_PATH)
    tokenizer = get_tokenier(cfg.MODEL_NAME, cfg.SPECIAL_TOKENS)
    ds = CodeDataset(
        codes=codes,
        tokenizer=tokenizer,
        control_tokens=cfg.CONTROL_TOKENS,
        max_length=cfg.MAX_LENGTH,
        num_description_sentences=cfg.NUM_DESCRIPTION_SENTENCES,
    )

    assert hasattr(ds, "__len__"), "Dataset must have a __len__ attribute."
    assert hasattr(ds, "__getitem__"), "Dataset must have a __getitem__ attribute."
    assert isinstance(
        ds[0], dict
    ), "CodeDataset should return a dictionary object when indexed."
    assert isinstance(
        ds.get_string(0), str
    ), "CodeDataset.get_string must return the string format text."
Example #5
0
def test_method_load_save_model():
    tokenizer = get_tokenier(cfg.MODEL_NAME, cfg.SPECIAL_TOKENS)
    model = get_model(cfg.MODEL_NAME, tokenizer, cfg.SPECIAL_TOKENS)

    model_path = Path(str(cfg.MODEL_PATH) + "__test")
    tokenizer_path = Path(str(cfg.TOKENIZER_PATH) + "__test")

    save_transformers(model_path, tokenizer_path, model, tokenizer)

    model, tokenizer = load_transformers(model_path, tokenizer_path)

    rmtree(model_path)
    rmtree(tokenizer_path)

    assert hasattr(
        tokenizer,
        "__len__"), "The tokenizer object must have a __len__ property."
    assert hasattr(
        model, "forward"
    ), "The model is a subclass of pytorch.nn.module, hence must have a forward method."
Example #6
0
def test_method_get_tokenizer():
    tokenizer = get_tokenier(cfg.MODEL_NAME, cfg.SPECIAL_TOKENS)
    assert hasattr(
        tokenizer,
        "__len__"), "The tokenizer object must have a __len__ property."