Esempio n. 1
0
    def __init__(self,
                 vocab_size,
                 embedding_size,
                 char_vocab_size,
                 char_embedding_size,
                 num_filter,
                 ngram_filter_size,
                 num_classes,
                 bert_weight_path=False):
        super().__init__()

        self.char_embedding = nn.Embedding(char_vocab_size,
                                           char_embedding_size)
        init.uniform_(self.char_embedding.weight, -0.1, 0.1)

        if bert_weight_path:
            self.bert = PretrainedBertEmbedder(bert_weight_path)
        else:
            self.embedding = nn.Embedding(vocab_size,
                                          embedding_dim=embedding_size)
            init.uniform_(self.embedding.weight, -0.1, 0.1)
            self.bert = None
        self.cnn_encoder = CnnEncoder(char_embedding_size,
                                      num_filters=num_filter,
                                      ngram_filter_sizes=ngram_filter_size)
        self.char_encoder = TokenCharactersEncoder(self.char_embedding,
                                                   self.cnn_encoder)
        if bert_weight_path:
            embedding_size = 768
        self.linear_layer = nn.Linear(embedding_size + num_filter, num_classes)
        init.xavier_normal_(self.linear_layer.weight)
Esempio n. 2
0
    def __init__(self, args, num_authors: int, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(num_authors, self.sk_dim), requires_grad=True)  # (m, d)

        self.attention = nn.Parameter(torch.randn(self.word_embeddings.get_output_dim(), self.sk_dim), requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)
        self.sigmoid = nn.Sigmoid()

        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        # self.loss = nn.CrossEntropyLoss()

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
Esempio n. 3
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.multihead_att = TempCtxAttention(h=8, d_model=self.sk_dim)

        self.attention = nn.Parameter(torch.randn(
            self.word_embeddings.get_output_dim(), self.sk_dim),
                                      requires_grad=True)
        # nn.Linear(self.word_embeddings.get_output_dim(), self.sk_dim)

        self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
Esempio n. 4
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        # self.num_sk, self.sk_dim, self.time_dim = 20, 768, 32
        self.num_sk, self.sk_dim, self.time_dim = 20, 768, 768

        # self.author_dim = self.sk_dim + self.time_dim
        self.author_dim = self.sk_dim

        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.author_dim),
                                              requires_grad=True)  # (m, d)

        # self.ctx_attention = MultiHeadCtxAttention(h=8, d_model=self.sk_dim + self.time_dim)
        self.temp_ctx_attention_ns = TempCtxAttentionNS(
            h=8,
            d_model=self.author_dim,
            d_query=self.sk_dim,
            d_time=self.time_dim)

        # temporal context
        self.time_encoder = TimeEncoder(self.time_dim,
                                        dropout=0.1,
                                        span=1,
                                        date_range=date_span)

        # layer_norm
        self.ctx_layer_norm = LayerNorm(self.author_dim)

        # loss related
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.htemp_loss = HTempLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.coherence_func = CoherenceInnerProd()
Esempio n. 5
0
 def create_module(self):
     self.embedding = PretrainedBertEmbedder(
         pretrained_model=self.pretrained_model, )
     return BasicTextFieldEmbedder({self.key: self.embedding},
                                   embedder_to_indexer_map={
                                       self.key: [
                                           self.key,
                                           "{}-offsets".format(self.key),
                                           "{}-type-ids".format(self.key)
                                       ]
                                   },
                                   allow_unmatched_keys=True)
Esempio n. 6
0
def bert_embeddings(pretrained_model: Path, training: bool = False,
                    top_layer_only: bool = True
                    ) -> BasicTextFieldEmbedder:
    "Pre-trained embeddings using BERT"
    bert = PretrainedBertEmbedder(
        requires_grad=training,
        pretrained_model=pretrained_model,
        top_layer_only=top_layer_only
    )
    word_embeddings = BasicTextFieldEmbedder(
        token_embedders={'tokens': bert},
        embedder_to_indexer_map={'tokens': ['tokens', 'tokens-offsets']},
        allow_unmatched_keys=True)
    return word_embeddings
Esempio n. 7
0
def get_pretrained_bert(model_name: str = 'bert-base-uncased',
                        cache_dir: str = PATH_ALLENNLP_CACHE,
                        top_layer_only=True,
                        requires_grad=True,
                        **kwargs):
    model_path = path.join(cache_dir, 'bert', f'{model_name}.tar.gz')
    msgex.assert_path_exist(
        path_str=model_path,
        arg_name='model_path',
        extra_msg=f"the specified BERT model '{model_name}' is not found")
    model = PretrainedBertEmbedder(pretrained_model=model_path,
                                   top_layer_only=top_layer_only,
                                   requires_grad=requires_grad,
                                   **kwargs)
    return model, model.output_dim
Esempio n. 8
0
    def __init__(self, args, out_sz: int,
                 vocab: Vocabulary):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.word_embeddings = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                      # we'll be ignoring masks so we'll need to set this to True
                                                      allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(vocab, self.word_embeddings.get_output_dim())
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.loss = nn.CrossEntropyLoss()
Esempio n. 9
0
 def __init__(
     self,
     vocab: Vocabulary,
     dropout: float,
     pool: Text = "cls",
     model_name_or_path: Text = "bert-base-uncased",
     label_namespace: str = "page_labels",
 ):
     bert = PretrainedBertEmbedder(model_name_or_path, requires_grad=True)
     super().__init__(
         vocab=vocab,
         dropout=dropout,
         label_namespace=label_namespace,
         hidden_dim=bert.get_output_dim(),
     )
     self._model_name_or_path = model_name_or_path
     self._bert = bert
     self._pool = pool
