Example #1
0
 def link_loss(self, input, target, neighbors=8):
     batch_size = input.shape[0]
     self.pos_link_weight = (target
                             == 1).astype('float32') * nd.broadcast_axes(
                                 nd.expand_dims(self.pixel_weight, axis=1),
                                 axis=1,
                                 size=neighbors)  # (2, 8, 256, 256)
     self.neg_link_weight = (target
                             == 0).astype('float32') * nd.broadcast_axes(
                                 nd.expand_dims(self.pixel_weight, axis=1),
                                 axis=1,
                                 size=neighbors)
     sum_pos_link_weight = nd.sum(nd.reshape(self.pos_link_weight,
                                             (batch_size, -1)),
                                  axis=1)
     sum_neg_link_weight = nd.sum(nd.reshape(self.neg_link_weight,
                                             (batch_size, -1)),
                                  axis=1)
     self.link_cross_entropy = []
     for i in range(neighbors):
         assert input.shape[1] == 16
         this_input = input[:, [2 * i, 2 * i + 1]]
         this_target = target[:, i]
         self.link_cross_entropy.append(
             self.link_cross_entropy_layer(this_input, this_target)[1])
     self.link_cross_entropy = nd.concat(*self.link_cross_entropy,
                                         dim=1)  # (2, 8, 256, 256)
     loss_link_pos = []
     loss_link_neg = []
     ctx = try_gpu()
     for i in range(batch_size):
         if sum_pos_link_weight[i].asscalar() == 0:
             loss_link_pos_temp = nd.zeros(self.pos_link_weight[0].shape,
                                           ctx, 'float32')
             loss_link_pos.append(nd.expand_dims(loss_link_pos_temp,
                                                 axis=0))
         else:
             loss_link_pos_temp = self.pos_link_weight[
                 i] * self.link_cross_entropy[i] / sum_pos_link_weight[i]
             loss_link_pos.append(nd.expand_dims(loss_link_pos_temp,
                                                 axis=0))
         if sum_neg_link_weight[i].asscalar() == 0:
             loss_link_neg_temp = nd.zeros(self.neg_link_weight[0].shape,
                                           ctx, 'float32')
             loss_link_neg.append(nd.expand_dims(loss_link_neg_temp,
                                                 axis=0))
         else:
             loss_link_neg_temp = self.neg_link_weight[
                 i] * self.link_cross_entropy[i] / sum_neg_link_weight[
                     i]  # (8, 256, 256)
             loss_link_neg.append(nd.expand_dims(loss_link_neg_temp,
                                                 axis=0))
     loss_link_pos = nd.concat(*loss_link_pos, dim=0)
     loss_link_neg = nd.concat(*loss_link_neg, dim=0)  # (2, 8, 256, 256)
     loss_link_pos = nd.sum(nd.reshape(loss_link_pos, (batch_size, -1)),
                            axis=1)
     loss_link_neg = nd.sum(nd.reshape(loss_link_neg, (batch_size, -1)),
                            axis=1)
     return nd.mean(loss_link_pos), nd.mean(loss_link_neg)
Example #2
0
def getSelfMask(q_seq):
    batch_size, seq_len = q_seq.shape
    mask_matrix = np.ones(shape=(seq_len, seq_len), dtype=np.float)
    mask = np.tril(mask_matrix, k=0)
    mask = nd.expand_dims(nd.array(mask, ctx=ghp.ctx), axis=0)
    mask = nd.broadcast_axes(mask, axis=0, size=batch_size)
    return mask
Example #3
0
 def _get_self_tril_mask(self, dec_idx):
     batch_size, seq_len = dec_idx.shape
     mask_matrix = np.ones(shape=(seq_len, seq_len))
     mask = np.tril(mask_matrix, k=0)
     mask = nd.expand_dims(nd.array(mask, ctx=self._ctx), axis=0)
     mask = nd.broadcast_axes(mask, axis=0, size=batch_size)
     return mask
