Exemple #1
0
    def __init__(self, config):
        super().__init__()

        self.embedding = torch.nn.Embedding(len(config.tokenizer.vocab),
                                            config.emb_size)
        self.in_fc = nn.Linear(config.emb_size, config.d_model)
        self.transformer = TransformerEncoder(config)
        self.fc_dropout = nn.Dropout(config.fc_dropout)
        self.out_fc = nn.Linear(config.d_model, len(config.label2id))
        self.crf = CRF(num_tags=len(config.label2id), batch_first=True)
        self.apply(self.init_model_weights)
Exemple #2
0
 def __init__(self, num_embed, d_model=64, nhead=4, dim_feedforward=256, num_layers=2, out_dim=2, drop_rate=0.5):
     super(TransformerModel, self).__init__()
     self.embed=nn.Embedding(num_embed, d_model)
     self.encoder=TransformerEncoder(TransformerEncoderLayer(
         d_model=d_model,
         nhead=nhead,
         dim_feedforward=dim_feedforward,
         dropout=drop_rate,
     ), num_layers=num_layers)
     self.linear=nn.Linear(d_model, out_dim)
     self.dropout=nn.Dropout(p=drop_rate)
Exemple #3
0
 def get_encoder(lang):
     if lang not in lang_encoders:
         if shared_encoder_embed_tokens is not None:
             encoder_embed_tokens = shared_encoder_embed_tokens
         else:
             encoder_embed_tokens = build_embedding(
                 task.dicts[lang], args.encoder_embed_dim,
                 args.encoder_embed_path)
         lang_encoders[lang] = TransformerEncoder(
             args, task.dicts[lang], encoder_embed_tokens)
     return lang_encoders[lang]
Exemple #4
0
    def __init__(self, ner_processor, config):
        super().__init__()

        vocab_size = len(ner_processor.vocab)
        num_labels = len(ner_processor.idx2label)
        self.embedding = torch.nn.Embedding(vocab_size, config.emb_size)
        nn.init.normal_(self.embedding.weight, 0.0, 0.02)
        self.embed_size = config.emb_size
        self.in_fc = nn.Linear(config.emb_size, config.d_model)
        self.transformer = TransformerEncoder(config)
        self.fc_dropout = nn.Dropout(config.fc_dropout)
        self.out_fc = nn.Linear(config.d_model, num_labels)
        self.crf = CRF(num_tags=num_labels, batch_first=True)
    def __init__(self, vocab_size, tagset_size, config):
        super(TransformerCRF, self).__init__()
        self.device = config.device
        self.vocab_size = vocab_size
        self.tagset_size = tagset_size + 2
        self.batch_size = config.batch_size

        config.src_vocab_size = vocab_size
        config.emb_size = config.embedding_dim
        config.hidden_size = config.hidden_dim
        self.transformer = TransformerEncoder(config=config, padding_idx=utils.PAD)

        self.hidden2tag = nn.Linear(config.hidden_dim, self.tagset_size)

        self.crf = CRF(self.tagset_size, config)
def make_transformer(config, device):
    INPUT_DIM = src_vocab_length()
    OUTPUT_DIM = trg_vocab_length()

    enc = TransformerEncoder(INPUT_DIM,
                             config.hid_dim,
                             config.enc_layers,
                             config.enc_heads,
                             config.enc_pf_dim,
                             config.enc_dropout,
                             device)

    dec = TransformerDecoder(OUTPUT_DIM,
                             config.hid_dim,
                             config.dec_layers,
                             config.dec_heads,
                             config.dec_pf_dim,
                             config.dec_dropout,
                             device)
    return Seq2Seq(enc, dec, device).to(device)
Exemple #7
0
                                        TransformerEncoderQualitativeEvaluator,
                                        playMidi, init_performance_generation)
from src.models.model_run_job import ModelJobParams
from src.constants import PRODUCTION_DATA_DIR, DEVELOPMENT_DATA_DIR, CACHE_MODEL_DIR
from src.neptune import get_experiment_by_id

# %%
exp = init_performance_generation('THESIS-40', 'transformer.py', is_dev=True)

# %%
from models.transformer import TransformerEncoder
from src.models.model_writer_reader import read_params

# %%
hyper_params = read_params('artifacts/params.pickle')
model = TransformerEncoder(hyper_params)

# %%
params = QualitativeEvaluatorParams(is_dev=True)
qualitative_evaluator = TransformerEncoderQualitativeEvaluator(params, model)

# %%
model_path = './artifacts/model_dev_best.pth'
qualitative_evaluator.generate_performances(model_path)