Esempio n. 10
0
    def __init__(self, num_authors: int, out_sz: int, vocab: Vocabulary,
                 date_span: Any, num_shift: int, span: int):
        super().__init__(vocab)

        # init word embedding
        bert_embedder = PretrainedBertEmbedder(
            pretrained_model="bert-base-uncased",
            top_layer_only=True,  # conserve memory
        )
        self.date_span = date_span
        self.word_embeddings = BasicTextFieldEmbedder(
            {"tokens": bert_embedder},
            # we'll be ignoring masks so we'll need to set this to True
            allow_unmatched_keys=True)

        self.encoder = BertSentencePooler(
            vocab, self.word_embeddings.get_output_dim())

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = 20, 768
        self.author_embeddings = nn.Parameter(torch.randn(
            num_authors, self.num_sk, self.sk_dim),
                                              requires_grad=True)  # (m, k, d)

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        # layer_norm
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)

        self.shift_temp_att = ShiftTempAttention(self.num_authors, self.sk_dim,
                                                 date_span, num_shift, span)

        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
Esempio n. 11
0
    def __init__(self,
                 num_authors: int,
                 out_sz: int,
                 vocab: Vocabulary,
                 date_span: Any,
                 num_shift: int,
                 spans: List,
                 encoder: Any,
                 max_vocab_size: int,
                 ignore_time: bool,
                 ns_mode: bool = False,
                 num_sk: int = 20):
        super().__init__(vocab)

        self.date_span = date_span

        self.num_authors = num_authors

        # skills dim
        self.num_sk, self.sk_dim = num_sk, 768
        self.ignore_time = ignore_time
        self.ns_mode = ns_mode
        if self.ns_mode:
            self.author_embeddings = nn.Parameter(torch.randn(
                num_authors, self.sk_dim),
                                                  requires_grad=True)  # (m, d)
        else:
            self.author_embeddings = nn.Parameter(
                torch.randn(num_authors, self.num_sk, self.sk_dim),
                requires_grad=True)  # (m, k, d)
        self.encode_type = encoder
        if self.encode_type == "bert":
            # init word embedding
            bert_embedder = PretrainedBertEmbedder(
                pretrained_model="bert-base-uncased",
                top_layer_only=True,  # conserve memory
            )
            self.word_embeddings = BasicTextFieldEmbedder(
                {"tokens": bert_embedder},
                # we'll be ignoring masks so we'll need to set this to True
                allow_unmatched_keys=True)
            self.encoder = BertSentencePooler(
                vocab, self.word_embeddings.get_output_dim())
        else:
            # prepare embeddings
            token_embedding = Embedding(num_embeddings=max_vocab_size + 2,
                                        embedding_dim=300,
                                        padding_index=0)
            self.word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder(
                {"tokens": token_embedding})

            self.encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(
                nn.LSTM(self.word_embeddings.get_output_dim(),
                        hidden_size=int(self.sk_dim / 2),
                        bidirectional=True,
                        batch_first=True))

        self.ctx_attention = TempCtxAttention(h=8, d_model=self.sk_dim)
        self.ctx_layer_norm = nn.LayerNorm(self.sk_dim)  # layer_norm

        # shifted temporal attentions
        self.spans = spans
        self.span_temp_atts = nn.ModuleList()
        for span in self.spans:
            self.span_temp_atts.append(
                ShiftTempAttention(self.num_authors, self.sk_dim, date_span,
                                   num_shift, span, self.ignore_time))
        self.span_projection = nn.Linear(len(spans), 1)
        self.num_shift = num_shift

        # temporal encoder: used only for adding temporal information into token embedding
        self.time_encoder = TimeEncoder(self.sk_dim,
                                        dropout=0.1,
                                        span=spans[0],
                                        date_range=date_span)

        # loss
        # self.cohere_loss = CoherenceLoss(self.encoder.get_output_dim(), out_sz)
        # self.triplet_loss = TripletLoss(self.encoder.get_output_dim(), out_sz)
        self.temp_loss = TemporalLoss(self.encoder.get_output_dim(), out_sz)
        self.rank_loss = MarginRankLoss(self.encoder.get_output_dim(), out_sz)

        self.weight_temp = 0.3
        self.visual_id = 0
Esempio n. 12
0
    def __init__(
        self,
        vocab: Vocabulary,
        task_type: str,
        model_type: str,
        random_init_bert:
        bool,  # set True to shuffle the BERT encoder and get random init
        initializer: InitializerApplicator = InitializerApplicator()
    ) -> None:
        super().__init__(vocab)

        assert task_type in ["unary", "binary"]  # unary or binary edges
        assert model_type in ["clf", "reg"]  # classification or regression

        self.task_type = task_type
        self.model_type = model_type

        mix_params = None

        if self.task_type == "binary":  # for binary tasks train two separate mixes
            self.bert_embedder = PretrainedBertEmbedderSplitMix(
                BERT_MODEL_NAME,
                requires_grad=False,
                top_layer_only=False,
                scalar_mix_parameters=mix_params)
        else:  # for unary task train a single mix
            self.bert_embedder = PretrainedBertEmbedder(
                BERT_MODEL_NAME,
                requires_grad=False,
                top_layer_only=False,
                scalar_mix_parameters=mix_params)

        if random_init_bert:
            self.bert_embedder.bert_model.apply(init_weights)

        self.vocab = vocab

        self.num_classes = self.vocab.get_vocab_size("labels")
        self.num_classes = self.num_classes if self.num_classes > 0 else 1

        self.span_projection_dim = self.bert_embedder.output_dim

        # represent each span by its first wordpiece token
        self.span_extractor = EndpointSpanExtractor(self.span_projection_dim,
                                                    combination="x")

        if self.task_type == "binary":
            clf_input_dim = self.span_projection_dim * 2
        else:
            clf_input_dim = self.span_projection_dim

        self.classifier = Linear(
            clf_input_dim,
            self.num_classes)  # just a linear tag projection layer

        if self.model_type == "clf":
            self.loss = torch.nn.CrossEntropyLoss(
            )  # cross-entropy for classification
        else:
            self.loss = torch.nn.SmoothL1Loss()  # smooth L1 for regresison

        self.m_acc = CategoricalAccuracy()
        self.m_fmicro = FBetaMeasure(average="micro")
        self.mse = MeanSquaredError()

        initializer(self)