def eval(en_bert, mt_model, en_vocab, ch_vocab, dev_dataiter, logger, ctx):
    references = []
    hypothesis = []
    score = 0
    chencherry = SmoothingFunction()
    for trans, _, label, trans_valid_len, label_valid_len in tqdm(
            dev_dataiter):
        trans = trans.as_in_context(ctx)
        trans_valid_len = trans_valid_len.as_in_context(ctx)
        batch_size = trans.shape[0]

        trans_token_type = nd.zeros_like(trans)
        en_bert_outputs = en_bert(trans, trans_token_type, trans_valid_len)

        ch_sentences = [BOS]
        aim = ch_vocab[ch_sentences]
        aim = nd.array([aim], ctx=ctx)
        aim = nd.broadcast_axes(aim, axis=0, size=batch_size)

        for n in range(0, args.max_ch_len):
            mt_outputs = mt_model(en_bert_outputs, trans, aim)
            predicts = nd.argmax(nd.softmax(mt_outputs, axis=-1), axis=-1)
            final_predict = predicts[:, -1:]
            aim = nd.concat(aim, final_predict, dim=1)

        label = label.asnumpy().tolist()
        predict_valid_len = nd.sum(nd.not_equal(
            predicts, ch_vocab(ch_vocab.padding_token)),
                                   axis=-1).asnumpy().tolist()
        predicts = aim[:, 1:].asnumpy().tolist()
        label_valid_len = label_valid_len.asnumpy().tolist()

        for refer, hypoth, l_v_len, p_v_len in zip(label, predicts,
                                                   label_valid_len,
                                                   predict_valid_len):
            l_v_len = int(l_v_len)
            p_v_len = int(p_v_len)
            refer = refer[:l_v_len]
            refer_str = [ch_vocab.idx_to_token[int(idx)] for idx in refer]
            hypoth_str = [ch_vocab.idx_to_token[int(idx)] for idx in hypoth]
            hypoth_str_valid = []
            for token in hypoth_str:
                if token == EOS:
                    hypoth_str_valid.append(token)
                    break
                hypoth_str_valid.append(token)
            references.append(refer_str)
            hypothesis.append(hypoth_str_valid)

    for refer, hypoth in zip(references, hypothesis):
        score += sentence_bleu([refer],
                               hypoth,
                               smoothing_function=chencherry.method1)
    logger.info("dev sample:")
    logger.info("refer :{}".format(" ".join(references[0]).replace(
        EOS, "[EOS]").replace(ch_vocab.padding_token, "")))
    logger.info("hypoth:{}".format(" ".join(hypothesis[0]).replace(
        EOS, "[EOS]")))
    return score / len(references)
Example #5
0
    def forward(self, src_idx, tgt_idx):
        # compute encoder mask
        key_mask = self._get_key_mask(src_idx,
                                      src_idx,
                                      pad_idx=self.src_pad_idx)
        src_non_pad_mask = self._get_non_pad_mask(src_idx,
                                                  pad_idx=self.src_pad_idx)

        # compute decoder mask
        self_tril_mask = self._get_self_tril_mask(tgt_idx)
        self_key_mask = self._get_key_mask(tgt_idx,
                                           tgt_idx,
                                           pad_idx=self.tgt_pad_idx)
        self_att_mask = nd.greater((self_key_mask + self_tril_mask), 1)

        context_att_mask = self._get_key_mask(src_idx,
                                              tgt_idx,
                                              pad_idx=self.src_pad_idx)
        tgt_non_pad_mask = self._get_non_pad_mask(tgt_idx,
                                                  pad_idx=self.tgt_pad_idx)

        # Encoder
        position = nd.array(self._position_encoding_init(
            src_idx.shape[1], self._model_dim),
                            ctx=src_idx.context)
        position = nd.expand_dims(position, axis=0)
        position = nd.broadcast_axes(position, axis=0, size=tgt_idx.shape[0])
        position = position * src_non_pad_mask
        src_emb = self.embedding(src_idx)
        enc_output = self.encoder(src_emb, position, key_mask,
                                  src_non_pad_mask)

        # Decoder
        position = nd.array(self._position_encoding_init(
            tgt_idx.shape[1], self._model_dim),
                            ctx=src_idx.context)
        position = nd.expand_dims(position, axis=0)
        position = nd.broadcast_axes(position, axis=0, size=tgt_idx.shape[0])
        position = position * tgt_non_pad_mask
        tgt_emb = self.embedding(tgt_idx)

        outputs = self.decoder(enc_output, tgt_emb, position, self_att_mask,
                               context_att_mask, tgt_non_pad_mask)
        outputs = self.linear(outputs)
        return outputs
