def make_masks(self):
     """
     :rtype: Variable(shape=(1, bilstm_dim), dtype=np.float32), Variable(shape=(1, bilstm_dim), dtype=np.float32)
     """
     mask_bwd = utils.convert_ndarray_to_variable(self.mask_bwd, seq=False) # (1, bilstm_dim)
     mask_fwd = utils.convert_ndarray_to_variable(self.mask_fwd, seq=False) # (1, bilstm_dim)
     return mask_bwd, mask_fwd
 def pad_edu_vectors(self, edu_vectors):
     """
     :type edu_vectors: Variable(shape=(n_edus, bilstm_dim), dtype=np.float32)
     :rtype: Variable(shape=(n_edus+2, bilstm_dim), dtype=np.float32)
     """
     start_id = utils.convert_ndarray_to_variable(self.START_ID, seq=False) # (1,)
     stop_id = utils.convert_ndarray_to_variable(self.STOP_ID, seq=False) # (1,)
     start_vector = self.embed_boundary(start_id) # (1, bilstm_dim)
     stop_vector = self.embed_boundary(stop_id) # (1, bilstm_dim)
     padded_edu_vectors = F.vstack([start_vector, edu_vectors, stop_vector]) # (n_edus+2, bilstm_dim)
     return padded_edu_vectors
Esempio n. 3
0
    def forward(self, xs, train):
        """
        xs: [[int=ID]]
        ys: T x (N,T)
        """
        xs, ms = utils.padding(xs, head=True, with_mask=True)  # (N, T), (N, T)
        xs = utils.convert_ndarray_to_variable(xs, seq=True,
                                               train=train)  # T x (N,)
        ms = utils.convert_ndarray_to_variable(ms, seq=True,
                                               train=train)  # T x (N,)

        es = self.embed_words(xs, train=train)
        g = self.aggregate(es, ms)
        # Note: Here, we assume that "es" follows the original order
        ys, outputs = self.reorder(es, g, ms, train=train)  # bottleneck
        return ys, outputs
    def compute_span_vectors(
                self,
                edus,
                edus_postag,
                sbnds,
                pbnds,
                padded_edu_vectors,
                mask_bwd,
                mask_fwd,
                batch_spans):
        """
        :type edus: list of list of str
        :type edus_postag: list of list of str
        :type sbnds: list of (int, int)
        :type pbnds: list of (int, int)
        :type padded_edu_vectors: Variable(shape=(n_edus+2, bilstm_dim), dtype=np.float32)
        :type mask_bwd: Variable(shape=(1, bilstm_dim), dtype=np.float32)
        :type mask_fwd: Variable(shape=(1, bilstm_dim), dtype=np.float32)
        :type batch_spans: list of list of (int, int)
        :rtype: Variable(shape=(batch_size * n_spans, bilstm_dim + tempfeat_dim), dtype=np.float32)
        """
        batch_size = len(batch_spans)
        n_spans = len(batch_spans[0])
        total_spans = batch_size * n_spans
        for spans in batch_spans:
            assert len(spans) == n_spans

        # Reshape
        flatten_batch_spans = utils.flatten_lists(batch_spans) # total_spans * (int, int)
        # NOTE that indices in batch_spans should be shifted by +1 due to the boundary padding
        bm1_indices = [(b-1)+1 for b,e in flatten_batch_spans] # total_spans * int
        b_indices = [b+1 for b,e in flatten_batch_spans] # total_spans * int
        e_indices = [e+1 for b,e in flatten_batch_spans] # total_spans * int
        ep1_indices = [(e+1)+1 for b,e in flatten_batch_spans] # total_spans * int

        # Feature extraction
        bm1_padded_edu_vectors = F.get_item(padded_edu_vectors, bm1_indices) # (total_spans, bilstm_dim)
        b_padded_edu_vectors = F.get_item(padded_edu_vectors, b_indices) # (total_spans, bilstm_dim)
        e_padded_edu_vectors = F.get_item(padded_edu_vectors, e_indices) # (total_spans, bilstm_dim)
        ep1_padded_edu_vectors = F.get_item(padded_edu_vectors, ep1_indices) # (total_spans, bilstm_dim)
        mask_bwd = F.broadcast_to(mask_bwd, (total_spans, self.bilstm_dim)) # (total_spans, bilstm_dim)
        mask_fwd = F.broadcast_to(mask_fwd, (total_spans, self.bilstm_dim)) # (total_spans, bilstm_dim)
        span_vectors = mask_bwd * (e_padded_edu_vectors - bm1_padded_edu_vectors) \
                        + mask_fwd * (b_padded_edu_vectors - ep1_padded_edu_vectors) # (total_spans, bilstm_dim)

        # Template features
        tempfeat_vectors = self.template_feature_extractor.extract_batch_features(
                                        edus=edus,
                                        edus_postag=edus_postag,
                                        sbnds=sbnds,
                                        pbnds=pbnds,
                                        spans=flatten_batch_spans) # (total_spans, tempfeat_dim)
        tempfeat_vectors = utils.convert_ndarray_to_variable(tempfeat_vectors, seq=False) # (total_spans, tempfeat_dim)
        span_vectors = F.concat([span_vectors, tempfeat_vectors], axis=1) # (total_spans, bilstm_dim + tempfeat_dim)

        return span_vectors