# %%
qualitative_evaluator.generate_performance_for_file(
    xml_file_path=xml_file_path,
    midi_file_path=midi_file_path,
    plot_path=plot_file_path,
    composer_name='Bach',
Exemple #8
0
    def __init__(self,
                 args,
                 device,
                 vocab_size,
                 review_count,
                 product_size,
                 user_size,
                 review_words,
                 vocab_words,
                 word_dists=None):
        super(ProductRanker, self).__init__()
        self.args = args
        self.device = device
        self.train_review_only = args.train_review_only
        self.embedding_size = args.embedding_size
        self.vocab_words = vocab_words
        self.word_dists = None
        if word_dists is not None:
            self.word_dists = torch.tensor(word_dists, device=device)
        self.prod_pad_idx = product_size
        self.user_pad_idx = user_size
        self.word_pad_idx = vocab_size - 1
        self.seg_pad_idx = 3
        self.review_pad_idx = review_count - 1
        self.emb_dropout = args.dropout
        self.review_encoder_name = args.review_encoder_name
        self.fix_emb = args.fix_emb

        padded_review_words = review_words
        if not self.args.do_subsample_mask:
            #otherwise, review_words should be already padded
            padded_review_words = pad(review_words,
                                      pad_id=self.word_pad_idx,
                                      width=args.review_word_limit)
        self.review_words = torch.tensor(padded_review_words, device=device)

        self.pretrain_emb_dir = None
        if os.path.exists(args.pretrain_emb_dir):
            self.pretrain_emb_dir = args.pretrain_emb_dir
        self.pretrain_up_emb_dir = None
        if os.path.exists(args.pretrain_up_emb_dir):
            self.pretrain_up_emb_dir = args.pretrain_up_emb_dir
        self.dropout_layer = nn.Dropout(p=args.dropout)

        if self.args.use_user_emb:
            if self.pretrain_up_emb_dir is None:
                self.user_emb = nn.Embedding(user_size + 1,
                                             self.embedding_size,
                                             padding_idx=self.user_pad_idx)
            else:
                pretrain_user_emb_path = os.path.join(self.pretrain_up_emb_dir,
                                                      "user_emb.txt")
                pretrained_weights = load_user_item_embeddings(
                    pretrain_user_emb_path)
                pretrained_weights.append([0.] * len(pretrained_weights[0]))
                assert len(pretrained_weights[0]) == self.embedding_size
                self.user_emb = nn.Embedding.from_pretrained(
                    torch.FloatTensor(pretrained_weights),
                    padding_idx=self.user_pad_idx)

        if self.args.use_item_emb:
            if self.pretrain_up_emb_dir is None:
                self.product_emb = nn.Embedding(product_size + 1,
                                                self.embedding_size,
                                                padding_idx=self.prod_pad_idx)
            else:
                pretrain_product_emb_path = os.path.join(
                    self.pretrain_up_emb_dir, "product_emb.txt")
                pretrained_weights = load_user_item_embeddings(
                    pretrain_product_emb_path)
                pretrained_weights.append([0.] * len(pretrained_weights[0]))
                self.product_emb = nn.Embedding.from_pretrained(
                    torch.FloatTensor(pretrained_weights),
                    padding_idx=self.prod_pad_idx)

        if self.pretrain_emb_dir is not None:
            #word_emb_fname = "word_emb.txt.gz" #for query and target words in pv and pvc
            word_emb_fname = "context_emb.txt.gz" if args.review_encoder_name == "pvc" else "word_emb.txt.gz"  #for query and target words in pv and pvc
            pretrain_word_emb_path = os.path.join(self.pretrain_emb_dir,
                                                  word_emb_fname)
            word_index_dic, pretrained_weights = load_pretrain_embeddings(
                pretrain_word_emb_path)
            word_indices = torch.tensor(
                [0] + [word_index_dic[x]
                       for x in self.vocab_words[1:]] + [self.word_pad_idx])
            #print(len(word_indices))
            #print(word_indices.cpu().tolist())
            pretrained_weights = torch.FloatTensor(pretrained_weights)
            self.word_embeddings = nn.Embedding.from_pretrained(
                pretrained_weights[word_indices],
                padding_idx=self.word_pad_idx)
            #vectors of padding idx will not be updated
        else:
            self.word_embeddings = nn.Embedding(vocab_size,
                                                self.embedding_size,
                                                padding_idx=self.word_pad_idx)

        if self.fix_emb and args.review_encoder_name == "pvc":
            #if review embeddings are fixed, just load the aggregated embeddings which include all the words in the review
            #otherwise the reviews are cut off at review_word_limit
            self.review_encoder_name = "pv"

        self.transformer_encoder = TransformerEncoder(self.embedding_size,
                                                      args.ff_size, args.heads,
                                                      args.dropout,
                                                      args.inter_layers)

        if self.review_encoder_name == "pv":
            pretrain_emb_path = None
            if self.pretrain_emb_dir is not None:
                pretrain_emb_path = os.path.join(self.pretrain_emb_dir,
                                                 "doc_emb.txt.gz")
            self.review_encoder = ParagraphVector(self.word_embeddings,
                                                  self.word_dists,
                                                  review_count,
                                                  self.emb_dropout,
                                                  pretrain_emb_path,
                                                  fix_emb=self.fix_emb)
        elif self.review_encoder_name == "pvc":
            pretrain_emb_path = None
            #if self.pretrain_emb_dir is not None:
            #    pretrain_emb_path = os.path.join(self.pretrain_emb_dir, "context_emb.txt.gz")
            self.review_encoder = ParagraphVectorCorruption(
                self.word_embeddings,
                self.word_dists,
                args.corrupt_rate,
                self.emb_dropout,
                pretrain_emb_path,
                self.vocab_words,
                fix_emb=self.fix_emb)
        elif self.review_encoder_name == "fs":
            self.review_encoder = FSEncoder(self.embedding_size,
                                            self.emb_dropout)
        else:
            self.review_encoder = AVGEncoder(self.embedding_size,
                                             self.emb_dropout)

        if args.query_encoder_name == "fs":
            self.query_encoder = FSEncoder(self.embedding_size,
                                           self.emb_dropout)
        else:
            self.query_encoder = AVGEncoder(self.embedding_size,
                                            self.emb_dropout)
        self.seg_embeddings = nn.Embedding(4,
                                           self.embedding_size,
                                           padding_idx=self.seg_pad_idx)
        #for each q,u,i
        #Q, previous purchases of u, current available reviews for i, padding value
        #self.logsoftmax = torch.nn.LogSoftmax(dim = -1)
        #self.bce_logits_loss = torch.nn.BCEWithLogitsLoss(reduction='none')#by default it's mean

        self.review_embeddings = None
        if self.fix_emb:
            #self.word_embeddings.weight.requires_grad = False
            #embeddings of query words need to be update
            #self.emb_dropout = 0
            self.get_review_embeddings()  #get model.review_embeddings

        self.initialize_parameters(logger)  #logger
        self.to(device)  #change model in place
Exemple #9
0
class ProductRanker(nn.Module):
    def __init__(self,
                 args,
                 device,
                 vocab_size,
                 review_count,
                 product_size,
                 user_size,
                 review_words,
                 vocab_words,
                 word_dists=None):
        super(ProductRanker, self).__init__()
        self.args = args
        self.device = device
        self.train_review_only = args.train_review_only
        self.embedding_size = args.embedding_size
        self.vocab_words = vocab_words
        self.word_dists = None
        if word_dists is not None:
            self.word_dists = torch.tensor(word_dists, device=device)
        self.prod_pad_idx = product_size
        self.user_pad_idx = user_size
        self.word_pad_idx = vocab_size - 1
        self.seg_pad_idx = 3
        self.review_pad_idx = review_count - 1
        self.emb_dropout = args.dropout
        self.review_encoder_name = args.review_encoder_name
        self.fix_emb = args.fix_emb

        padded_review_words = review_words
        if not self.args.do_subsample_mask:
            #otherwise, review_words should be already padded
            padded_review_words = pad(review_words,
                                      pad_id=self.word_pad_idx,
                                      width=args.review_word_limit)
        self.review_words = torch.tensor(padded_review_words, device=device)

        self.pretrain_emb_dir = None
        if os.path.exists(args.pretrain_emb_dir):
            self.pretrain_emb_dir = args.pretrain_emb_dir
        self.pretrain_up_emb_dir = None
        if os.path.exists(args.pretrain_up_emb_dir):
            self.pretrain_up_emb_dir = args.pretrain_up_emb_dir
        self.dropout_layer = nn.Dropout(p=args.dropout)

        if self.args.use_user_emb:
            if self.pretrain_up_emb_dir is None:
                self.user_emb = nn.Embedding(user_size + 1,
                                             self.embedding_size,
                                             padding_idx=self.user_pad_idx)
            else:
                pretrain_user_emb_path = os.path.join(self.pretrain_up_emb_dir,
                                                      "user_emb.txt")
                pretrained_weights = load_user_item_embeddings(
                    pretrain_user_emb_path)
                pretrained_weights.append([0.] * len(pretrained_weights[0]))
                assert len(pretrained_weights[0]) == self.embedding_size
                self.user_emb = nn.Embedding.from_pretrained(
                    torch.FloatTensor(pretrained_weights),
                    padding_idx=self.user_pad_idx)

        if self.args.use_item_emb:
            if self.pretrain_up_emb_dir is None:
                self.product_emb = nn.Embedding(product_size + 1,
                                                self.embedding_size,
                                                padding_idx=self.prod_pad_idx)
            else:
                pretrain_product_emb_path = os.path.join(
                    self.pretrain_up_emb_dir, "product_emb.txt")
                pretrained_weights = load_user_item_embeddings(
                    pretrain_product_emb_path)
                pretrained_weights.append([0.] * len(pretrained_weights[0]))
                self.product_emb = nn.Embedding.from_pretrained(
                    torch.FloatTensor(pretrained_weights),
                    padding_idx=self.prod_pad_idx)

        if self.pretrain_emb_dir is not None:
            #word_emb_fname = "word_emb.txt.gz" #for query and target words in pv and pvc
            word_emb_fname = "context_emb.txt.gz" if args.review_encoder_name == "pvc" else "word_emb.txt.gz"  #for query and target words in pv and pvc
            pretrain_word_emb_path = os.path.join(self.pretrain_emb_dir,
                                                  word_emb_fname)
            word_index_dic, pretrained_weights = load_pretrain_embeddings(
                pretrain_word_emb_path)
            word_indices = torch.tensor(
                [0] + [word_index_dic[x]
                       for x in self.vocab_words[1:]] + [self.word_pad_idx])
            #print(len(word_indices))
            #print(word_indices.cpu().tolist())
            pretrained_weights = torch.FloatTensor(pretrained_weights)
            self.word_embeddings = nn.Embedding.from_pretrained(
                pretrained_weights[word_indices],
                padding_idx=self.word_pad_idx)
            #vectors of padding idx will not be updated
        else:
            self.word_embeddings = nn.Embedding(vocab_size,
                                                self.embedding_size,
                                                padding_idx=self.word_pad_idx)

        if self.fix_emb and args.review_encoder_name == "pvc":
            #if review embeddings are fixed, just load the aggregated embeddings which include all the words in the review
            #otherwise the reviews are cut off at review_word_limit
            self.review_encoder_name = "pv"

        self.transformer_encoder = TransformerEncoder(self.embedding_size,
                                                      args.ff_size, args.heads,
                                                      args.dropout,
                                                      args.inter_layers)

        if self.review_encoder_name == "pv":
            pretrain_emb_path = None
            if self.pretrain_emb_dir is not None:
                pretrain_emb_path = os.path.join(self.pretrain_emb_dir,
                                                 "doc_emb.txt.gz")
            self.review_encoder = ParagraphVector(self.word_embeddings,
                                                  self.word_dists,
                                                  review_count,
                                                  self.emb_dropout,
                                                  pretrain_emb_path,
                                                  fix_emb=self.fix_emb)
        elif self.review_encoder_name == "pvc":
            pretrain_emb_path = None
            #if self.pretrain_emb_dir is not None:
            #    pretrain_emb_path = os.path.join(self.pretrain_emb_dir, "context_emb.txt.gz")
            self.review_encoder = ParagraphVectorCorruption(
                self.word_embeddings,
                self.word_dists,
                args.corrupt_rate,
                self.emb_dropout,
                pretrain_emb_path,
                self.vocab_words,
                fix_emb=self.fix_emb)
        elif self.review_encoder_name == "fs":
            self.review_encoder = FSEncoder(self.embedding_size,
                                            self.emb_dropout)
        else:
            self.review_encoder = AVGEncoder(self.embedding_size,
                                             self.emb_dropout)

        if args.query_encoder_name == "fs":
            self.query_encoder = FSEncoder(self.embedding_size,
                                           self.emb_dropout)
        else:
            self.query_encoder = AVGEncoder(self.embedding_size,
                                            self.emb_dropout)
        self.seg_embeddings = nn.Embedding(4,
                                           self.embedding_size,
                                           padding_idx=self.seg_pad_idx)
        #for each q,u,i
        #Q, previous purchases of u, current available reviews for i, padding value
        #self.logsoftmax = torch.nn.LogSoftmax(dim = -1)
        #self.bce_logits_loss = torch.nn.BCEWithLogitsLoss(reduction='none')#by default it's mean

        self.review_embeddings = None
        if self.fix_emb:
            #self.word_embeddings.weight.requires_grad = False
            #embeddings of query words need to be update
            #self.emb_dropout = 0
            self.get_review_embeddings()  #get model.review_embeddings

        self.initialize_parameters(logger)  #logger
        self.to(device)  #change model in place

    def load_cp(self, pt, strict=True):
        self.load_state_dict(pt['model'], strict=strict)

    def clear_review_embbeddings(self):
        #otherwise review_embeddings are always the same
        if not self.fix_emb:
            self.review_embeddings = None
            #del self.review_embeddings
            torch.cuda.empty_cache()

    def get_review_embeddings(self, batch_size=128):
        if hasattr(self,
                   "review_embeddings") and self.review_embeddings is not None:
            return  #if already computed and not deleted
        if self.review_encoder_name == "pv":
            self.review_embeddings = self.review_encoder.review_embeddings.weight
        else:
            review_count = self.review_pad_idx
            seg_count = int((review_count - 1) / batch_size) + 1
            self.review_embeddings = torch.zeros(review_count + 1,
                                                 self.embedding_size,
                                                 device=self.device)
            #The last one is always 0
            for i in range(seg_count):
                slice_reviews = self.review_words[i * batch_size:(i + 1) *
                                                  batch_size]
                if self.review_encoder_name == "pvc":
                    self.review_encoder.set_to_evaluation_mode()
                    slice_review_emb = self.review_encoder.get_para_vector(
                        slice_reviews)
                    self.review_encoder.set_to_train_mode()
                else:  #fs or avg
                    slice_rword_emb = self.word_embeddings(slice_reviews)
                    slice_review_emb = self.review_encoder(
                        slice_rword_emb, slice_reviews.ne(self.word_pad_idx))
                self.review_embeddings[i * batch_size:(i + 1) *
                                       batch_size] = slice_review_emb

    def test(self, batch_data):
        query_word_idxs = batch_data.query_word_idxs
        candi_prod_ridxs = batch_data.candi_prod_ridxs
        candi_seg_idxs = batch_data.candi_seg_idxs
        candi_seq_item_idxs = batch_data.candi_seq_item_idxs
        candi_seq_user_idxs = batch_data.candi_seq_user_idxs
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        batch_size, candi_k, candi_rcount = candi_prod_ridxs.size()
        candi_review_emb = self.review_embeddings[candi_prod_ridxs]

        #concat query_emb with pos_review_emb and candi_review_emb
        query_mask = torch.ones(batch_size,
                                1,
                                dtype=torch.uint8,
                                device=query_word_idxs.device)
        candi_prod_ridx_mask = candi_prod_ridxs.ne(self.review_pad_idx)
        candi_review_mask = torch.cat([
            query_mask.unsqueeze(1).expand(-1, candi_k, -1),
            candi_prod_ridx_mask
        ],
                                      dim=2)
        #batch_size, 1, embedding_size
        candi_sequence_emb = torch.cat((query_emb.unsqueeze(1).expand(
            -1, candi_k, -1).unsqueeze(2), candi_review_emb),
                                       dim=2)
        #batch_size, candi_k, max_review_count+1, embedding_size
        candi_seg_emb = self.seg_embeddings(
            candi_seg_idxs
        )  #batch_size, candi_k, max_review_count+1, embedding_size
        if self.args.use_seg_emb:
            candi_sequence_emb += candi_seg_emb
        if self.args.use_user_emb:
            candi_seq_user_emb = self.user_emb(candi_seq_user_idxs)
            candi_sequence_emb += candi_seq_user_emb
        if self.args.use_item_emb:
            candi_seq_item_emb = self.product_emb(candi_seq_item_idxs)
            candi_sequence_emb += candi_seq_item_emb

        candi_scores = self.transformer_encoder(
            candi_sequence_emb.view(batch_size * candi_k, candi_rcount + 1,
                                    -1),
            candi_review_mask.view(batch_size * candi_k, candi_rcount + 1),
            use_pos=self.args.use_pos_emb)
        candi_scores = candi_scores.view(batch_size, candi_k)
        return candi_scores

    def forward(self, batch_data, train_pv=True):
        query_word_idxs = batch_data.query_word_idxs
        pos_prod_ridxs = batch_data.pos_prod_ridxs
        pos_seg_idxs = batch_data.pos_seg_idxs
        pos_prod_rword_idxs = batch_data.pos_prod_rword_idxs
        pos_prod_rword_masks = batch_data.pos_prod_rword_masks
        neg_prod_ridxs = batch_data.neg_prod_ridxs
        neg_seg_idxs = batch_data.neg_seg_idxs
        pos_user_idxs = batch_data.pos_user_idxs
        neg_user_idxs = batch_data.neg_user_idxs
        pos_item_idxs = batch_data.pos_item_idxs
        neg_item_idxs = batch_data.neg_item_idxs
        neg_prod_rword_idxs = batch_data.neg_prod_rword_idxs
        neg_prod_rword_masks = batch_data.neg_prod_rword_masks
        pos_prod_rword_idxs_pvc = batch_data.pos_prod_rword_idxs_pvc
        neg_prod_rword_idxs_pvc = batch_data.neg_prod_rword_idxs_pvc
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        batch_size, pos_rcount, posr_word_limit = pos_prod_rword_idxs.size()
        _, neg_k, neg_rcount = neg_prod_ridxs.size()
        posr_word_emb = self.word_embeddings(
            pos_prod_rword_idxs.view(-1, posr_word_limit))
        update_pos_prod_rword_masks = pos_prod_rword_masks.view(
            -1, posr_word_limit)
        pv_loss = None
        if "pv" in self.review_encoder_name:
            if train_pv:
                if self.review_encoder_name == "pv":
                    pos_review_emb, pos_prod_loss = self.review_encoder(
                        pos_prod_ridxs.view(-1), posr_word_emb,
                        update_pos_prod_rword_masks, self.args.neg_per_pos)
                elif self.review_encoder_name == "pvc":
                    pos_review_emb, pos_prod_loss = self.review_encoder(
                        posr_word_emb, update_pos_prod_rword_masks,
                        pos_prod_rword_idxs_pvc.view(
                            -1, pos_prod_rword_idxs_pvc.size(-1)),
                        self.args.neg_per_pos)
                sample_count = pos_prod_ridxs.ne(
                    self.review_pad_idx).float().sum(-1)
                # it won't be less than batch_size since there is not any sequence with all padding indices
                #sample_count = sample_count.masked_fill(sample_count.eq(0),1)
                pv_loss = pos_prod_loss.sum() / sample_count.sum()
            else:
                if self.fix_emb:
                    pos_review_emb = self.review_embeddings[pos_prod_ridxs]
                else:
                    if self.review_encoder_name == "pv":
                        pos_review_emb = self.review_encoder.get_para_vector(
                            pos_prod_ridxs)
                    elif self.review_encoder_name == "pvc":
                        pos_review_emb = self.review_encoder.get_para_vector(
                            #pos_prod_rword_idxs_pvc.view(-1, pos_prod_rword_idxs_pvc.size(-1)))
                            pos_prod_rword_idxs.view(
                                -1, pos_prod_rword_idxs.size(-1)))
            if self.fix_emb:
                neg_review_emb = self.review_embeddings[neg_prod_ridxs]
            else:
                if self.review_encoder_name == "pv":
                    neg_review_emb = self.review_encoder.get_para_vector(
                        neg_prod_ridxs)
                elif self.review_encoder_name == "pvc":
                    if not train_pv:
                        neg_prod_rword_idxs_pvc = neg_prod_rword_idxs
                    neg_review_emb = self.review_encoder.get_para_vector(
                        neg_prod_rword_idxs_pvc.view(
                            -1, neg_prod_rword_idxs_pvc.size(-1)))
            pos_review_emb = self.dropout_layer(pos_review_emb)
            neg_review_emb = self.dropout_layer(neg_review_emb)
        else:
            negr_word_limit = neg_prod_rword_idxs.size()[-1]
            negr_word_emb = self.word_embeddings(
                neg_prod_rword_idxs.view(-1, negr_word_limit))
            pos_review_emb = self.review_encoder(posr_word_emb,
                                                 update_pos_prod_rword_masks)
            neg_review_emb = self.review_encoder(
                negr_word_emb, neg_prod_rword_masks.view(-1, negr_word_limit))

        pos_review_emb = pos_review_emb.view(batch_size, pos_rcount, -1)
        neg_review_emb = neg_review_emb.view(batch_size, neg_k, neg_rcount, -1)

        #concat query_emb with pos_review_emb and neg_review_emb
        query_mask = torch.ones(batch_size,
                                1,
                                dtype=torch.uint8,
                                device=query_word_idxs.device)
        pos_review_mask = torch.cat(
            [query_mask, pos_prod_ridxs.ne(self.review_pad_idx)],
            dim=1)  #batch_size, 1+max_review_count
        neg_prod_ridx_mask = neg_prod_ridxs.ne(self.review_pad_idx)
        neg_review_mask = torch.cat([
            query_mask.unsqueeze(1).expand(-1, neg_k, -1), neg_prod_ridx_mask
        ],
                                    dim=2)
        #batch_size, 1, embedding_size
        pos_sequence_emb = torch.cat((query_emb.unsqueeze(1), pos_review_emb),
                                     dim=1)
        pos_seg_emb = self.seg_embeddings(
            pos_seg_idxs)  #batch_size, max_review_count+1, embedding_size
        neg_sequence_emb = torch.cat((query_emb.unsqueeze(1).expand(
            -1, neg_k, -1).unsqueeze(2), neg_review_emb),
                                     dim=2)
        #batch_size, neg_k, max_review_count+1, embedding_size
        neg_seg_emb = self.seg_embeddings(
            neg_seg_idxs
        )  #batch_size, neg_k, max_review_count+1, embedding_size
        if self.args.use_seg_emb:
            pos_sequence_emb += pos_seg_emb
            neg_sequence_emb += neg_seg_emb
        if self.args.use_item_emb:
            pos_seq_item_emb = self.product_emb(pos_item_idxs)
            neg_seq_item_emb = self.product_emb(neg_item_idxs)
            pos_sequence_emb += pos_seq_item_emb
            neg_sequence_emb += neg_seq_item_emb
        if self.args.use_user_emb:
            pos_seq_user_emb = self.user_emb(pos_user_idxs)
            neg_seq_user_emb = self.user_emb(neg_user_idxs)
            pos_sequence_emb += pos_seq_user_emb
            neg_sequence_emb += neg_seq_user_emb

        pos_scores = self.transformer_encoder(pos_sequence_emb,
                                              pos_review_mask,
                                              use_pos=self.args.use_pos_emb)
        neg_scores = self.transformer_encoder(
            neg_sequence_emb.view(batch_size * neg_k, neg_rcount + 1, -1),
            neg_review_mask.view(batch_size * neg_k, neg_rcount + 1),
            use_pos=self.args.use_pos_emb)
        neg_scores = neg_scores.view(batch_size, neg_k)
        pos_weight = 1
        if self.args.pos_weight:
            pos_weight = self.args.neg_per_pos
        prod_mask = torch.cat(
            [
                torch.ones(batch_size,
                           1,
                           dtype=torch.uint8,
                           device=query_word_idxs.device) * pos_weight,
                neg_prod_ridx_mask.sum(-1).ne(0)
            ],
            dim=-1)  #batch_size, neg_k (valid products, some are padded)
        #TODO: this mask does not reflect true neg prods, when reviews are randomly selected all of them should valid since there is no need for padding
        prod_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1)
        target = torch.cat([
            torch.ones(batch_size, 1, device=query_word_idxs.device),
            torch.zeros(batch_size, neg_k, device=query_word_idxs.device)
        ],
                           dim=-1)
        #ps_loss = self.bce_logits_loss(prod_scores, target, weight=prod_mask.float())
        ps_loss = nn.functional.binary_cross_entropy_with_logits(
            prod_scores, target, weight=prod_mask.float(), reduction='none')

        ps_loss = ps_loss.sum(-1).mean()
        loss = ps_loss + pv_loss if pv_loss is not None else ps_loss
        return loss

    def initialize_parameters(self, logger=None):
        if logger:
            logger.info(" ProductRanker initialization started.")
        if self.pretrain_emb_dir is None:
            nn.init.normal_(self.word_embeddings.weight)
        nn.init.normal_(self.seg_embeddings.weight)
        self.review_encoder.initialize_parameters(logger)
        self.query_encoder.initialize_parameters(logger)
        self.transformer_encoder.initialize_parameters(logger)
        if logger:
            logger.info(" ProductRanker initialization finished.")
    def __init__(self,
                 args,
                 device,
                 vocab_size,
                 product_size,
                 vocab_words,
                 word_dists=None):
        super(ItemTransformerRanker, self).__init__()
        self.args = args
        self.device = device
        self.train_review_only = args.train_review_only
        self.embedding_size = args.embedding_size
        self.vocab_words = vocab_words
        self.word_dists = None
        if word_dists is not None:
            self.word_dists = torch.tensor(word_dists, device=device)
        self.prod_dists = torch.ones(product_size, device=device)
        self.prod_pad_idx = product_size
        self.word_pad_idx = vocab_size - 1
        self.seg_pad_idx = 3
        self.emb_dropout = args.dropout
        self.pretrain_emb_dir = None
        if os.path.exists(args.pretrain_emb_dir):
            self.pretrain_emb_dir = args.pretrain_emb_dir
        self.pretrain_up_emb_dir = None
        if os.path.exists(args.pretrain_up_emb_dir):
            self.pretrain_up_emb_dir = args.pretrain_up_emb_dir
        self.dropout_layer = nn.Dropout(p=args.dropout)

        self.product_emb = nn.Embedding(product_size + 1,
                                        self.embedding_size,
                                        padding_idx=self.prod_pad_idx)
        if args.sep_prod_emb:
            self.hist_product_emb = nn.Embedding(product_size + 1,
                                                 self.embedding_size,
                                                 padding_idx=self.prod_pad_idx)
        '''
        else:
            pretrain_product_emb_path = os.path.join(self.pretrain_up_emb_dir, "product_emb.txt")
            pretrained_weights = load_user_item_embeddings(pretrain_product_emb_path)
            pretrained_weights.append([0.] * len(pretrained_weights[0]))
            self.product_emb = nn.Embedding.from_pretrained(torch.FloatTensor(pretrained_weights), padding_idx=self.prod_pad_idx)
        '''
        self.product_bias = nn.Parameter(torch.zeros(product_size + 1),
                                         requires_grad=True)
        self.word_bias = nn.Parameter(torch.zeros(vocab_size),
                                      requires_grad=True)

        if self.pretrain_emb_dir is not None:
            word_emb_fname = "word_emb.txt.gz"  #for query and target words in pv and pvc
            pretrain_word_emb_path = os.path.join(self.pretrain_emb_dir,
                                                  word_emb_fname)
            word_index_dic, pretrained_weights = load_pretrain_embeddings(
                pretrain_word_emb_path)
            word_indices = torch.tensor(
                [0] + [word_index_dic[x]
                       for x in self.vocab_words[1:]] + [self.word_pad_idx])
            #print(len(word_indices))
            #print(word_indices.cpu().tolist())
            pretrained_weights = torch.FloatTensor(pretrained_weights)
            self.word_embeddings = nn.Embedding.from_pretrained(
                pretrained_weights[word_indices],
                padding_idx=self.word_pad_idx)
            #vectors of padding idx will not be updated
        else:
            self.word_embeddings = nn.Embedding(vocab_size,
                                                self.embedding_size,
                                                padding_idx=self.word_pad_idx)
        if self.args.model_name == "item_transformer":
            self.transformer_encoder = TransformerEncoder(
                self.embedding_size, args.ff_size, args.heads, args.dropout,
                args.inter_layers)
        #if self.args.model_name == "ZAM" or self.args.model_name == "AEM":
        else:
            self.attention_encoder = MultiHeadedAttention(
                args.heads, self.embedding_size, args.dropout)

        if args.query_encoder_name == "fs":
            self.query_encoder = FSEncoder(self.embedding_size,
                                           self.emb_dropout)
        else:
            self.query_encoder = AVGEncoder(self.embedding_size,
                                            self.emb_dropout)
        self.seg_embeddings = nn.Embedding(4,
                                           self.embedding_size,
                                           padding_idx=self.seg_pad_idx)
        #for each q,u,i
        #Q, previous purchases of u, current available reviews for i, padding value
        #self.logsoftmax = torch.nn.LogSoftmax(dim = -1)
        self.bce_logits_loss = torch.nn.BCEWithLogitsLoss(
            reduction='none')  #by default it's mean

        self.initialize_parameters(logger)  #logger
        self.to(device)  #change model in place
        self.item_loss = 0
        self.ps_loss = 0