Example #6
0
def getMask(q_seq, k_seq):
    # q_seq shape : (batch_size, q_seq_len)
    # k_seq shape : (batch_size, k_seq_len)
    q_len = q_seq.shape[1]
    pad_mask = nd.not_equal(k_seq, 0)
    pad_mask = nd.expand_dims(pad_mask, axis=1)
    pad_mask = nd.broadcast_axes(pad_mask, axis=1, size=q_len)

    return pad_mask
Example #7
0
 def _get_key_mask(self, enc_idx, dec_idx, pad_idx=None):
     seq_len = dec_idx.shape[1]
     if pad_idx:
         pad_mask = nd.not_equal(enc_idx, pad_idx)
     else:
         pad_mask = nd.not_equal(enc_idx, 0)
     pad_mask = nd.expand_dims(pad_mask, axis=1)
     pad_mask = nd.broadcast_axes(pad_mask, axis=1, size=seq_len)
     return pad_mask
Example #8
0
def batch_process(seq, ctx):
    seq = np.array(seq)
    aligned_seq = np.zeros(
        (max_sequence_length - 2 * region_radius, batch_size, region_size))
    for i in range(region_radius, max_sequence_length - region_radius):
        aligned_seq[i - region_radius] = seq[:, i - region_radius:i -
                                             region_radius + region_size]
    aligned_seq = nd.array(aligned_seq, ctx)
    batch_sequence = nd.array(seq, ctx)
    trimed_seq = batch_sequence[:, region_radius:max_sequence_length -
                                region_radius]
    mask = nd.broadcast_axes(nd.greater(trimed_seq, 0).reshape(
        (batch_size, -1, 1)),
                             axis=2,
                             size=128)
    return aligned_seq, nd.array(trimed_seq, ctx), mask
Example #9
0
def batch_process(seq, isContextWord, ctx):
    seq = np.array(seq)
    aligned_seq = np.zeros(
        (max_sequence_length - 2 * region_radius, batch_size, region_size))
    for i in range(region_radius, max_sequence_length - region_radius):
        aligned_seq[i - region_radius] = seq[:, i - region_radius:i -
                                             region_radius + region_size]
    if isContextWord:
        unit_id_bias = np.array([i * vocab_size for i in range(region_size)])
        aligned_seq = aligned_seq.transpose((1, 0, 2)) + unit_id_bias
    aligned_seq = nd.array(aligned_seq, ctx)
    batch_sequence = nd.array(seq, ctx)
    trimed_seq = batch_sequence[:, region_radius:max_sequence_length -
                                region_radius]
    mask = nd.broadcast_axes(nd.greater(trimed_seq, 0).reshape(
        (batch_size, -1, 1)),
                             axis=2,
                             size=128)
    return aligned_seq, nd.array(trimed_seq, ctx), mask
Example #10
0
    def forward(self, en_bert_output, en_idx, ch_idx):
        self_tril_mask = self._get_self_tril_mask(ch_idx)
        self_key_mask = self._get_key_mask(ch_idx,
                                           ch_idx,
                                           pad_idx=self.ch_pad_idx)
        self_att_mask = nd.greater((self_key_mask + self_tril_mask), 1)

        context_att_mask = self._get_key_mask(en_idx,
                                              ch_idx,
                                              pad_idx=self.en_pad_idx)
        non_pad_mask = self._get_non_pad_mask(ch_idx, pad_idx=self.ch_pad_idx)

        position = nd.array(self._position_encoding_init(
            ch_idx.shape[1], self._model_dim),
                            ctx=self._ctx)
        position = nd.expand_dims(position, axis=0)
        position = nd.broadcast_axes(position, axis=0, size=ch_idx.shape[0])
        position = position * non_pad_mask
        ch_emb = self.ch_embedding(ch_idx)
        outputs = self.decoder(en_bert_output, ch_emb, position, self_att_mask,
                               context_att_mask, non_pad_mask)
        outputs = self.linear(outputs)
        return outputs