Esempio n. 13
0
def main():
    parser = argparse.ArgumentParser(description='BM25 Pipeline reader')
    parser.add_argument(
        '--k',
        type=int,
        default=1,
        help=
        'number of evidence paragraphs to pick from the classifier (default: 1)'
    )
    parser.add_argument('--probs',
                        type=str,
                        default=None,
                        help='Pickled sentence probs file (default: None)')
    args = parser.parse_args()

    with torch.no_grad():
        bert_token_indexer = {
            'bert': PretrainedBertIndexer('scibert/vocab.txt', max_pieces=512)
        }

        pipeline_train = pickle.load(open('data/train_instances.p', 'rb'))
        pipeline_val = pickle.load(open('data/val_instances.p', 'rb'))
        pipeline_test = pickle.load(open('data/test_instances.p', 'rb'))

        pipeline_reader = PipelineDatasetReader(bert_token_indexer)
        p_train = pipeline_reader.read(pipeline_train)
        p_val = pipeline_reader.read(pipeline_val)
        p_test = pipeline_reader.read(pipeline_test)

        p_vocab = Vocabulary.from_instances(p_train + p_val + p_test)

        bert_token_embedding = PretrainedBertEmbedder('scibert/weights.tar.gz',
                                                      requires_grad=False)

        word_embeddings = BasicTextFieldEmbedder(
            {"bert": bert_token_embedding}, {"bert": ['bert']},
            allow_unmatched_keys=True)

        predictor = Oracle(word_embeddings=word_embeddings, vocab=p_vocab)

        cuda_device = 0

        if torch.cuda.is_available():
            predictor = predictor.cuda()
        else:
            cuda_device = -1

        predictor.load_state_dict(
            torch.load('model_checkpoints/f_oracle_full/best.th'))

        logger.info('Predictor model loaded successfully')
        predictor.eval()

        iterator = BasicIterator(batch_size=256)
        iterator.index_with(p_vocab)

        top_k_sentences = []
        prob_counter = 0
        for i in range(len(pipeline_test)):
            sentences = [
                ' '.join(pipeline_test[i]['sentence_span'][k][0] +
                         pipeline_test[i]['sentence_span'][k + 1][0] +
                         pipeline_test[i]['sentence_span'][k + 2][0]).lower().
                split()
                for k in range(len(pipeline_test[i]['sentence_span']) - 2)
            ]

            bm25 = BM25Okapi(sentences)

            prompt = pipeline_test[i]['I'] + pipeline_test[i]['C'] + pipeline_test[i]['O'] + \
                     ['no', 'significant', 'difference']

            doc_scores = np.array(bm25.get_scores(prompt))

            probs = list(doc_scores)
            prob_counter += len(sentences)
            sorted_sentences = sorted(zip(sentences, probs),
                                      key=lambda x: x[1],
                                      reverse=True)
            top_k = [' '.join(s[0]) for s in sorted_sentences[:args.k]]
            top_k_sentences.append({
                'I': pipeline_test[i]['I'],
                'C': pipeline_test[i]['C'],
                'O': pipeline_test[i]['O'],
                'y_label': pipeline_test[i]['y'][0][0],
                'evidence': ' '.join(top_k)
            })

        logger.info('Obtained the top sentences from the bm25 classifier')

        predictor_reader = EIDatasetReader(bert_token_indexer)
        predictor_test = predictor_reader.read(top_k_sentences)

        test_metrics = evaluate(predictor,
                                predictor_test,
                                iterator,
                                cuda_device=cuda_device,
                                batch_weight_key="")

        print('Test Data statistics:')
        for key, value in test_metrics.items():
            print(str(key) + ': ' + str(value))
