def main():
    config = get_config()
    with config:
        config.logging_steps = 400
        config.train_epochs = 2
        config.lr = 4e-5
        # config.lr = 1e-4
        config.model_type = 'roberta'
        config.model_path = util.models_path('StackOBERTflow-comments-small-v1')
        # config.train_head_only = True

    ds = TDDataset(config, binary=True)

    tokenizer = tu.load_tokenizer(config)
    model_cls = tu.get_model_cls(config)

    train_dataloader = ds.get_complete_train_dataloader(tokenizer)
    model = tu.load_model(config)
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config, model, tokenizer)
    global_step, tr_loss = experiment.train(train_dataloader)

    experiment.save_model(util.models_path('satd_complete_binary'))
def main(config, results):
    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = Dataset(config, tokenizer)
    label_names = ds.label_names

    train_dataloader, valid_dataloader = ds.get_train_valid_dataloaders()

    model = tu.load_model(config, model_config)
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config,
                            model,
                            tokenizer,
                            label_names=label_names,
                            results=results)
    global_step, tr_loss = experiment.train(train_dataloader,
                                            valid_dataloader=valid_dataloader)
    results = experiment.results

    experiment.save_model(util.models_path('comment_code_shuffle'))

    return results
def main(config, results):
    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = Dataset(config, tokenizer)
    label_names = ds.label_names

    train_dataloader = ds.get_train_dataloader()
    fake_valid_dataloader = ds.get_fake_valid_dataloader()

    # with config:
    #     config.max_steps=100

    model = tu.load_model(config, model_config)
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config, model, tokenizer, label_names=label_names, results=results)
    global_step, tr_loss = experiment.train(train_dataloader, valid_dataloader=fake_valid_dataloader) #test_dataloader=test_dataloader)

    valid_dataloader = ds.get_valid_dataloader()
    test_dataloader = ds.get_test_dataloader()
    experiment.evaluate('test_final', test_dataloader)
    experiment.evaluate('valid_final', valid_dataloader)
    experiment.save_model('test_model_complexity')

    with config:
        config.model_path = 'test_model_complexity' 
    model = tu.load_model(config, model_config)
    model.to(config.device)
    logger.warn('#################################### =========================')
    experiment = Experiment(config, model, tokenizer, label_names=label_names, results=results)
    experiment.evaluate('test_final_reloaded', test_dataloader)
    experiment.evaluate('valid_final_reloaded', valid_dataloader)



    results = experiment.results
    
    return results