def translate(args):
    gpu_idx = args.gpu
    if not gpu_idx:
        ctx = mx.cpu()
    else:
        ctx = mx.gpu(gpu_idx - 1)
    en_bert, en_vocab = gluonnlp.model.get_model(
        args.bert_model,
        dataset_name=args.en_bert_dataset,
        pretrained=True,
        ctx=ctx,
        use_pooler=False,
        use_decoder=False,
        use_classifier=False)
    _, ch_vocab = gluonnlp.model.get_model(args.bert_model,
                                           dataset_name=args.ch_bert_dataset,
                                           pretrained=True,
                                           ctx=ctx,
                                           use_pooler=False,
                                           use_decoder=False,
                                           use_classifier=False)

    mt_model = MTModel_Hybird(en_vocab=en_vocab,
                              ch_vocab=ch_vocab,
                              embedding_dim=args.mt_emb_dim,
                              model_dim=args.mt_model_dim,
                              head_num=args.mt_head_num,
                              layer_num=args.mt_layer_num,
                              ffn_dim=args.mt_ffn_dim,
                              dropout=args.mt_dropout,
                              att_dropout=args.mt_att_dropout,
                              ffn_dropout=args.mt_ffn_dropout,
                              ctx=ctx)

    en_bert.load_parameters(args.en_bert_model_params_path, ctx=ctx)
    mt_model.load_parameters(args.mt_model_params_path, ctx=ctx)

    en_bert_tokenzier = BERTTokenizer(en_vocab)
    ch_bert_tokenzier = BERTTokenizer(ch_vocab)

    while True:
        trans = input("input:")

        trans = en_bert_tokenzier(trans)
        trans = [en_vocab.cls_token] + \
            trans + [en_vocab.sep_token]

        trans_valid_len = len(trans)

        if args.max_en_len and len(trans) > args.max_en_len:
            trans = trans[0:args.max_en_len]

        aim = [BOS]

        trans = en_vocab[trans]
        aim = ch_vocab[aim]

        aim = nd.array([aim], ctx=ctx)

        trans = nd.array([trans], ctx=ctx)
        trans_valid_len = nd.array([trans_valid_len], ctx=ctx)
        trans_token_types = nd.zeros_like(trans)

        batch_size = 1
        beam_size = 6

        en_bert_outputs = en_bert(trans, trans_token_types, trans_valid_len)
        mt_outputs = mt_model(en_bert_outputs, trans, aim)

        en_bert_outputs = nd.broadcast_axes(en_bert_outputs,
                                            axis=0,
                                            size=beam_size)
        trans = nd.broadcast_axes(trans, axis=0, size=beam_size)
        targets = None
        for n in range(0, args.max_ch_len):
            aim, targets = beam_search(mt_outputs[:, n, :],
                                       targets=targets,
                                       max_seq_len=args.max_ch_len,
                                       ctx=ctx,
                                       beam_width=beam_size)
            mt_outputs = mt_model(en_bert_outputs, trans, aim)

        predict = aim.asnumpy().tolist()
        predict_strs = []
        for pred in predict:
            predict_token = [ch_vocab.idx_to_token[int(idx)] for idx in pred]
            predict_str = ""
            sub_token = []
            for token in predict_token:
                # if token in ["[CLS]", EOS, "[SEP]"]:
                #     continue
                if len(sub_token) == 0:
                    sub_token.append(token)
                elif token[:2] != "##" and len(sub_token) != 0:
                    predict_str += "".join(sub_token) + " "
                    sub_token = []
                    sub_token.append(token)
                else:
                    if token[:2] == "##":
                        token = token.replace("##", "")
                    sub_token.append(token)
                if token == EOS:
                    if len(sub_token) != 0:
                        predict_str += "".join(sub_token) + " "
                    break
            predict_strs.append(
                predict_str.replace("[SEP]", "").replace("[CLS]",
                                                         "").replace(EOS, ""))
        for predict_str in predict_strs:
            print(predict_str)