Esempio n. 14
0
def main():
    parser = create_parser()
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    model_id = create_model_id(args)

    if not path.exists(args.out_dir):
        print("# Create directory: {}".format(args.out_dir))
        os.mkdir(args.out_dir)

    # log file
    out_dir = path.join(args.out_dir, "out-" + model_id)
    print("# Create output directory: {}".format(out_dir))
    os.mkdir(out_dir)
    log = StandardLogger(path.join(out_dir, "log-" + model_id + ".txt"))
    log.write(args=args)
    write_args_log(args, path.join(out_dir, "args.json"))

    # dataset reader
    token_indexers = {
        "tokens": SingleIdTokenIndexer(),
        "elmo": ELMoTokenCharactersIndexer(),
        "bert": PretrainedBertIndexer(BERT_MODEL, use_starting_offsets=True),
        "xlnet": PretrainedTransformerIndexer(XLNET_MODEL, do_lowercase=False)
    }

    reader = SrlDatasetReader(token_indexers)

    # dataset
    train_dataset = reader.read_with_ratio(args.train, args.data_ratio)
    validation_dataset = reader.read_with_ratio(args.dev, 100)
    pseudo_dataset = reader.read_with_ratio(
        args.pseudo, args.data_ratio) if args.pseudo else []
    all_dataset = train_dataset + validation_dataset + pseudo_dataset
    if args.test:
        test_dataset = reader.read_with_ratio(args.test, 100)
        all_dataset += test_dataset

    vocab = Vocabulary.from_instances(all_dataset)

    # embedding
    input_size = args.binary_dim * 2 if args.multi_predicate else args.binary_dim
    if args.glove:
        token_embedding = Embedding(
            num_embeddings=vocab.get_vocab_size('tokens'),
            embedding_dim=GLOVE_DIM,
            trainable=True,
            pretrained_file=GLOVE)
        input_size += GLOVE_DIM
    else:
        token_embedding = Embedding(
            num_embeddings=vocab.get_vocab_size('tokens'),
            embedding_dim=args.embed_dim,
            trainable=True)
        input_size += args.embed_dim
    token_embedders = {"tokens": token_embedding}

    if args.elmo:
        elmo_embedding = ElmoTokenEmbedder(options_file=ELMO_OPT,
                                           weight_file=ELMO_WEIGHT)
        token_embedders["elmo"] = elmo_embedding
        input_size += ELMO_DIM

    if args.bert:
        bert_embedding = PretrainedBertEmbedder(BERT_MODEL)
        token_embedders["bert"] = bert_embedding
        input_size += BERT_DIM

    if args.xlnet:
        xlnet_embedding = PretrainedTransformerEmbedder(XLNET_MODEL)
        token_embedders["xlnet"] = xlnet_embedding
        input_size += XLNET_DIM

    word_embeddings = BasicTextFieldEmbedder(token_embedders=token_embedders,
                                             allow_unmatched_keys=True,
                                             embedder_to_indexer_map={
                                                 "bert":
                                                 ["bert", "bert-offsets"],
                                                 "elmo": ["elmo"],
                                                 "tokens": ["tokens"],
                                                 "xlnet": ["xlnet"]
                                             })
    # encoder
    if args.highway:
        lstm = PytorchSeq2SeqWrapper(
            StackedAlternatingLstm(input_size=input_size,
                                   hidden_size=args.hidden_dim,
                                   num_layers=args.n_layers,
                                   recurrent_dropout_probability=args.dropout))
    else:
        pytorch_lstm = torch.nn.LSTM(input_size=input_size,
                                     hidden_size=args.hidden_dim,
                                     num_layers=int(args.n_layers / 2),
                                     batch_first=True,
                                     dropout=args.dropout,
                                     bidirectional=True)
        # initialize
        for name, param in pytorch_lstm.named_parameters():
            if 'weight_ih' in name:
                torch.nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                # Wii, Wif, Wic, Wio
                for n in range(4):
                    torch.nn.init.orthogonal_(
                        param.data[args.hidden_dim * n:args.hidden_dim *
                                   (n + 1)])
            elif 'bias' in name:
                param.data.fill_(0)

        lstm = PytorchSeq2SeqWrapper(pytorch_lstm)

    # model
    hidden_dim = args.hidden_dim if args.highway else args.hidden_dim * 2  # pytorch.nn.LSTMはconcatされるので2倍
    model = SemanticRoleLabelerWithAttention(
        vocab=vocab,
        text_field_embedder=word_embeddings,
        encoder=lstm,
        binary_feature_dim=args.binary_dim,
        embedding_dropout=args.embed_dropout,
        attention_dropout=0.0,
        use_attention=args.attention,
        use_multi_predicate=args.multi_predicate,
        hidden_dim=hidden_dim)

    if args.model:
        print("# Load model parameter: {}".format(args.model))
        with open(args.model, 'rb') as f:
            state_dict = torch.load(f, map_location='cpu')
            model.load_state_dict(state_dict)

    if torch.cuda.is_available():
        cuda_device = 0
        model = model.cuda(cuda_device)
    else:
        cuda_device = -1

    # optimizer
    if args.optimizer == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    elif args.optimizer == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
    elif args.optimizer == "Adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95)
    else:
        raise ValueError("unsupported value: '{}'".format(args.optimizer))

    # iterator
    # iterator = BucketIterator(batch_size=args.batch, sorting_keys=[("tokens", "num_tokens")])
    iterator = BasicIterator(batch_size=args.batch)
    iterator.index_with(vocab)

    if not args.test_only:
        # Train
        print("# Train Method: {}".format(args.train_method))
        print("# Start Train", flush=True)
        if args.train_method == "concat":
            trainer = Trainer(model=model,
                              optimizer=optimizer,
                              iterator=iterator,
                              train_dataset=train_dataset + pseudo_dataset,
                              validation_dataset=validation_dataset,
                              validation_metric="+f1-measure-overall",
                              patience=args.early_stopping,
                              num_epochs=args.max_epoch,
                              num_serialized_models_to_keep=5,
                              grad_clipping=args.grad_clipping,
                              serialization_dir=out_dir,
                              cuda_device=cuda_device)
            trainer.train()
        elif args.train_method == "pre-train":
            pre_train_out_dir = path.join(out_dir + "pre-train")
            fine_tune_out_dir = path.join(out_dir + "fine-tune")
            os.mkdir(pre_train_out_dir)
            os.mkdir(fine_tune_out_dir)

            trainer = Trainer(model=model,
                              optimizer=optimizer,
                              iterator=iterator,
                              train_dataset=pseudo_dataset,
                              validation_dataset=validation_dataset,
                              validation_metric="+f1-measure-overall",
                              patience=args.early_stopping,
                              num_epochs=args.max_epoch,
                              num_serialized_models_to_keep=3,
                              grad_clipping=args.grad_clipping,
                              serialization_dir=pre_train_out_dir,
                              cuda_device=cuda_device)
            trainer.train()

            if args.optimizer == "Adam":
                optimizer = torch.optim.Adam(model.parameters(),
                                             lr=args.learning_rate)
            elif args.optimizer == "SGD":
                optimizer = torch.optim.SGD(model.parameters(),
                                            lr=args.learning_rate)
            elif args.optimizer == "Adadelta":
                optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95)
            else:
                raise ValueError("unsupported value: '{}'".format(
                    args.optimizer))
            trainer = Trainer(model=model,
                              optimizer=optimizer,
                              iterator=iterator,
                              train_dataset=train_dataset,
                              validation_dataset=validation_dataset,
                              validation_metric="+f1-measure-overall",
                              patience=args.early_stopping,
                              num_epochs=args.max_epoch,
                              num_serialized_models_to_keep=3,
                              grad_clipping=args.grad_clipping,
                              serialization_dir=fine_tune_out_dir,
                              cuda_device=cuda_device)
            trainer.train()
        else:
            raise ValueError("Unsupported Value '{}'".format(
                args.train_method))

    # Test
    if args.test:
        print("# Test")
        result = evaluate(model=model,
                          instances=test_dataset,
                          data_iterator=iterator,
                          cuda_device=cuda_device,
                          batch_weight_key="")
        with open(path.join(out_dir, "test.score"), 'w') as fo:
            json.dump(result, fo)

    log.write_endtime()