Esempio n. 5
0
def forward(
        model, batch_sents, batch_labels,
        lmd, identity_penalty,
        train):
    ys, _ = model.forward(batch_sents, train=train) # T x (N,T)    
    ys = F.concat(ys, axis=0) # => (T*N, T)

    ts, M = utils.padding(batch_labels, head=True, with_mask=True) # => (N, T), (N, T)
    ts = ts.T # => (T, N)
    ts = ts.reshape(-1,) # => (T*N,)
    M = M[:,None,:] * M[:,:,None] # => (N, T, T)
    ts = utils.convert_ndarray_to_variable(ts, seq=False, train=train) # => (T*N,)
    M = utils.convert_ndarray_to_variable(M, seq=False, train=train) # => (N, T, T)

    loss = F.softmax_cross_entropy(ys, ts)
    acc = F.accuracy(ys, ts, ignore_label=-1)

    if identity_penalty:
        loss_id = loss_identity_penalty(ys, M, train=train)
        loss = loss + lmd * loss_id
    return loss, acc
Esempio n. 6
0
def train(model, decoder, sampler, max_epoch, n_init_epochs, negative_size,
          batch_size, weight_decay, gradient_clipping, optimizer_name,
          train_dataset, dev_dataset, path_train, path_valid, path_snapshot,
          path_pred, path_gold):
    """
    :type model: SpanBasedModel
    :type decoder: IncrementalCKYDecoder
    :type sampler: TreeSampler
    :type max_epoch: int
    :type n_init_epochs: int
    :type negative_size: int
    :type batch_size: int
    :type weight_decay: float
    :type gradient_clipping: float
    :type optimizer_name: str
    :type train_dataset: numpy.ndarray
    :type dev_dataset: numpy.ndarray
    :type path_train: str
    :type path_valid: str
    :type path_snapshot: str
    :type path_pred: str
    :type path_gold: str
    :rtype: None
    """
    writer_train = jsonlines.Writer(open(path_train, "w"), flush=True)
    if dev_dataset is not None:
        writer_valid = jsonlines.Writer(open(path_valid, "w"), flush=True)

    boundary_flags = [(True, False)]
    assert negative_size >= len(boundary_flags)
    negative_tree_sampler = treesamplers.NegativeTreeSampler()

    # Optimizer preparation
    if optimizer_name == "adam":
        opt = optimizers.Adam()
    else:
        raise ValueError("Invalid optimizer_name=%s" % optimizer_name)

    opt.setup(model)

    if weight_decay > 0.0:
        opt.add_hook(chainer.optimizer.WeightDecay(weight_decay))
    if gradient_clipping:
        opt.add_hook(chainer.optimizer.GradientClipping(gradient_clipping))

    n_train = len(train_dataset)
    it = 0
    bestscore_holder = utils.BestScoreHolder(scale=100.0)
    bestscore_holder.init()

    if dev_dataset is not None:
        # Initial validation
        with chainer.using_config("train", False), chainer.no_backprop_mode():
            parse(model=model,
                  decoder=decoder,
                  dataset=dev_dataset,
                  path_pred=path_pred)
            scores = metrics.rst_parseval(pred_path=path_pred,
                                          gold_path=path_gold)
            old_scores = metrics.old_rst_parseval(pred_path=path_pred,
                                                  gold_path=path_gold)
            out = {
                "epoch": 0,
                "Morey2018": {
                    "Unlabeled Precision": scores["S"]["Precision"] * 100.0,
                    "Precision_info": scores["S"]["Precision_info"],
                    "Unlabeled Recall": scores["S"]["Recall"] * 100.0,
                    "Recall_info": scores["S"]["Recall_info"],
                    "Micro F1": scores["S"]["Micro F1"] * 100.0
                },
                "Marcu2000": {
                    "Unlabeled Precision":
                    old_scores["S"]["Precision"] * 100.0,
                    "Precision_info": old_scores["S"]["Precision_info"],
                    "Unlabeled Recall": old_scores["S"]["Recall"] * 100.0,
                    "Recall_info": old_scores["S"]["Recall_info"],
                    "Micro F1": old_scores["S"]["Micro F1"] * 100.0
                }
            }
            writer_valid.write(out)
            utils.writelog(utils.pretty_format_dict(out))
        # Saving
        bestscore_holder.compare_scores(scores["S"]["Micro F1"], step=0)
        serializers.save_npz(path_snapshot, model)
        utils.writelog("Saved the model to %s" % path_snapshot)
    else:
        # Saving
        serializers.save_npz(path_snapshot, model)
        utils.writelog("Saved the model to %s" % path_snapshot)

    for epoch in range(1, max_epoch + 1):

        perm = np.random.permutation(n_train)

        ########## E-Step (BEGIN) ##########
        utils.writelog("E step ===>")

        prog_bar = pyprind.ProgBar(n_train)

        for inst_i in range(0, n_train, batch_size):

            ### Mini batch

            for data in train_dataset[inst_i:inst_i + batch_size]:

                ### One data instance

                edu_ids = data.edu_ids
                edus = data.edus
                edus_postag = data.edus_postag
                edus_head = data.edus_head
                sbnds = data.sbnds
                pbnds = data.pbnds

                with chainer.using_config("train",
                                          False), chainer.no_backprop_mode():

                    # Feature extraction
                    edu_vectors = model.forward_edus(
                        edus, edus_postag, edus_head)  # (n_edus, bilstm_dim)
                    padded_edu_vectors = model.pad_edu_vectors(
                        edu_vectors)  # (n_edus+2, bilstm_dim)
                    mask_bwd, mask_fwd = model.make_masks(
                    )  # (1, bilstm_dim), (1, bilstm_dim)

                    # Positive tree
                    if epoch <= n_init_epochs:
                        pos_sexp = sampler.sample(inputs=edu_ids,
                                                  edus=edus,
                                                  edus_head=edus_head,
                                                  sbnds=sbnds,
                                                  pbnds=pbnds)
                    else:
                        span_scores = precompute_all_span_scores(
                            model=model,
                            edus=edus,
                            edus_postag=edus_postag,
                            sbnds=sbnds,
                            pbnds=pbnds,
                            padded_edu_vectors=padded_edu_vectors,
                            mask_bwd=mask_bwd,
                            mask_fwd=mask_fwd)
                        pos_sexp = decoder.decode(span_scores=span_scores,
                                                  inputs=edu_ids,
                                                  sbnds=sbnds,
                                                  pbnds=pbnds,
                                                  use_sbnds=True,
                                                  use_pbnds=True)
                    pos_tree = treetk.sexp2tree(pos_sexp,
                                                with_nonterminal_labels=False,
                                                with_terminal_labels=False)
                    pos_tree.calc_spans()
                    pos_spans = treetk.aggregate_spans(
                        pos_tree, include_terminal=False,
                        order="post-order")  # list of (int, int)
                    data.pos_spans = pos_spans  #NOTE
                    prog_bar.update()
        ########## E-Step (END) ##########

        ########## M-Step (BEGIN) ##########
        utils.writelog("M step ===>")

        for inst_i in range(0, n_train, batch_size):

            # Processing one mini-batch

            # Init
            loss_bracketing, acc_bracketing = 0.0, 0.0
            actual_batchsize = 0

            for data in train_dataset[perm[inst_i:inst_i + batch_size]]:

                # Processing one instance

                edu_ids = data.edu_ids
                edus = data.edus
                edus_postag = data.edus_postag
                edus_head = data.edus_head
                sbnds = data.sbnds
                pbnds = data.pbnds
                pos_spans = data.pos_spans  # NOTE

                # Feature extraction
                edu_vectors = model.forward_edus(
                    edus, edus_postag, edus_head)  # (n_edus, bilstm_dim)
                padded_edu_vectors = model.pad_edu_vectors(
                    edu_vectors)  # (n_edus+2, bilstm_dim)
                mask_bwd, mask_fwd = model.make_masks(
                )  # (1, bilstm_dim), (1, bilstm_dim)

                # Negative trees
                pos_neg_spans = []
                margins = []
                pos_neg_spans.append(pos_spans)
                with chainer.using_config("train",
                                          False), chainer.no_backprop_mode():
                    for use_sbnds, use_pbnds in boundary_flags:
                        span_scores = precompute_all_span_scores(
                            model=model,
                            edus=edus,
                            edus_postag=edus_postag,
                            sbnds=sbnds,
                            pbnds=pbnds,
                            padded_edu_vectors=padded_edu_vectors,
                            mask_bwd=mask_bwd,
                            mask_fwd=mask_fwd)
                        neg_bin_sexp = decoder.decode(
                            span_scores=span_scores,
                            inputs=edu_ids,
                            sbnds=sbnds,
                            pbnds=pbnds,
                            use_sbnds=use_sbnds,
                            use_pbnds=use_pbnds,
                            gold_spans=pos_spans)  # list of str
                        neg_tree = treetk.sexp2tree(
                            neg_bin_sexp,
                            with_nonterminal_labels=False,
                            with_terminal_labels=False)
                        neg_tree.calc_spans()
                        neg_spans = treetk.aggregate_spans(
                            neg_tree,
                            include_terminal=False,
                            order="pre-order")  # list of (int, int)
                        margin = compute_tree_distance(pos_spans,
                                                       neg_spans,
                                                       coef=1.0)
                        pos_neg_spans.append(neg_spans)
                        margins.append(margin)
                for _ in range(negative_size - len(boundary_flags)):
                    neg_bin_sexp = negative_tree_sampler.sample(inputs=edu_ids,
                                                                sbnds=sbnds,
                                                                pbnds=pbnds)
                    neg_tree = treetk.sexp2tree(neg_bin_sexp,
                                                with_nonterminal_labels=False,
                                                with_terminal_labels=False)
                    neg_tree.calc_spans()
                    neg_spans = treetk.aggregate_spans(
                        neg_tree, include_terminal=False,
                        order="pre-order")  # list of (int, int)
                    margin = compute_tree_distance(pos_spans,
                                                   neg_spans,
                                                   coef=1.0)
                    pos_neg_spans.append(neg_spans)
                    margins.append(margin)

                # Scoring
                pred_scores = model.forward_spans_for_bracketing(
                    edus=edus,
                    edus_postag=edus_postag,
                    sbnds=sbnds,
                    pbnds=pbnds,
                    padded_edu_vectors=padded_edu_vectors,
                    mask_bwd=mask_bwd,
                    mask_fwd=mask_fwd,
                    batch_spans=pos_neg_spans,
                    aggregate=True)  # (1+negative_size, 1)

                # Bracketing Loss
                for neg_i in range(negative_size):
                    loss_bracketing += F.clip(
                        pred_scores[1 + neg_i] + margins[neg_i] -
                        pred_scores[0], 0.0, 10000000.0)

                # Ranking Accuracy
                pred_scores = F.reshape(
                    pred_scores,
                    (1, 1 + negative_size))  # (1, 1+negative_size)
                gold_scores = np.zeros((1, ), dtype=np.int32)  # (1,)
                gold_scores = utils.convert_ndarray_to_variable(
                    gold_scores, seq=False)  # (1,)
                acc_bracketing += F.accuracy(pred_scores, gold_scores)

                actual_batchsize += 1

            # Backward & Update
            actual_batchsize = float(actual_batchsize)
            loss_bracketing = loss_bracketing / actual_batchsize
            acc_bracketing = acc_bracketing / actual_batchsize
            loss = loss_bracketing
            model.zerograds()
            loss.backward()
            opt.update()
            it += 1

            # Write log
            loss_bracketing_data = float(cuda.to_cpu(loss_bracketing.data))
            acc_bracketing_data = float(cuda.to_cpu(acc_bracketing.data))
            out = {
                "iter": it,
                "epoch": epoch,
                "progress": "%d/%d" % (inst_i + actual_batchsize, n_train),
                "progress_ratio":
                float(inst_i + actual_batchsize) / n_train * 100.0,
                "Bracketing Loss": loss_bracketing_data,
                "Ranking Accuracy": acc_bracketing_data * 100.0
            }
            writer_train.write(out)
            utils.writelog(utils.pretty_format_dict(out))
        ########## M-Step (END) ##########

        if dev_dataset is not None:
            # Validation
            with chainer.using_config("train",
                                      False), chainer.no_backprop_mode():
                parse(model=model,
                      decoder=decoder,
                      dataset=dev_dataset,
                      path_pred=path_pred)
                scores = metrics.rst_parseval(pred_path=path_pred,
                                              gold_path=path_gold)
                old_scores = metrics.old_rst_parseval(pred_path=path_pred,
                                                      gold_path=path_gold)
                out = {
                    "epoch": epoch,
                    "Morey2018": {
                        "Unlabeled Precision":
                        scores["S"]["Precision"] * 100.0,
                        "Precision_info": scores["S"]["Precision_info"],
                        "Unlabeled Recall": scores["S"]["Recall"] * 100.0,
                        "Recall_info": scores["S"]["Recall_info"],
                        "Micro F1": scores["S"]["Micro F1"] * 100.0
                    },
                    "Marcu2000": {
                        "Unlabeled Precision":
                        old_scores["S"]["Precision"] * 100.0,
                        "Precision_info": old_scores["S"]["Precision_info"],
                        "Unlabeled Recall": old_scores["S"]["Recall"] * 100.0,
                        "Recall_info": old_scores["S"]["Recall_info"],
                        "Micro F1": old_scores["S"]["Micro F1"] * 100.0
                    }
                }
                writer_valid.write(out)
                utils.writelog(utils.pretty_format_dict(out))
            # Saving
            did_update = bestscore_holder.compare_scores(
                scores["S"]["Micro F1"], epoch)
            if did_update:
                serializers.save_npz(path_snapshot, model)
                utils.writelog("Saved the model to %s" % path_snapshot)
            # Finished?
            if bestscore_holder.ask_finishing(max_patience=10):
                utils.writelog(
                    "Patience %d is over. Training finished successfully." %
                    bestscore_holder.patience)
                writer_train.close()
                if dev_dataset is not None:
                    writer_valid.close()
                return
        else:
            # No validation
            # Saving
            serializers.save_npz(path_snapshot, model)
    def forward_edus(self, edus, edus_postag, edus_head):
        """
        :type edus: lsit of list of str
        :type edus_postag: list of list of str
        :type edus_head: list of (str, str, str)
        :rtype: Variable(shape=(n_edus, bilstm_dim), dtype=np.float32)
        """
        with chainer.using_config("train", False), chainer.no_backprop_mode():
            #################
            # TODO?
            # Bag-of-word embedding
            # word_ids = [[self.vocab_word.get(w, self.unk_word_id) for w in edu]
            #             for edu in edus] # n_edus * length * int
            # word_ids, mask = utils.padding(word_ids, head=True, with_mask=True) # (n_edus, max_length), (n_edus, max_length)
            # n_edus, max_length = word_ids.shape
            # word_ids = utils.convert_ndarray_to_variable(word_ids, seq=False) # (n_edus, max_length)
            # mask = utils.convert_ndarray_to_variable(mask, seq=False) # (n_edus, max_length)
            # word_ids = F.reshape(word_ids, (n_edus * max_length,)) # (n_edus * max_length,)
            # word_vectors = F.dropout(self.embed_word(word_ids), ratio=0.2) # (n_edus * max_length, word_dim)
            # word_vectors = F.reshape(word_vectors, (n_edus, max_length, self.word_dim)) # (n_edus, max_length, word_dim)
            # mask = F.broadcast_to(mask[:,:,None], (n_edus, max_length, self.word_dim)) # (n_edus, max_length, word_dikm)
            # word_vectors = word_vectors * mask # (n_edus, max_length, word_dim)
            # bow_vectors = F.sum(word_vectors, axis=1) # (n_edus, word_dim)
            #################

            # Beginning-word embedding
            begin_word_ids = [
                self.vocab_word.get(edu[0], self.unk_word_id) for edu in edus
            ]  # n_edus * int
            begin_word_ids = np.asarray(begin_word_ids,
                                        dtype=np.int32)  # (n_edus,)
            begin_word_ids = utils.convert_ndarray_to_variable(
                begin_word_ids, seq=False)  # (n_edus,)
            begin_word_vectors = F.dropout(self.embed_word(begin_word_ids),
                                           ratio=0.2)  # (n_edus, word_dim)

            # End-word embedding
            end_word_ids = [
                self.vocab_word.get(edu[-1], self.unk_word_id) for edu in edus
            ]  # n_edus * int
            end_word_ids = np.asarray(end_word_ids,
                                      dtype=np.int32)  # (n_edus,)
            end_word_ids = utils.convert_ndarray_to_variable(
                end_word_ids, seq=False)  # (n_edus,)
            end_word_vectors = F.dropout(self.embed_word(end_word_ids),
                                         ratio=0.2)  # (n_edus, word_dim)

            # Head-word embedding
            head_word_ids = [
                self.vocab_word.get(head_word, self.unk_word_id)
                for (head_word, head_postag, head_deprel) in edus_head
            ]  # n_edus * int
            head_word_ids = np.asarray(head_word_ids,
                                       dtype=np.int32)  # (n_edus,)
            head_word_ids = utils.convert_ndarray_to_variable(
                head_word_ids, seq=False)  # (n_edus,)
            head_word_vectors = F.dropout(self.embed_word(head_word_ids),
                                          ratio=0.2)  # (n_edus, word_dim)

        # Beginning-postag embedding
        begin_postag_ids = [
            self.vocab_postag[edu_postag[0]] for edu_postag in edus_postag
        ]  # n_edus * int
        begin_postag_ids = np.asarray(begin_postag_ids,
                                      dtype=np.int32)  # (n_edus,)
        begin_postag_ids = utils.convert_ndarray_to_variable(
            begin_postag_ids, seq=False)  # (n_edus,)
        begin_postag_vectors = F.dropout(self.embed_postag(begin_postag_ids),
                                         ratio=0.2)  # (n_edus, postag_dim)

        # End-postag embedding
        end_postag_ids = [
            self.vocab_postag[edu_postag[-1]] for edu_postag in edus_postag
        ]  # n_edus * int
        end_postag_ids = np.asarray(end_postag_ids,
                                    dtype=np.int32)  # (n_edus,)
        end_postag_ids = utils.convert_ndarray_to_variable(
            end_postag_ids, seq=False)  # (n_edus,)
        end_postag_vectors = F.dropout(self.embed_postag(end_postag_ids),
                                       ratio=0.2)  # (n_edus, postag_dim)

        # Head-postag embedding
        head_postag_ids = [
            self.vocab_postag[head_postag]
            for (head_word, head_postag, head_deprel) in edus_head
        ]  # n_edus * int
        head_postag_ids = np.asarray(head_postag_ids,
                                     dtype=np.int32)  # (n_edus,)
        head_postag_ids = utils.convert_ndarray_to_variable(
            head_postag_ids, seq=False)  # (n_edus,)
        head_postag_vectors = F.dropout(self.embed_postag(head_postag_ids),
                                        ratio=0.2)  # (n_edus, postag_dim)

        # Head-deprel embedding
        head_deprel_ids = [
            self.vocab_deprel.get(head_deprel, self.unk_deprel_id)
            for (head_word, head_postag, head_deprel) in edus_head
        ]  # n_edus * int
        head_deprel_ids = np.asarray(head_deprel_ids,
                                     dtype=np.int32)  # (n_edus,)
        head_deprel_ids = utils.convert_ndarray_to_variable(
            head_deprel_ids, seq=False)  # (n_edus,)
        head_deprel_vectors = F.dropout(self.embed_deprel(head_deprel_ids),
                                        ratio=0.2)  # (n_edus, deprel_dim)

        # Concat
        edu_vectors = F.concat(
            [
                begin_word_vectors, end_word_vectors, head_word_vectors,
                begin_postag_vectors, end_postag_vectors, head_postag_vectors,
                head_deprel_vectors
            ],
            axis=1)  # (n_edus, 3 * word_dim + 3 * postag_dim + deprel_dim)
        edu_vectors = F.relu(self.W_edu(edu_vectors))  # (n_edus, word_dim)

        # BiLSTM
        h_init, c_init = None, None
        _, _, states = self.bilstm(hx=h_init, cx=c_init,
                                   xs=[edu_vectors])  # (1, n_edus, bilstm_dim)
        edu_vectors = states[0]  # (n_edus, bilstm_dim)

        return edu_vectors