def translate(args):
    gpu_idx = args.gpu
    if not gpu_idx:
        ctx = mx.cpu()
    else:
        ctx = mx.gpu(gpu_idx - 1)
    src_bert, src_vocab = gluonnlp.model.get_model(args.bert_model,
                                                   dataset_name=args.src_bert_dataset,
                                                   pretrained=True,
                                                   ctx=ctx,
                                                   use_pooler=False,
                                                   use_decoder=False,
                                                   use_classifier=False)
    _, tgt_vocab = gluonnlp.model.get_model(args.bert_model,
                                            dataset_name=args.tgt_bert_dataset,
                                            pretrained=True,
                                            ctx=ctx,
                                            use_pooler=False,
                                            use_decoder=False,
                                            use_classifier=False)

    mt_model = MTModel_Hybird(src_vocab=src_vocab,
                              tgt_vocab=tgt_vocab,
                              embedding_dim=args.mt_emb_dim,
                              model_dim=args.mt_model_dim,
                              head_num=args.mt_head_num,
                              layer_num=args.mt_layer_num,
                              ffn_dim=args.mt_ffn_dim,
                              dropout=args.mt_dropout,
                              att_dropout=args.mt_att_dropout,
                              ffn_dropout=args.mt_ffn_dropout,
                              ctx=ctx)

    src_bert.load_parameters(args.bert_model_params_path, ctx=ctx)
    mt_model.load_parameters(args.mt_model_params_path, ctx=ctx)

    src_bert_tokenzier = BERTTokenizer(src_vocab)
    tgt_bert_tokenzier = BERTTokenizer(tgt_vocab)

    while True:
        src = input("input:")

        src = src_bert_tokenzier(src)
        src = [src_vocab.cls_token] + \
            src + [src_vocab.sep_token]

        src_valid_len = len(src)

        if args.max_src_len and len(src) > args.max_src_len:
            src = src[0:args.max_src_len]

        tgt = [BOS]

        src = src_vocab[src]
        tgt = tgt_vocab[tgt]

        tgt = nd.array([tgt], ctx=ctx)

        src = nd.array([src], ctx=ctx)
        src_valid_len = nd.array([src_valid_len], ctx=ctx)
        src_token_types = nd.zeros_like(src)

        beam_size = 6

        src_bert_outputs = src_bert(src, src_token_types, src_valid_len)
        mt_outputs = mt_model(src_bert_outputs, src, tgt)

        src_bert_outputs = nd.broadcast_axes(
            src_bert_outputs, axis=0, size=beam_size)
        src = nd.broadcast_axes(src, axis=0, size=beam_size)
        targets = None
        for n in range(0, args.max_tgt_len):
            tgt, targets = beam_search(
                mt_outputs[:, n, :], targets=targets, max_seq_len=args.max_tgt_len, ctx=ctx, beam_width=beam_size)
            mt_outputs = mt_model(src_bert_outputs, src, tgt)

        predict = tgt.asnumpy().tolist()
        predict_strs = []
        for pred in predict:
            predict_token = [tgt_vocab.idx_to_token[int(idx)] for idx in pred]
            predict_str = ""
            sub_token = []
            for token in predict_token:
                # if token in ["[CLS]", EOS, "[SEP]"]:
                #     continue
                if len(sub_token) == 0:
                    sub_token.append(token)
                elif token[:2] != "##" and len(sub_token) != 0:
                    predict_str += "".join(sub_token) + " "
                    sub_token = []
                    sub_token.append(token)
                else:
                    if token[:2] == "##":
                        token = token.replace("##", "")
                    sub_token.append(token)
                if token == EOS:
                    if len(sub_token) != 0:
                        predict_str += "".join(sub_token) + " "
                    break
            predict_strs.append(predict_str.replace(
                "[SEP]", "").replace("[CLS]", "").replace(EOS, ""))
        for predict_str in predict_strs:
            print(predict_str)