Esempio n. 15
0
def main():
    parser = argparse.ArgumentParser(description='Evidence sentence classifier')
    parser.add_argument('--k', type=int, default=1,
                        help='number of evidence paragraphs to pick from the classifier (default: 1)')
    parser.add_argument('--probs', type=str, default=None,
                        help='Pickled sentence probs file (default: None)')
    args = parser.parse_args()

    with torch.no_grad():
        bert_token_indexer = {'bert': PretrainedBertIndexer('scibert/vocab.txt', max_pieces=512)}

        pipeline_train = pickle.load(open('data/train_instances.p', 'rb'))
        pipeline_val = pickle.load(open('data/val_instances.p', 'rb'))
        pipeline_test = pickle.load(open('data/test_instances.p', 'rb'))

        pipeline_reader = PipelineDatasetReader(bert_token_indexer)
        p_train = pipeline_reader.read(pipeline_train)
        p_val = pipeline_reader.read(pipeline_val)
        p_test = pipeline_reader.read(pipeline_test)

        p_vocab = Vocabulary.from_instances(p_train + p_val + p_test)

        bert_token_embedding = PretrainedBertEmbedder(
            'scibert/weights.tar.gz', requires_grad=False
        )

        word_embeddings = BasicTextFieldEmbedder(
            {"bert": bert_token_embedding},
            {"bert": ['bert']},
            allow_unmatched_keys=True
        )

        ev_classifier = Classifier(word_embeddings=word_embeddings,
                                   vocab=p_vocab,
                                   loss='bce',
                                   hinge_margin=0)
        predictor = Oracle(word_embeddings=word_embeddings,
                           vocab=p_vocab)

        cuda_device = 0

        if torch.cuda.is_available():
            ev_classifier = ev_classifier.cuda()
            predictor = predictor.cuda()
        else:
            cuda_device = -1

        ev_classifier.load_state_dict(torch.load('model_checkpoints/f_evidence_sentence_classifier_para/best.th'))
        predictor.load_state_dict(torch.load('model_checkpoints/f_oracle_full/best.th'))

        logger.info('Classifier and Predictor models loaded successfully')
        ev_classifier.eval()
        predictor.eval()

        iterator = BasicIterator(batch_size=256)
        iterator.index_with(p_vocab)

        if args.probs is None:
            iterator_obj = iterator(p_test, num_epochs=1, shuffle=False)
            generator_tqdm = Tqdm.tqdm(iterator_obj, total=iterator.get_num_batches(p_test))

            output_probs = []
            for batch in generator_tqdm:
                batch = nn_util.move_to_device(batch, cuda_device)
                probs = ev_classifier.predict_evidence_probs(**batch)
                probs = probs.cpu().numpy()
                output_probs.append(probs)

            output_probs = [i for item in output_probs for i in item]
            logger.info('Obtained all sentence evidence probabilities - total {}'.format(len(output_probs)))
            pickle.dump(output_probs, open('sentence_ev_probs.p', 'wb'))

        else:
            output_probs = pickle.load(open(args.probs, 'rb'))

        top_k_sentences = []
        prob_counter = 0
        for i in range(len(pipeline_test)):
            sentences = [' '.join(pipeline_test[i]['sentence_span'][k][0] + pipeline_test[i]['sentence_span'][k + 1][0]
                                  + pipeline_test[i]['sentence_span'][k + 2][0])
                         for k in range(len(pipeline_test[i]['sentence_span']) - 2)]
            probs = list(output_probs[prob_counter: prob_counter + len(sentences)])
            prob_counter += len(sentences)
            sorted_sentences = sorted(zip(sentences, probs), key=lambda x: x[1], reverse=True)
            top_k = [s[0] for s in sorted_sentences[:args.k]]
            top_k_sentences.append({'I': pipeline_test[i]['I'],
                                    'C': pipeline_test[i]['C'],
                                    'O': pipeline_test[i]['O'],
                                    'y_label': pipeline_test[i]['y'][0][0],
                                    'evidence': ' '.join(top_k)})

        logger.info('Obtained the top sentences from the evidence classifier')

        predictor_reader = EIDatasetReader(bert_token_indexer)
        predictor_test = predictor_reader.read(top_k_sentences)

        test_metrics = evaluate(predictor, predictor_test, iterator,
                                cuda_device=cuda_device,
                                batch_weight_key="")

        print('Test Data statistics:')
        for key, value in test_metrics.items():
            print(str(key) + ': ' + str(value))