class ItemTransformerRanker(nn.Module):
    def __init__(self,
                 args,
                 device,
                 vocab_size,
                 product_size,
                 vocab_words,
                 word_dists=None):
        super(ItemTransformerRanker, self).__init__()
        self.args = args
        self.device = device
        self.train_review_only = args.train_review_only
        self.embedding_size = args.embedding_size
        self.vocab_words = vocab_words
        self.word_dists = None
        if word_dists is not None:
            self.word_dists = torch.tensor(word_dists, device=device)
        self.prod_dists = torch.ones(product_size, device=device)
        self.prod_pad_idx = product_size
        self.word_pad_idx = vocab_size - 1
        self.seg_pad_idx = 3
        self.emb_dropout = args.dropout
        self.pretrain_emb_dir = None
        if os.path.exists(args.pretrain_emb_dir):
            self.pretrain_emb_dir = args.pretrain_emb_dir
        self.pretrain_up_emb_dir = None
        if os.path.exists(args.pretrain_up_emb_dir):
            self.pretrain_up_emb_dir = args.pretrain_up_emb_dir
        self.dropout_layer = nn.Dropout(p=args.dropout)

        self.product_emb = nn.Embedding(product_size + 1,
                                        self.embedding_size,
                                        padding_idx=self.prod_pad_idx)
        if args.sep_prod_emb:
            self.hist_product_emb = nn.Embedding(product_size + 1,
                                                 self.embedding_size,
                                                 padding_idx=self.prod_pad_idx)
        '''
        else:
            pretrain_product_emb_path = os.path.join(self.pretrain_up_emb_dir, "product_emb.txt")
            pretrained_weights = load_user_item_embeddings(pretrain_product_emb_path)
            pretrained_weights.append([0.] * len(pretrained_weights[0]))
            self.product_emb = nn.Embedding.from_pretrained(torch.FloatTensor(pretrained_weights), padding_idx=self.prod_pad_idx)
        '''
        self.product_bias = nn.Parameter(torch.zeros(product_size + 1),
                                         requires_grad=True)
        self.word_bias = nn.Parameter(torch.zeros(vocab_size),
                                      requires_grad=True)

        if self.pretrain_emb_dir is not None:
            word_emb_fname = "word_emb.txt.gz"  #for query and target words in pv and pvc
            pretrain_word_emb_path = os.path.join(self.pretrain_emb_dir,
                                                  word_emb_fname)
            word_index_dic, pretrained_weights = load_pretrain_embeddings(
                pretrain_word_emb_path)
            word_indices = torch.tensor(
                [0] + [word_index_dic[x]
                       for x in self.vocab_words[1:]] + [self.word_pad_idx])
            #print(len(word_indices))
            #print(word_indices.cpu().tolist())
            pretrained_weights = torch.FloatTensor(pretrained_weights)
            self.word_embeddings = nn.Embedding.from_pretrained(
                pretrained_weights[word_indices],
                padding_idx=self.word_pad_idx)
            #vectors of padding idx will not be updated
        else:
            self.word_embeddings = nn.Embedding(vocab_size,
                                                self.embedding_size,
                                                padding_idx=self.word_pad_idx)
        if self.args.model_name == "item_transformer":
            self.transformer_encoder = TransformerEncoder(
                self.embedding_size, args.ff_size, args.heads, args.dropout,
                args.inter_layers)
        #if self.args.model_name == "ZAM" or self.args.model_name == "AEM":
        else:
            self.attention_encoder = MultiHeadedAttention(
                args.heads, self.embedding_size, args.dropout)

        if args.query_encoder_name == "fs":
            self.query_encoder = FSEncoder(self.embedding_size,
                                           self.emb_dropout)
        else:
            self.query_encoder = AVGEncoder(self.embedding_size,
                                            self.emb_dropout)
        self.seg_embeddings = nn.Embedding(4,
                                           self.embedding_size,
                                           padding_idx=self.seg_pad_idx)
        #for each q,u,i
        #Q, previous purchases of u, current available reviews for i, padding value
        #self.logsoftmax = torch.nn.LogSoftmax(dim = -1)
        self.bce_logits_loss = torch.nn.BCEWithLogitsLoss(
            reduction='none')  #by default it's mean

        self.initialize_parameters(logger)  #logger
        self.to(device)  #change model in place
        self.item_loss = 0
        self.ps_loss = 0

    def clear_loss(self):
        self.item_loss = 0
        self.ps_loss = 0

    def load_cp(self, pt, strict=True):
        self.load_state_dict(pt['model'], strict=strict)

    def test(self, batch_data):
        if self.args.model_name == "item_transformer":
            if self.args.use_dot_prod:
                return self.test_dotproduct(batch_data)
            else:
                return self.test_trans(batch_data)
        else:
            return self.test_attn(batch_data)

    def test_dotproduct(self, batch_data):
        query_word_idxs = batch_data.query_word_idxs
        target_prod_idxs = batch_data.target_prod_idxs
        u_item_idxs = batch_data.u_item_idxs
        candi_prod_idxs = batch_data.candi_prod_idxs
        batch_size, prev_item_count = u_item_idxs.size()
        _, candi_k = candi_prod_idxs.size()
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        column_mask = torch.ones(batch_size,
                                 1,
                                 dtype=torch.uint8,
                                 device=query_word_idxs.device)
        u_item_mask = u_item_idxs.ne(self.prod_pad_idx)
        u_item_mask = u_item_mask.unsqueeze(1).expand(-1, candi_k, -1)
        column_mask = column_mask.unsqueeze(1).expand(-1, candi_k, -1)
        candi_item_seq_mask = torch.cat([column_mask, u_item_mask], dim=2)
        candi_item_emb = self.product_emb(
            candi_prod_idxs)  #batch_size, candi_k, embedding_size
        if self.args.sep_prod_emb:
            u_item_emb = self.hist_product_emb(u_item_idxs)
        else:
            u_item_emb = self.product_emb(u_item_idxs)
        candi_sequence_emb = torch.cat([
            query_emb.unsqueeze(1).expand(-1, candi_k, -1).unsqueeze(2),
            u_item_emb.unsqueeze(1).expand(-1, candi_k, -1, -1)
        ],
                                       dim=2)

        out_pos = -1 if self.args.use_item_pos else 0
        top_vecs = self.transformer_encoder.encode(
            candi_sequence_emb.view(batch_size * candi_k, prev_item_count + 1,
                                    -1),
            candi_item_seq_mask.view(batch_size * candi_k,
                                     prev_item_count + 1),
            use_pos=self.args.use_pos_emb)
        candi_out_emb = top_vecs[:, out_pos, :]
        candi_scores = torch.bmm(
            candi_out_emb.unsqueeze(1),
            candi_item_emb.view(batch_size * candi_k, -1).unsqueeze(2))
        candi_scores = candi_scores.view(batch_size, candi_k)
        if self.args.sim_func == "bias_product":
            candi_bias = self.product_bias[candi_prod_idxs.view(-1)].view(
                batch_size, candi_k)
            candi_scores += candi_bias
        return candi_scores

    def test_attn(self, batch_data):
        query_word_idxs = batch_data.query_word_idxs
        target_prod_idxs = batch_data.target_prod_idxs
        u_item_idxs = batch_data.u_item_idxs
        candi_prod_idxs = batch_data.candi_prod_idxs
        batch_size, prev_item_count = u_item_idxs.size()
        _, candi_k = candi_prod_idxs.size()
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        embed_size = query_emb.size()[-1]
        candi_item_emb = self.product_emb(
            candi_prod_idxs)  #batch_size, candi_k, embedding_size
        if self.args.model_name == "QEM":
            candi_out_emb = query_emb.unsqueeze(1).expand(
                -1, candi_k, -1).contiguous().view(batch_size * candi_k,
                                                   embed_size)
        else:  #if self.args.model_name == "AEM" or self.args.model_name == "ZAM":
            u_item_mask = u_item_idxs.ne(self.prod_pad_idx)
            candi_item_seq_mask = u_item_mask.unsqueeze(1).expand(
                -1, candi_k, -1)
            if self.args.sep_prod_emb:
                u_item_emb = self.hist_product_emb(u_item_idxs)
            else:
                u_item_emb = self.product_emb(u_item_idxs)

            candi_sequence_emb = u_item_emb.unsqueeze(1).expand(
                -1, candi_k, -1, -1)

            if self.args.model_name == "ZAM":
                zero_column = torch.zeros(batch_size,
                                          1,
                                          embed_size,
                                          device=query_word_idxs.device)
                column_mask = torch.ones(batch_size,
                                         1,
                                         dtype=torch.uint8,
                                         device=query_word_idxs.device)
                column_mask = column_mask.unsqueeze(1).expand(-1, candi_k, -1)
                candi_item_seq_mask = torch.cat(
                    [column_mask, candi_item_seq_mask], dim=2)
                pos_sequence_emb = torch.cat([zero_column, u_item_emb], dim=1)
                candi_sequence_emb = torch.cat([
                    zero_column.expand(-1, candi_k, -1).unsqueeze(2),
                    candi_sequence_emb
                ],
                                               dim=2)

            candi_item_seq_mask = candi_item_seq_mask.contiguous().view(
                batch_size * candi_k, 1, -1)
            out_pos = 0
            candi_sequence_emb = candi_sequence_emb.contiguous().view(
                batch_size * candi_k, -1, embed_size)
            query_emb = query_emb.unsqueeze(1).expand(
                -1, candi_k, -1).contiguous().view(batch_size * candi_k, 1,
                                                   embed_size)
            top_vecs = self.attention_encoder(candi_sequence_emb,
                                              candi_sequence_emb,
                                              query_emb,
                                              mask=1 - candi_item_seq_mask)
            candi_out_emb = 0.5 * top_vecs[:,
                                           out_pos, :] + 0.5 * query_emb.squeeze(
                                               1)

        candi_scores = torch.bmm(
            candi_out_emb.unsqueeze(1),
            candi_item_emb.view(batch_size * candi_k, -1).unsqueeze(2))
        candi_scores = candi_scores.view(batch_size, candi_k)

        if self.args.sim_func == "bias_product":
            candi_bias = self.product_bias[candi_prod_idxs.view(-1)].view(
                batch_size, candi_k)
            candi_scores += candi_bias
        return candi_scores

    def test_trans(self, batch_data):
        query_word_idxs = batch_data.query_word_idxs
        target_prod_idxs = batch_data.target_prod_idxs
        u_item_idxs = batch_data.u_item_idxs
        candi_prod_idxs = batch_data.candi_prod_idxs
        batch_size, prev_item_count = u_item_idxs.size()
        _, candi_k = candi_prod_idxs.size()
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        column_mask = torch.ones(batch_size,
                                 1,
                                 dtype=torch.uint8,
                                 device=query_word_idxs.device)
        u_item_mask = u_item_idxs.ne(self.prod_pad_idx)
        u_item_mask = u_item_mask.unsqueeze(1).expand(-1, candi_k, -1)
        column_mask = column_mask.unsqueeze(1).expand(-1, candi_k, -1)
        candi_item_seq_mask = torch.cat(
            [column_mask, u_item_mask, column_mask], dim=2)
        candi_seg_idxs = torch.cat([
            column_mask * 0,
            column_mask.expand(-1, -1, prev_item_count), column_mask * 2
        ],
                                   dim=2)
        candi_item_emb = self.product_emb(
            candi_prod_idxs)  #batch_size, candi_k, embedding_size
        if self.args.sep_prod_emb:
            u_item_emb = self.hist_product_emb(u_item_idxs)
        else:
            u_item_emb = self.product_emb(u_item_idxs)
        candi_sequence_emb = torch.cat([
            query_emb.unsqueeze(1).expand(-1, candi_k, -1).unsqueeze(2),
            u_item_emb.unsqueeze(1).expand(-1, candi_k, -1, -1),
            candi_item_emb.unsqueeze(2)
        ],
                                       dim=2)
        candi_seg_emb = self.seg_embeddings(candi_seg_idxs.long(
        ))  #batch_size, candi_k, max_prev_item_count+1, embedding_size
        candi_sequence_emb += candi_seg_emb

        out_pos = -1 if self.args.use_item_pos else 0
        candi_scores = self.transformer_encoder(
            candi_sequence_emb.view(batch_size * candi_k, prev_item_count + 2,
                                    -1),
            candi_item_seq_mask.view(batch_size * candi_k,
                                     prev_item_count + 2),
            use_pos=self.args.use_pos_emb,
            out_pos=out_pos)
        candi_scores = candi_scores.view(batch_size, candi_k)
        return candi_scores

    def test_seq(self, batch_data):
        query_word_idxs = batch_data.query_word_idxs
        candi_seg_idxs = batch_data.neg_seg_idxs
        candi_seq_item_idxs = batch_data.neg_seq_item_idxs
        batch_size, candi_k, prev_item_count = batch_data.neg_seq_item_idxs.size(
        )
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        candi_seq_item_emb = self.product_emb(
            candi_seq_item_idxs
        )  #batch_size, candi_k, max_prev_item_count, embedding_size
        #concat query_emb with pos_review_emb and candi_review_emb
        query_mask = torch.ones(batch_size,
                                1,
                                dtype=torch.uint8,
                                device=query_word_idxs.device)
        candi_prod_idx_mask = candi_seq_item_idxs.ne(self.prod_pad_idx)
        candi_seq_item_mask = torch.cat([
            query_mask.unsqueeze(1).expand(-1, candi_k, -1),
            candi_prod_idx_mask
        ],
                                        dim=2)
        #batch_size, 1, embedding_size
        candi_sequence_emb = torch.cat((query_emb.unsqueeze(1).expand(
            -1, candi_k, -1).unsqueeze(2), candi_seq_item_emb),
                                       dim=2)
        #batch_size, candi_k, max_review_count+1, embedding_size
        candi_seg_emb = self.seg_embeddings(
            candi_seg_idxs
        )  #batch_size, candi_k, max_review_count+1, embedding_size
        candi_sequence_emb += candi_seg_emb

        candi_scores = self.transformer_encoder(
            candi_sequence_emb.view(batch_size * candi_k, prev_item_count + 1,
                                    -1),
            candi_seq_item_mask.view(batch_size * candi_k,
                                     prev_item_count + 1))
        candi_scores = candi_scores.view(batch_size, candi_k)
        return candi_scores

    def item_to_words(self, target_prod_idxs, target_word_idxs, n_negs):
        batch_size, pv_window_size = target_word_idxs.size()
        prod_emb = self.product_emb(target_prod_idxs)
        target_word_emb = self.word_embeddings(target_word_idxs)

        #for each target word, there is k words negative sampling
        #vocab_size = self.word_embeddings.weight.size() - 1
        #compute the loss of review generating positive and negative words
        neg_sample_idxs = torch.multinomial(self.word_dists,
                                            batch_size * pv_window_size *
                                            n_negs,
                                            replacement=True)
        neg_sample_emb = self.word_embeddings(
            neg_sample_idxs.view(batch_size, -1))
        output_pos = torch.bmm(
            target_word_emb,
            prod_emb.unsqueeze(2))  # batch_size, pv_window_size, 1
        output_neg = torch.bmm(neg_sample_emb, prod_emb.unsqueeze(2)).view(
            batch_size, pv_window_size, -1)
        pos_bias = self.word_bias[target_word_idxs.view(-1)].view(
            batch_size, pv_window_size, 1)
        neg_bias = self.word_bias[neg_sample_idxs].view(
            batch_size, pv_window_size, -1)
        output_pos += pos_bias
        output_neg += neg_bias

        scores = torch.cat((output_pos, output_neg),
                           dim=-1)  #batch_size, pv_window_size, 1+n_negs
        target = torch.cat(
            (torch.ones(output_pos.size(), device=scores.device),
             torch.zeros(output_neg.size(), device=scores.device)),
            dim=-1)
        loss = self.bce_logits_loss(scores, target).sum(
            -1)  #batch_size, pv_window_size
        loss = get_vector_mean(loss.unsqueeze(-1),
                               target_word_idxs.ne(self.word_pad_idx))
        loss = loss.mean()
        return loss

    def forward_trans(self, batch_data, train_pv=False):
        query_word_idxs = batch_data.query_word_idxs
        target_prod_idxs = batch_data.target_prod_idxs
        u_item_idxs = batch_data.u_item_idxs
        batch_size, prev_item_count = u_item_idxs.size()
        neg_k = self.args.neg_per_pos
        pos_iword_idxs = batch_data.pos_iword_idxs
        neg_item_idxs = torch.multinomial(self.prod_dists,
                                          batch_size * neg_k,
                                          replacement=True)
        neg_item_idxs = neg_item_idxs.view(batch_size, -1)
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        column_mask = torch.ones(batch_size,
                                 1,
                                 dtype=torch.uint8,
                                 device=query_word_idxs.device)
        u_item_mask = u_item_idxs.ne(self.prod_pad_idx)
        pos_item_seq_mask = torch.cat([column_mask, u_item_mask, column_mask],
                                      dim=1)  #batch_size, 1+max_review_count

        pos_seg_idxs = torch.cat([
            column_mask * 0,
            column_mask.expand(-1, prev_item_count), column_mask * 2
        ],
                                 dim=1)
        column_mask = column_mask.unsqueeze(1).expand(-1, neg_k, -1)
        neg_item_seq_mask = torch.cat([
            column_mask,
            u_item_mask.unsqueeze(1).expand(-1, neg_k, -1), column_mask
        ],
                                      dim=2)
        neg_seg_idxs = torch.cat([
            column_mask * 0,
            column_mask.expand(-1, -1, prev_item_count), column_mask * 2
        ],
                                 dim=2)
        target_item_emb = self.product_emb(target_prod_idxs)
        neg_item_emb = self.product_emb(
            neg_item_idxs)  #batch_size, neg_k, embedding_size
        if self.args.sep_prod_emb:
            u_item_emb = self.hist_product_emb(u_item_idxs)
        else:
            u_item_emb = self.product_emb(u_item_idxs)
        pos_sequence_emb = torch.cat(
            [query_emb.unsqueeze(1), u_item_emb,
             target_item_emb.unsqueeze(1)],
            dim=1)
        pos_seg_emb = self.seg_embeddings(pos_seg_idxs.long())
        neg_sequence_emb = torch.cat([
            query_emb.unsqueeze(1).expand(-1, neg_k, -1).unsqueeze(2),
            u_item_emb.unsqueeze(1).expand(-1, neg_k, -1, -1),
            neg_item_emb.unsqueeze(2)
        ],
                                     dim=2)
        neg_seg_emb = self.seg_embeddings(neg_seg_idxs.long(
        ))  #batch_size, neg_k, max_prev_item_count+1, embedding_size
        pos_sequence_emb += pos_seg_emb
        neg_sequence_emb += neg_seg_emb

        out_pos = -1 if self.args.use_item_pos else 0
        pos_scores = self.transformer_encoder(pos_sequence_emb,
                                              pos_item_seq_mask,
                                              use_pos=self.args.use_pos_emb,
                                              out_pos=out_pos)

        neg_scores = self.transformer_encoder(
            neg_sequence_emb.view(batch_size * neg_k, prev_item_count + 2, -1),
            neg_item_seq_mask.view(batch_size * neg_k, prev_item_count + 2),
            use_pos=self.args.use_pos_emb,
            out_pos=out_pos)
        neg_scores = neg_scores.view(batch_size, neg_k)
        pos_weight = 1
        if self.args.pos_weight:
            pos_weight = self.args.neg_per_pos
        prod_mask = torch.cat([
            torch.ones(batch_size,
                       1,
                       dtype=torch.uint8,
                       device=query_word_idxs.device) * pos_weight,
            torch.ones(batch_size,
                       neg_k,
                       dtype=torch.uint8,
                       device=query_word_idxs.device)
        ],
                              dim=-1)
        prod_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1)
        target = torch.cat([
            torch.ones(batch_size, 1, device=query_word_idxs.device),
            torch.zeros(batch_size, neg_k, device=query_word_idxs.device)
        ],
                           dim=-1)
        #ps_loss = self.bce_logits_loss(prod_scores, target, weight=prod_mask.float())
        #for all positive items, there are neg_k negative items
        ps_loss = nn.functional.binary_cross_entropy_with_logits(
            prod_scores, target, weight=prod_mask.float(), reduction='none')
        ps_loss = ps_loss.sum(-1).mean()
        item_loss = self.item_to_words(target_prod_idxs, pos_iword_idxs,
                                       self.args.neg_per_pos)
        self.ps_loss += ps_loss.item()
        self.item_loss += item_loss.item()
        #logger.info("ps_loss:{} item_loss:{}".format(, item_loss.item()))

        return ps_loss + item_loss

    def forward(self, batch_data, train_pv=False):
        if self.args.model_name == "item_transformer":
            if self.args.use_dot_prod:
                return self.forward_dotproduct(batch_data)
            else:
                return self.forward_trans(batch_data)
        else:
            return self.forward_attn(batch_data)

    def forward_attn(self, batch_data, train_pv=False):
        query_word_idxs = batch_data.query_word_idxs
        target_prod_idxs = batch_data.target_prod_idxs
        u_item_idxs = batch_data.u_item_idxs
        batch_size, prev_item_count = u_item_idxs.size()
        neg_k = self.args.neg_per_pos
        pos_iword_idxs = batch_data.pos_iword_idxs
        neg_item_idxs = torch.multinomial(self.prod_dists,
                                          batch_size * neg_k,
                                          replacement=True)
        neg_item_idxs = neg_item_idxs.view(batch_size, -1)
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        embed_size = query_emb.size()[-1]
        target_item_emb = self.product_emb(target_prod_idxs)
        neg_item_emb = self.product_emb(
            neg_item_idxs)  #batch_size, neg_k, embedding_size
        if self.args.model_name == "QEM":
            pos_out_emb = query_emb  #batch_size, embedding_size
            neg_out_emb = query_emb.unsqueeze(1).expand(-1, neg_k,
                                                        -1).contiguous().view(
                                                            batch_size * neg_k,
                                                            embed_size)
        else:  #if self.args.model_name == "ZAM" or self.args.model_name == "AEM":
            u_item_mask = u_item_idxs.ne(self.prod_pad_idx)
            if self.args.sep_prod_emb:
                u_item_emb = self.hist_product_emb(u_item_idxs)
            else:
                u_item_emb = self.product_emb(u_item_idxs)
            pos_sequence_emb = u_item_emb
            neg_sequence_emb = u_item_emb.unsqueeze(1).expand(
                -1, neg_k, -1, -1)
            pos_item_seq_mask = u_item_mask
            neg_item_seq_mask = u_item_mask.unsqueeze(1).expand(-1, neg_k, -1)
            if self.args.model_name == "ZAM":
                zero_column = torch.zeros(batch_size,
                                          1,
                                          embed_size,
                                          device=query_word_idxs.device)
                column_mask = torch.ones(batch_size,
                                         1,
                                         dtype=torch.uint8,
                                         device=query_word_idxs.device)
                pos_item_seq_mask = torch.cat(
                    [column_mask, u_item_mask],
                    dim=1)  #batch_size, 1+max_review_count
                column_mask = column_mask.unsqueeze(1).expand(-1, neg_k, -1)
                neg_item_seq_mask = torch.cat([
                    column_mask,
                    u_item_mask.unsqueeze(1).expand(-1, neg_k, -1)
                ],
                                              dim=2)
                pos_sequence_emb = torch.cat([zero_column, u_item_emb], dim=1)
                neg_sequence_emb = torch.cat([
                    zero_column.expand(-1, neg_k, -1).unsqueeze(2),
                    neg_sequence_emb
                ],
                                             dim=2)

            pos_item_seq_mask = pos_item_seq_mask.unsqueeze(1)
            neg_item_seq_mask = neg_item_seq_mask.contiguous().view(
                batch_size * neg_k, 1, -1)
            out_pos = 0
            top_vecs = self.attention_encoder(pos_sequence_emb,
                                              pos_sequence_emb,
                                              query_emb.unsqueeze(1),
                                              mask=1 - pos_item_seq_mask)
            pos_out_emb = 0.5 * top_vecs[:,
                                         out_pos, :] + 0.5 * query_emb  #batch_size, embedding_size
            neg_sequence_emb = neg_sequence_emb.contiguous().view(
                batch_size * neg_k, -1, embed_size)
            query_emb = query_emb.unsqueeze(1).expand(
                -1, neg_k, -1).contiguous().view(batch_size * neg_k, 1,
                                                 embed_size)
            top_vecs = self.attention_encoder(neg_sequence_emb,
                                              neg_sequence_emb,
                                              query_emb,
                                              mask=1 - neg_item_seq_mask)
            neg_out_emb = 0.5 * top_vecs[:,
                                         out_pos, :] + 0.5 * query_emb.squeeze(
                                             1)

        pos_scores = torch.bmm(pos_out_emb.unsqueeze(1),
                               target_item_emb.unsqueeze(2)).squeeze()
        neg_scores = torch.bmm(
            neg_out_emb.unsqueeze(1),
            neg_item_emb.view(batch_size * neg_k, -1).unsqueeze(2))
        neg_scores = neg_scores.view(batch_size, neg_k)
        if self.args.sim_func == "bias_product":
            pos_bias = self.product_bias[target_prod_idxs.view(-1)].view(
                batch_size)
            neg_bias = self.product_bias[neg_item_idxs.view(-1)].view(
                batch_size, neg_k)
            pos_scores += pos_bias
            neg_scores += neg_bias
        pos_weight = 1
        if self.args.pos_weight:
            pos_weight = self.args.neg_per_pos
        prod_mask = torch.cat([
            torch.ones(batch_size,
                       1,
                       dtype=torch.uint8,
                       device=query_word_idxs.device) * pos_weight,
            torch.ones(batch_size,
                       neg_k,
                       dtype=torch.uint8,
                       device=query_word_idxs.device)
        ],
                              dim=-1)
        prod_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1)
        target = torch.cat([
            torch.ones(batch_size, 1, device=query_word_idxs.device),
            torch.zeros(batch_size, neg_k, device=query_word_idxs.device)
        ],
                           dim=-1)
        #ps_loss = self.bce_logits_loss(prod_scores, target, weight=prod_mask.float())
        #for all positive items, there are neg_k negative items
        ps_loss = nn.functional.binary_cross_entropy_with_logits(
            prod_scores, target, weight=prod_mask.float(), reduction='none')
        ps_loss = ps_loss.sum(-1).mean()
        item_loss = self.item_to_words(target_prod_idxs, pos_iword_idxs,
                                       self.args.neg_per_pos)
        self.ps_loss += ps_loss.item()
        self.item_loss += item_loss.item()
        #logger.info("ps_loss:{} item_loss:{}".format(, item_loss.item()))

        return ps_loss + item_loss

    def forward_dotproduct(self, batch_data, train_pv=False):
        query_word_idxs = batch_data.query_word_idxs
        target_prod_idxs = batch_data.target_prod_idxs
        u_item_idxs = batch_data.u_item_idxs
        batch_size, prev_item_count = u_item_idxs.size()
        neg_k = self.args.neg_per_pos
        pos_iword_idxs = batch_data.pos_iword_idxs
        neg_item_idxs = torch.multinomial(self.prod_dists,
                                          batch_size * neg_k,
                                          replacement=True)
        neg_item_idxs = neg_item_idxs.view(batch_size, -1)
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        column_mask = torch.ones(batch_size,
                                 1,
                                 dtype=torch.uint8,
                                 device=query_word_idxs.device)
        u_item_mask = u_item_idxs.ne(self.prod_pad_idx)
        #pos_item_seq_mask = torch.cat([column_mask, u_item_mask, column_mask], dim=1) #batch_size, 1+max_review_count
        pos_item_seq_mask = torch.cat([column_mask, u_item_mask],
                                      dim=1)  #batch_size, 1+max_review_count

        #pos_seg_idxs = torch.cat(
        #        [column_mask*0, column_mask.expand(-1, prev_item_count), column_mask*2], dim=1)
        column_mask = column_mask.unsqueeze(1).expand(-1, neg_k, -1)
        #neg_item_seq_mask = torch.cat([column_mask, u_item_mask.unsqueeze(1).expand(-1,neg_k,-1), column_mask], dim=2)
        neg_item_seq_mask = torch.cat(
            [column_mask,
             u_item_mask.unsqueeze(1).expand(-1, neg_k, -1)],
            dim=2)
        #neg_seg_idxs = torch.cat([column_mask*0,
        #    column_mask.expand(-1, -1, prev_item_count),
        #    column_mask*2], dim = 2)
        target_item_emb = self.product_emb(target_prod_idxs)
        neg_item_emb = self.product_emb(
            neg_item_idxs)  #batch_size, neg_k, embedding_size
        if self.args.sep_prod_emb:
            u_item_emb = self.hist_product_emb(u_item_idxs)
        else:
            u_item_emb = self.product_emb(u_item_idxs)
        #pos_sequence_emb = torch.cat([query_emb.unsqueeze(1), u_item_emb, target_item_emb.unsqueeze(1)], dim=1)
        pos_sequence_emb = torch.cat([query_emb.unsqueeze(1), u_item_emb],
                                     dim=1)
        #pos_seg_emb = self.seg_embeddings(pos_seg_idxs.long())
        neg_sequence_emb = torch.cat([
            query_emb.unsqueeze(1).expand(-1, neg_k, -1).unsqueeze(2),
            u_item_emb.unsqueeze(1).expand(-1, neg_k, -1, -1),
        ],
                                     dim=2)
        #neg_item_emb.unsqueeze(2)], dim=2)
        #neg_seg_emb = self.seg_embeddings(neg_seg_idxs.long()) #batch_size, neg_k, max_prev_item_count+1, embedding_size
        #pos_sequence_emb += pos_seg_emb
        #neg_sequence_emb += neg_seg_emb

        out_pos = -1 if self.args.use_item_pos else 0
        top_vecs = self.transformer_encoder.encode(
            pos_sequence_emb, pos_item_seq_mask, use_pos=self.args.use_pos_emb)
        pos_out_emb = top_vecs[:, out_pos, :]  #batch_size, embedding_size
        pos_scores = torch.bmm(pos_out_emb.unsqueeze(1),
                               target_item_emb.unsqueeze(2)).view(
                                   batch_size)  #in case batch_size=1
        top_vecs = self.transformer_encoder.encode(
            #neg_sequence_emb.view(batch_size*neg_k, prev_item_count+2, -1),
            #neg_item_seq_mask.view(batch_size*neg_k, prev_item_count+2),
            neg_sequence_emb.view(batch_size * neg_k, prev_item_count + 1, -1),
            neg_item_seq_mask.view(batch_size * neg_k, prev_item_count + 1),
            use_pos=self.args.use_pos_emb)
        neg_out_emb = top_vecs[:, out_pos, :]
        neg_scores = torch.bmm(
            neg_out_emb.unsqueeze(1),
            neg_item_emb.view(batch_size * neg_k, -1).unsqueeze(2))
        neg_scores = neg_scores.view(batch_size, neg_k)
        if self.args.sim_func == "bias_product":
            pos_bias = self.product_bias[target_prod_idxs.view(-1)].view(
                batch_size)
            neg_bias = self.product_bias[neg_item_idxs.view(-1)].view(
                batch_size, neg_k)
            pos_scores += pos_bias
            neg_scores += neg_bias
        pos_weight = 1
        if self.args.pos_weight:
            pos_weight = self.args.neg_per_pos
        prod_mask = torch.cat([
            torch.ones(batch_size,
                       1,
                       dtype=torch.uint8,
                       device=query_word_idxs.device) * pos_weight,
            torch.ones(batch_size,
                       neg_k,
                       dtype=torch.uint8,
                       device=query_word_idxs.device)
        ],
                              dim=-1)
        prod_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1)
        target = torch.cat([
            torch.ones(batch_size, 1, device=query_word_idxs.device),
            torch.zeros(batch_size, neg_k, device=query_word_idxs.device)
        ],
                           dim=-1)
        #ps_loss = self.bce_logits_loss(prod_scores, target, weight=prod_mask.float())
        #for all positive items, there are neg_k negative items
        ps_loss = nn.functional.binary_cross_entropy_with_logits(
            prod_scores, target, weight=prod_mask.float(), reduction='none')
        ps_loss = ps_loss.sum(-1).mean()
        item_loss = self.item_to_words(target_prod_idxs, pos_iword_idxs,
                                       self.args.neg_per_pos)
        self.ps_loss += ps_loss.item()
        self.item_loss += item_loss.item()
        #logger.info("ps_loss:{} item_loss:{}".format(, item_loss.item()))

        return ps_loss + item_loss

    def forward_seq(self, batch_data, train_pv=False):
        query_word_idxs = batch_data.query_word_idxs
        target_prod_idxs = batch_data.target_prod_idxs
        pos_seg_idxs = batch_data.pos_seg_idxs
        neg_seg_idxs = batch_data.neg_seg_idxs
        pos_seq_item_idxs = batch_data.pos_seq_item_idxs
        neg_seq_item_idxs = batch_data.neg_seq_item_idxs
        pos_iword_idxs = batch_data.pos_iword_idxs
        batch_size, neg_k, prev_item_count = neg_seq_item_idxs.size()
        query_word_emb = self.word_embeddings(query_word_idxs)
        query_emb = self.query_encoder(query_word_emb,
                                       query_word_idxs.ne(self.word_pad_idx))
        query_mask = torch.ones(batch_size,
                                1,
                                dtype=torch.uint8,
                                device=query_word_idxs.device)
        pos_seq_item_mask = torch.cat(
            [query_mask, pos_seq_item_idxs.ne(self.prod_pad_idx)],
            dim=1)  #batch_size, 1+max_review_count
        neg_prod_idx_mask = neg_seq_item_idxs.ne(self.prod_pad_idx)
        neg_seq_item_mask = torch.cat(
            [query_mask.unsqueeze(1).expand(-1, neg_k, -1), neg_prod_idx_mask],
            dim=2)
        #batch_size, 1, embedding_size
        pos_seq_item_emb = self.product_emb(pos_seq_item_idxs)
        pos_sequence_emb = torch.cat(
            (query_emb.unsqueeze(1), pos_seq_item_emb), dim=1)
        pos_seg_emb = self.seg_embeddings(pos_seg_idxs)
        neg_seq_item_emb = self.product_emb(
            neg_seq_item_idxs
        )  #batch_size, neg_k, max_prev_item_count, embedding_size
        neg_sequence_emb = torch.cat((query_emb.unsqueeze(1).expand(
            -1, neg_k, -1).unsqueeze(2), neg_seq_item_emb),
                                     dim=2)
        #batch_size, neg_k, max_review_count+1, embedding_size
        neg_seg_emb = self.seg_embeddings(
            neg_seg_idxs
        )  #batch_size, neg_k, max_prev_item_count+1, embedding_size
        pos_sequence_emb += pos_seg_emb
        neg_sequence_emb += neg_seg_emb

        pos_scores = self.transformer_encoder(pos_sequence_emb,
                                              pos_seq_item_mask,
                                              use_pos=self.args.use_pos_emb)
        neg_scores = self.transformer_encoder(
            neg_sequence_emb.view(batch_size * neg_k, prev_item_count + 1, -1),
            neg_seq_item_mask.view(batch_size * neg_k, prev_item_count + 1),
            use_pos=self.args.use_pos_emb)
        neg_scores = neg_scores.view(batch_size, neg_k)
        pos_weight = 1
        if self.args.pos_weight:
            pos_weight = self.args.neg_per_pos
        prod_mask = torch.cat([
            torch.ones(batch_size,
                       1,
                       dtype=torch.uint8,
                       device=query_word_idxs.device) * pos_weight,
            torch.ones(batch_size,
                       neg_k,
                       dtype=torch.uint8,
                       device=query_word_idxs.device)
        ],
                              dim=-1)
        prod_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1)
        target = torch.cat([
            torch.ones(batch_size, 1, device=query_word_idxs.device),
            torch.zeros(batch_size, neg_k, device=query_word_idxs.device)
        ],
                           dim=-1)
        #ps_loss = self.bce_logits_loss(prod_scores, target, weight=prod_mask.float())
        #for all positive items, there are neg_k negative items
        ps_loss = nn.functional.binary_cross_entropy_with_logits(
            prod_scores, target, weight=prod_mask.float(), reduction='none')
        ps_loss = ps_loss.sum(-1).mean()
        item_loss = self.item_to_words(target_prod_idxs, pos_iword_idxs,
                                       self.args.neg_per_pos)
        self.ps_loss += ps_loss.item()
        self.item_loss += item_loss.item()
        #logger.info("ps_loss:{} item_loss:{}".format(, item_loss.item()))

        return ps_loss + item_loss

    def initialize_parameters(self, logger=None):
        if logger:
            logger.info(" ItemTransformerRanker initialization started.")
        if self.pretrain_emb_dir is None:
            nn.init.normal_(self.word_embeddings.weight)
        nn.init.normal_(self.seg_embeddings.weight)
        self.query_encoder.initialize_parameters(logger)
        if self.args.model_name == "item_transformer":
            self.transformer_encoder.initialize_parameters(logger)
        if logger:
            logger.info(" ItemTransformerRanker initialization finished.")
    def __init__(self,
                 vocab: Vocabulary,
                 embed_size: int = 30,
                 char_emb_size: int = 30,
                 word_dropout: float = 0,
                 dropout: float = 0,
                 pool_method: str = 'max',
                 activation='relu',
                 min_char_freq: int = 2,
                 requires_grad=True,
                 include_word_start_end=True,
                 char_attn_type='adatrans',
                 char_n_head=3,
                 char_dim_ffn=60,
                 char_scale=False,
                 char_pos_embed=None,
                 char_dropout=0.15,
                 char_after_norm=False):
        """
        :param vocab: 词表
        :param embed_size: TransformerCharEmbed的输出维度。默认值为50.
        :param char_emb_size: character的embedding的维度。默认值为50. 同时也是Transformer的d_model大小
        :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
        :param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。
        :param pool_method: 支持'max', 'avg'。
        :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
        :param min_char_freq: character的最小出现次数。默认值为2.
        :param requires_grad:
        :param include_word_start_end: 是否使用特殊的tag标记word的开始与结束
        :param char_attn_type: adatrans or naive.
        :param char_n_head: 多少个head
        :param char_dim_ffn: transformer中ffn中间层的大小
        :param char_scale: 是否使用scale
        :param char_pos_embed: None, 'fix', 'sin'. What kind of position embedding. When char_attn_type=relative, None is
            ok
        :param char_dropout: Dropout in Transformer encoder
        :param char_after_norm: the normalization place.
        """
        super(TransformerCharEmbed, self).__init__(vocab,
                                                   word_dropout=word_dropout,
                                                   dropout=dropout)

        assert char_emb_size % char_n_head == 0, "d_model should divide n_head."

        assert pool_method in ('max', 'avg')
        self.pool_method = pool_method
        # activation function
        if isinstance(activation, str):
            if activation.lower() == 'relu':
                self.activation = F.relu
            elif activation.lower() == 'sigmoid':
                self.activation = F.sigmoid
            elif activation.lower() == 'tanh':
                self.activation = F.tanh
        elif activation is None:
            self.activation = lambda x: x
        elif callable(activation):
            self.activation = activation
        else:
            raise Exception(
                "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]"
            )

        logger.info("Start constructing character vocabulary.")
        # 建立char的词表
        self.char_vocab = _construct_char_vocab_from_vocab(
            vocab,
            min_freq=min_char_freq,
            include_word_start_end=include_word_start_end)
        self.char_pad_index = self.char_vocab.padding_idx
        logger.info(
            f"In total, there are {len(self.char_vocab)} distinct characters.")
        # 对vocab进行index
        max_word_len = max(map(lambda x: len(x[0]), vocab))
        if include_word_start_end:
            max_word_len += 2
        self.register_buffer(
            'words_to_chars_embedding',
            torch.full((len(vocab), max_word_len),
                       fill_value=self.char_pad_index,
                       dtype=torch.long))
        self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
        for word, index in vocab:
            # if index!=vocab.padding_idx:  # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
            if include_word_start_end:
                word = ['<bow>'] + list(word) + ['<eow>']
            self.words_to_chars_embedding[index, :len(word)] = \
                torch.LongTensor([self.char_vocab.to_index(c) for c in word])
            self.word_lengths[index] = len(word)

        self.char_embedding = get_embeddings(
            (len(self.char_vocab), char_emb_size))
        self.transformer = TransformerEncoder(1,
                                              char_emb_size,
                                              char_n_head,
                                              char_dim_ffn,
                                              dropout=char_dropout,
                                              after_norm=char_after_norm,
                                              attn_type=char_attn_type,
                                              pos_embed=char_pos_embed,
                                              scale=char_scale)
        self.fc = nn.Linear(char_emb_size, embed_size)

        self._embed_size = embed_size

        self.requires_grad = requires_grad