def run_training(args: dict):
    word_lexicon = load_vocab(args['lex_size'])

    train_data, dev_data, test_data, entities_lexicon = instantiate_datasets(
        word_lexicon, args
    )

    print("Instantiating model")
    pyramid_ner = Pyramid(
        word_lexicon=word_lexicon,
        entities_lexicon=entities_lexicon,
        classifier_type=args['classifier_type'],
        word_embeddings=args['word_embeddings'],
        language_model=args['language_model'],
        char_embeddings_dim=args['char_embeddings_dim'],
        encoder_hidden_size=args['encoder_hidden_size'],
        embedding_encoder_type=args['embedding_encoder_type'],
        embedding_encoder_output_size=args['embedding_encoder_output_size'],
        decoder_hidden_size=args['decoder_hidden_size'],
        inverse_pyramid=args['inverse_pyramid'],
        custom_tokenizer=args['custom_tokenizer'],
        pyramid_max_depth=args['pyramid_max_depth'],
        decoder_dropout=args['decoder_dropout'],
        encoder_dropout=args['encoder_dropout'],
        use_pre=args['use_pre'],
        use_post=args['use_post'],
        hidden_size=args['hidden_size'],
        rnn_layers=args['rnn_layers'],
        reproject_words=args['reproject_words'],
        reproject_words_dimension=args['reproject_words_dimension'],
        bidirectional=args['bidirectional'],
        dropout=args['dropout'],
        word_dropout=args['word_dropout'],
        locked_dropout=args['locked_dropout'],
        device=DEVICE
    )

    print(pyramid_ner.nnet)

    trainer = MultiLabelTrainer(pyramid_ner)
    optimizer, scheduler = get_default_sgd_optim(pyramid_ner.nnet.parameters())
    ner_model, train_report = trainer.train(
        train_data,
        optimizer=optimizer,
        scheduler=scheduler,
        dev_data=dev_data,
        epochs=args['epochs'],
        patience=args['patience'],
        grad_clip=args['grad_clip'],
        restore_weights_on=args['restore_weights_on']
    )
    train_report.plot_loss_report()
    train_report.plot_custom_report('micro_f1')

    formatted_report = trainer.test_model(test_data, out_dict=False)
    write_report(args, train_report, formatted_report)
def run_sigmoid_test(args: dict):
    word_lexicon = load_vocab(args['lex_size'])

    print("Loading data")
    train_data = list(
        itertools.islice(wrg_reader("data/nne_concat/train.txt"), 256))
    dev_data = list(itertools.islice(wrg_reader("data/nne_concat/dev.txt"),
                                     64))
    test_data = list(
        itertools.islice(wrg_reader("data/nne_concat/test.txt"), 64))

    print("Generating entity lexicon")
    entities_lexicon = list(
        sorted({
            entity.name
            for data in (train_data, dev_data, test_data)
            for data_point in data for entity in data_point.entities
        }))
    print(entities_lexicon)

    print("Instantiating train dataset")
    train_dataset = SigmoidMultiLabelNerDataset(
        train_data,
        pyramid_max_depth=args['pyramid_max_depth'],
        token_lexicon=word_lexicon,
        entities_lexicon=entities_lexicon,
        custom_tokenizer=None,
        char_vectorizer=True,
    )
    print("Instantiating dev dataset")
    dev_dataset = SigmoidMultiLabelNerDataset(
        dev_data,
        pyramid_max_depth=args['pyramid_max_depth'],
        token_lexicon=word_lexicon,
        entities_lexicon=entities_lexicon,
        custom_tokenizer=None,
        char_vectorizer=True,
    )
    print("Instantiating test dataset")
    test_dataset = SigmoidMultiLabelNerDataset(
        test_data,
        pyramid_max_depth=args['pyramid_max_depth'],
        token_lexicon=word_lexicon,
        entities_lexicon=entities_lexicon,
        custom_tokenizer=None,
        char_vectorizer=True,
    )

    dev_dataloader, train_dataloader, test_dataloader = get_data_loader(
        train_dataset, dev_dataset, test_dataset, args)

    print("Instantiating SigmoidMultiLabelPyramid")
    pyramid_ner = SigmoidMultiLabelPyramid(
        word_lexicon=word_lexicon,
        entities_lexicon=entities_lexicon,
        classifier_type=args['classifier_type'],
        word_embeddings=args['word_embeddings'],
        language_model=args['language_model'],
        char_embeddings_dim=args['char_embeddings_dim'],
        encoder_hidden_size=args['encoder_hidden_size'],
        decoder_hidden_size=args['decoder_hidden_size'],
        inverse_pyramid=args['inverse_pyramid'],
        custom_tokenizer=args['custom_tokenizer'],
        pyramid_max_depth=args['pyramid_max_depth'],
        decoder_dropout=args['decoder_dropout'],
        encoder_dropout=args['encoder_dropout'],
        device=DEVICE)

    print(pyramid_ner.nnet)
    run_training(pyramid_ner, train_dataloader, dev_dataloader,
                 test_dataloader, args)
def run_token_window_test(args: dict):
    word_lexicon = load_vocab(args['lex_size'])

    print("Loading data")
    train_data = list(
        itertools.islice(
            wrg_token_window_reader("data/nne_raw/train/",
                                    args['token_window']), 256))
    dev_data = list(
        itertools.islice(
            wrg_token_window_reader("data/nne_raw/dev/", args['token_window']),
            64 * 10))
    test_data = list(
        itertools.islice(
            wrg_token_window_reader("data/nne_raw/test/",
                                    args['token_window']), 64))

    print("Generating entity lexicon")
    entities_lexicon = list(
        sorted({
            entity.name
            for data in (train_data, dev_data, test_data)
            for data_point in data for entity in data_point.entities
        }))
    print(entities_lexicon)

    print("Instantiating train dataset")
    train_dataset = TokenWindowMultiLabelNerDataset(
        train_data,
        pyramid_max_depth=args['pyramid_max_depth'],
        token_lexicon=word_lexicon,
        entities_lexicon=entities_lexicon,
        custom_tokenizer=None,
        char_vectorizer=True,
    )
    print("Instantiating dev dataset")
    dev_dataset = TokenWindowMultiLabelNerDataset(
        dev_data,
        pyramid_max_depth=args['pyramid_max_depth'],
        token_lexicon=word_lexicon,
        entities_lexicon=entities_lexicon,
        custom_tokenizer=None,
        char_vectorizer=True,
    )
    print("Instantiating test dataset")
    test_dataset = TokenWindowMultiLabelNerDataset(
        test_data,
        pyramid_max_depth=args['pyramid_max_depth'],
        token_lexicon=word_lexicon,
        entities_lexicon=entities_lexicon,
        custom_tokenizer=None,
        char_vectorizer=True,
    )

    dev_dataloader, train_dataloader, test_dataloader = get_data_loader(
        train_dataset, dev_dataset, test_dataset, args)
    test_document_rnn(word_lexicon, train_dataloader, dev_dataloader,
                      test_dataloader, entities_lexicon, args)

    dev_dataloader, train_dataloader, test_dataloader = get_data_loader(
        train_dataset, dev_dataset, test_dataset, args)
    test_sentence_transformer(word_lexicon, train_dataloader, dev_dataloader,
                              test_dataloader, entities_lexicon, args)