Esempio n. 16
0
def main():
    parser = argparse.ArgumentParser(description='Evidence Inference experiments')
    parser.add_argument('--cuda_device', type=int, default=0,
                        help='GPU number (default: 0)')
    parser.add_argument('--epochs', type=int, default=2,
                        help='upper epoch limit (default: 2)')
    parser.add_argument('--patience', type=int, default=1,
                        help='trainer patience  (default: 1)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='batch size (default: 32)')
    parser.add_argument('--dropout', type=float, default=0.2,
                        help='dropout for the model (default: 0.2)')
    parser.add_argument('--model_name', type=str, default='baseline',
                        help='model name (default: baseline)')
    parser.add_argument('--tunable', action='store_true',
                        help='tune the underlying embedding model (default: False)')
    args = parser.parse_args()

    annotations = pd.read_csv('data/data/annotations_merged.csv')
    prompts = pd.read_csv('data/data/prompts_merged.csv')

    feature_dictionary = {}
    prompts_dictionary = {}

    for index, row in prompts.iterrows():
        prompts_dictionary[row['PromptID']] = [row['Outcome'], row['Intervention'], row['Comparator']]

    for index, row in annotations.iterrows():
        if row['PMCID'] not in feature_dictionary:
            feature_dictionary[row['PMCID']] = []
        feature_dictionary[row['PMCID']].append([row['Annotations'], row['Label']]
                                                + prompts_dictionary[row['PromptID']])

    train = []
    valid = []
    test = []

    with open('data/splits/train_article_ids.txt') as train_file:
        for line in train_file:
            train.append(int(line.strip()))

    with open('data/splits/validation_article_ids.txt') as valid_file:
        for line in valid_file:
            valid.append(int(line.strip()))

    with open('data/splits/test_article_ids.txt') as test_file:
        for line in test_file:
            test.append(int(line.strip()))

    bert_token_indexer = {'bert': PretrainedBertIndexer('scibert/vocab.txt', max_pieces=512)}

    reader = EIDatasetReader(bert_token_indexer, feature_dictionary)
    train_data = reader.read(train)
    valid_data = reader.read(valid)
    test_data = reader.read(test)

    vocab = Vocabulary.from_instances(train_data + valid_data + test_data)

    bert_token_embedding = PretrainedBertEmbedder(
        'scibert/weights.tar.gz', requires_grad=args.tunable
    )

    word_embeddings = BasicTextFieldEmbedder(
        {"bert": bert_token_embedding},
        {"bert": ['bert']},
        allow_unmatched_keys=True
    )

    model = Baseline(word_embeddings, vocab)

    cuda_device = args.cuda_device

    if torch.cuda.is_available():
        model = model.cuda(cuda_device)
    else:
        cuda_device = -1

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    iterator = BucketIterator(batch_size=args.batch_size,
                              sorting_keys=[('intervention', 'num_tokens')],
                              padding_noise=0.1)
    iterator.index_with(vocab)

    serialization_dir = 'model_checkpoints/' + args.model_name

    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      iterator=iterator,
                      train_dataset=train_data,
                      validation_dataset=test_data,
                      patience=args.patience,
                      validation_metric='+accuracy',
                      num_epochs=args.epochs,
                      cuda_device=cuda_device,
                      serialization_dir=serialization_dir)

    result = trainer.train()
    for key in result:
        print(str(key) + ': ' + str(result[key]))

    test_metrics = evaluate(trainer.model, test_data, iterator,
                            cuda_device=cuda_device,
                            batch_weight_key="")

    print('Test Data statistics:')
    for key, value in test_metrics.items():
        print(str(key) + ': ' + str(value))
Esempio n. 17
0
def main():
    parser = argparse.ArgumentParser(description='Evidence oracle QA')
    parser.add_argument('--epochs', type=int, default=5,
                        help='upper epoch limit (default: 5)')
    parser.add_argument('--patience', type=int, default=1,
                        help='trainer patience  (default: 1)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='batch size (default: 32)')
    parser.add_argument('--model_name', type=str, default='sentence_oracle_bert',
                        help='model name (default: sentence_oracle_bert)')
    parser.add_argument('--tunable', action='store_true',
                        help='tune the underlying embedding model (default: False)')
    parser.add_argument('--ev_type', type=str, default='sentence',
                        help='how to train the oracle - sentence or full (evidence) (default: sentence)')
    args = parser.parse_args()

    if args.ev_type == 'sentence':
        train = pickle.load(open('data/oracle_train.p', 'rb'))
        valid = pickle.load(open('data/oracle_val.p', 'rb'))
        test = pickle.load(open('data/oracle_test.p', 'rb'))
    elif args.ev_type == 'full':
        train = pickle.load(open('data/oracle_full_train.p', 'rb'))
        valid = pickle.load(open('data/oracle_full_val.p', 'rb'))
        test = pickle.load(open('data/oracle_full_test.p', 'rb'))
    else:
        print('ev_type should be either sentence or full')
        return

    bert_token_indexer = {'bert': PretrainedBertIndexer('scibert/vocab.txt', max_pieces=512)}

    pipeline_train = pickle.load(open('data/train_instances.p', 'rb'))
    pipeline_val = pickle.load(open('data/val_instances.p', 'rb'))
    pipeline_test = pickle.load(open('data/test_instances.p', 'rb'))

    pipeline_reader = PipelineDatasetReader(bert_token_indexer)
    p_train = pipeline_reader.read(pipeline_train)
    p_val = pipeline_reader.read(pipeline_val)
    p_test = pipeline_reader.read(pipeline_test)

    p_vocab = Vocabulary.from_instances(p_train + p_val + p_test)

    reader = EIDatasetReader(bert_token_indexer)
    train_data = reader.read(train)
    valid_data = reader.read(valid)
    test_data = reader.read(test)

    bert_token_embedding = PretrainedBertEmbedder(
        'scibert/weights.tar.gz', requires_grad=args.tunable
    )

    word_embeddings = BasicTextFieldEmbedder(
        {"bert": bert_token_embedding},
        {"bert": ['bert']},
        allow_unmatched_keys=True
    )

    model = Oracle(word_embeddings, p_vocab)

    cuda_device = list(range(torch.cuda.device_count()))

    if torch.cuda.is_available():
        model = model.cuda()
    else:
        cuda_device = -1

    t_total = len(train_data) // args.epochs

    optimizer = BertAdam(model.parameters(), lr=1e-5, warmup=0.05, t_total=t_total)

    iterator = BucketIterator(batch_size=args.batch_size,
                              sorting_keys=[('comb_prompt_ev', 'num_tokens')],
                              padding_noise=0.1)
    iterator.index_with(p_vocab)

    serialization_dir = 'model_checkpoints/' + args.model_name

    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      iterator=iterator,
                      train_dataset=train_data,
                      validation_dataset=valid_data,
                      patience=args.patience,
                      validation_metric='+accuracy',
                      num_epochs=args.epochs,
                      cuda_device=cuda_device,
                      serialization_dir=serialization_dir)

    result = trainer.train()
    for key in result:
        print(str(key) + ': ' + str(result[key]))

    if cuda_device != -1:
        cuda_device = 0
    test_metrics = evaluate(trainer.model, test_data, iterator,
                            cuda_device=cuda_device,
                            batch_weight_key="")

    print('Test Data statistics:')
    for key, value in test_metrics.items():
        print(str(key) + ': ' + str(value))
Esempio n. 18
0
    iterator = BucketIterator(
        batch_size=batch_size,
        # This is for testing. To see how big of batch size the GPU can handle.
        biggest_batch_first=True,
        sorting_keys=[("tokens", "num_tokens")],
    )
    iterator.index_with(vocab)

    linguistic_features_embedding = Embedding(
        num_embeddings=max_vocab_size + 2,
        embedding_dim=linguistic_features_embedding_dim,
        # padding_index=0 I do not understand what is does
    )
    bert_embedder = PretrainedBertEmbedder(
        pretrained_model=bert_mode,
        top_layer_only=False,
        requires_grad=bert_finetuning,
    )
    word_embedder = BasicTextFieldEmbedder(
        {
            "bert": bert_embedder,
            "deps": linguistic_features_embedding,
            "ner": linguistic_features_embedding,
            "pos": linguistic_features_embedding,
            "lang": linguistic_features_embedding,
        }, {
            "bert": {
                "input_ids": "bert",
                "offsets": "bert-offsets"
            },
            "deps": {
def get_model(pretrained_file: str, WORD_EMB_DIM: int, vocab: Vocabulary,
              num_tags: int):
    """
    This creates a new model and returns it along with some other variables.
    :param pretrained_file:
    :param WORD_EMB_DIM:
    :param vocab:
    :param num_tags:
    :return:
    """

    CNN_EMB_DIM = 128
    CHAR_EMB_DIM = 16

    weight = _read_pretrained_embeddings_file(pretrained_file, WORD_EMB_DIM,
                                              vocab, "tokens")
    token_embedding = Embedding(num_embeddings=weight.shape[0],
                                embedding_dim=weight.shape[1],
                                weight=weight,
                                vocab_namespace="tokens")
    char_embedding = Embedding(
        num_embeddings=vocab.get_vocab_size("token_characters"),
        embedding_dim=CHAR_EMB_DIM,
        vocab_namespace="token_characters")

    char_encoder = CnnEncoder(
        embedding_dim=CHAR_EMB_DIM,
        num_filters=CNN_EMB_DIM,
        ngram_filter_sizes=[3],
        conv_layer_activation=Activation.by_name("relu")())
    token_characters_embedding = TokenCharactersEncoder(
        embedding=char_embedding, encoder=char_encoder)

    if USING_BERT:
        print("USING BERT EMBEDDINGS")
        bert_emb = PretrainedBertEmbedder("bert-base-multilingual-cased")
        tfe = BasicTextFieldEmbedder(
            {
                "bert": bert_emb,
                "token_characters": token_characters_embedding
            },
            embedder_to_indexer_map={
                "bert": ["bert", "bert-offsets"],
                "token_characters": ["token_characters"]
            },
            allow_unmatched_keys=True)

        EMBEDDING_DIM = CNN_EMB_DIM + 768
    else:
        EMBEDDING_DIM = CNN_EMB_DIM + WORD_EMB_DIM
        tfe = BasicTextFieldEmbedder({
            "tokens":
            token_embedding,
            "token_characters":
            token_characters_embedding
        })

    HIDDEN_DIM = 256

    encoder = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(EMBEDDING_DIM,
                      HIDDEN_DIM,
                      batch_first=True,
                      bidirectional=True,
                      dropout=0.5,
                      num_layers=2))

    model = MarginalCrfTagger(vocab,
                              tfe,
                              encoder,
                              num_tags,
                              include_start_end_transitions=False,
                              calculate_span_f1=True,
                              dropout=0.5,
                              label_encoding="BIOUL",
                              constrain_crf_decoding=True)

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if torch.cuda.is_available():
        print("Using GPU")
        cuda_device = 0
        model = model.cuda(cuda_device)
    else:
        cuda_device = -1

    return model, optimizer, cuda_device
dev_dataset = reader.read('data/stanfordSentimentTreebank/trees/dev.txt')

# You can optionally specify the minimum count of tokens/labels.
# `min_count={'tokens':3}` here means that any tokens that appear less than three times
# will be ignored and not included in the vocabulary.
vocab = Vocabulary.from_instances(train_dataset + dev_dataset,
                                  min_count={'tokens': 3})

# token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
#                             embedding_dim=EMBEDDING_DIM)

# BasicTextFieldEmbedder takes a dict - we need an embedding just for tokens,
# not for labels, which are used as-is as the "answer" of the sentence classification
# word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

bert_embedder = PretrainedBertEmbedder(pretrained_model='bert-base-uncased',
                                       top_layer_only=True)

word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder(
    {"tokens": bert_embedder}, allow_unmatched_keys=True)

# Seq2VecEncoder is a neural network abstraction that takes a sequence of something
# (usually a sequence of embedded word vectors), processes it, and returns a single
# vector. Oftentimes this is an RNN-based architecture (e.g., LSTM or GRU), but
# AllenNLP also supports CNNs and other simple architectures (for example,
# just averaging over the input vectors).
# encoder = PytorchSeq2VecWrapper(
#     torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

# HW
encoder = TransformerSeq2VecEncoder(EMBEDDING_DIM,
                                    HIDDEN_DIM,
Esempio n. 21
0
    def __init__(
        self,
        vocab: Vocabulary,
        attention: Attention,
        beam_size: int,
        max_decoding_steps: int,
        target_embedding_dim: int = 30,
        copy_token: str = "@COPY@",
        source_namespace: str = "bert",
        target_namespace: str = "target_tokens",
        tensor_based_metric: Metric = None,
        token_based_metric: Metric = None,
        initializer: InitializerApplicator = InitializerApplicator(),
    ) -> None:
        super().__init__(vocab)
        self._source_namespace = source_namespace
        self._target_namespace = target_namespace
        self._src_start_index = self.vocab.get_token_index(
            START_SYMBOL, self._source_namespace)
        self._src_end_index = self.vocab.get_token_index(
            END_SYMBOL, self._source_namespace)
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._oov_index = self.vocab.get_token_index(self.vocab._oov_token,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)
        self._copy_index = self.vocab.add_token_to_namespace(
            copy_token, self._target_namespace)

        self._tensor_based_metric = tensor_based_metric or BLEU(
            exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        self._token_based_metric = token_based_metric

        self._target_vocab_size = self.vocab.get_vocab_size(
            self._target_namespace)

        # Encoding modules.
        bert_token_embedding = PretrainedBertEmbedder('bert-base-uncased',
                                                      requires_grad=True)

        self._source_embedder = bert_token_embedding
        self._encoder = PassThroughEncoder(
            input_dim=self._source_embedder.get_output_dim())

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        # We arbitrarily set the decoder's input dimension to be the same as the output dimension.
        self.encoder_output_dim = self._encoder.get_output_dim()
        self.decoder_output_dim = self.encoder_output_dim
        self.decoder_input_dim = self.decoder_output_dim

        target_vocab_size = self.vocab.get_vocab_size(self._target_namespace)

        # The decoder input will be a function of the embedding of the previous predicted token,
        # an attended encoder hidden state called the "attentive read", and another
        # weighted sum of the encoder hidden state called the "selective read".
        # While the weights for the attentive read are calculated by an `Attention` module,
        # the weights for the selective read are simply the predicted probabilities
        # corresponding to each token in the source sentence that matches the target
        # token from the previous timestep.
        self._target_embedder = Embedding(target_vocab_size,
                                          target_embedding_dim)
        self._attention = attention
        self._input_projection_layer = Linear(
            target_embedding_dim + self.encoder_output_dim * 2,
            self.decoder_input_dim)

        # We then run the projected decoder input through an LSTM cell to produce
        # the next hidden state.
        self._decoder_cell = LSTMCell(self.decoder_input_dim,
                                      self.decoder_output_dim)

        # We create a "generation" score for each token in the target vocab
        # with a linear projection of the decoder hidden state.
        self._output_generation_layer = Linear(self.decoder_output_dim,
                                               target_vocab_size)

        # We create a "copying" score for each source token by applying a non-linearity
        # (tanh) to a linear projection of the encoded hidden state for that token,
        # and then taking the dot product of the result with the decoder hidden state.
        self._output_copying_layer = Linear(self.encoder_output_dim,
                                            self.decoder_output_dim)

        # At prediction time, we'll use a beam search to find the best target sequence.
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        initializer(self)
Esempio n. 22
0
def main():
    parser = argparse.ArgumentParser(
        description='Evidence sentence classifier')
    parser.add_argument('--epochs',
                        type=int,
                        default=5,
                        help='upper epoch limit (default: 5)')
    parser.add_argument('--patience',
                        type=int,
                        default=1,
                        help='trainer patience  (default: 1)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=8,
                        help='batch size (default: 8)')
    parser.add_argument(
        '--loss',
        type=str,
        default='hinge',
        help=
        'loss function to train the model - choose bce or hinge (default: hinge)'
    )
    parser.add_argument(
        '--hinge_margin',
        type=float,
        default=0.5,
        help='the margin for the hinge loss, if used (default: 0.5)')
    parser.add_argument('--model_name',
                        type=str,
                        default='ev_classifier_bert',
                        help='model name (default: ev_classifier_bert)')
    parser.add_argument(
        '--tunable',
        action='store_true',
        help='tune the underlying embedding model (default: False)')
    args = parser.parse_args()

    if args.loss not in ['bce', 'hinge']:
        print('Loss must be bce or hinge')
        return

    bert_token_indexer = {
        'bert': PretrainedBertIndexer('scibert/vocab.txt', max_pieces=512)
    }

    pipeline_train = pickle.load(open('data/train_instances.p', 'rb'))
    pipeline_val = pickle.load(open('data/val_instances.p', 'rb'))
    pipeline_test = pickle.load(open('data/test_instances.p', 'rb'))

    pipeline_reader = PipelineDatasetReader(bert_token_indexer)
    p_train = pipeline_reader.read(pipeline_train)
    p_val = pipeline_reader.read(pipeline_val)
    p_test = pipeline_reader.read(pipeline_test)

    p_vocab = Vocabulary.from_instances(p_train + p_val + p_test)

    classifier_train = pickle.load(open('data/classifier_train.p', 'rb'))
    classifier_val = pickle.load(open('data/classifier_val.p', 'rb'))

    reader = EvidenceDatasetReader(bert_token_indexer)
    train_data = reader.read(classifier_train)
    valid_data = reader.read(classifier_val)

    bert_token_embedding = PretrainedBertEmbedder('scibert/weights.tar.gz',
                                                  requires_grad=args.tunable)

    word_embeddings = BasicTextFieldEmbedder({"bert": bert_token_embedding},
                                             {"bert": ['bert']},
                                             allow_unmatched_keys=True)

    model = Classifier(word_embeddings=word_embeddings,
                       vocab=p_vocab,
                       loss=args.loss,
                       hinge_margin=args.hinge_margin)

    cuda_device = list(range(torch.cuda.device_count()))

    if torch.cuda.is_available():
        model = model.cuda()
    else:
        cuda_device = -1

    t_total = len(train_data) // args.epochs

    optimizer = BertAdam(model.parameters(),
                         lr=2e-5,
                         warmup=0.1,
                         t_total=t_total)

    iterator = BucketIterator(batch_size=args.batch_size,
                              sorting_keys=[('comb_evidence', 'num_tokens')],
                              padding_noise=0.1,
                              biggest_batch_first=True)
    iterator.index_with(p_vocab)

    serialization_dir = 'model_checkpoints/' + args.model_name

    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        iterator=iterator,
        train_dataset=train_data,
        validation_dataset=valid_data,
        patience=args.patience,
        validation_metric='+accuracy',
        num_epochs=args.epochs,
        cuda_device=cuda_device,
        # learning_rate_scheduler=scheduler,
        serialization_dir=serialization_dir)

    result = trainer.train()
    for key in result:
        print(str(key) + ': ' + str(result[key]))