Esempio n. 1
0
    def __init__(self,
                 d_model,
                 heads,
                 d_ff,
                 dropout,
                 topic=False,
                 topic_dim=300,
                 split_noise=False):
        super(TransformerDecoderLayer, self).__init__()

        self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)

        self.context_attn = MultiHeadedAttention(heads,
                                                 d_model,
                                                 dropout=dropout,
                                                 topic=topic,
                                                 topic_dim=topic_dim,
                                                 split_noise=split_noise)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
        self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
        self.drop = nn.Dropout(dropout)
        mask = self._get_attn_subsequent_mask(MAX_SIZE)
        # Register self.mask as a buffer in TransformerDecoderLayer, so
        # it gets TransformerDecoderLayer's cuda behavior automatically.
        self.register_buffer('mask', mask)
Esempio n. 2
0
    def __init__(self, d_model, heads, d_ff, dropout):
        super(TransformerDecoderLayer, self).__init__()

        self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
        self.enc_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.layer_norm2 = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
Esempio n. 3
0
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 checkpoint_ext=None,
                 checkpoint_abs=None):
        super(HybridSummarizer, self).__init__()
        self.args = args
        self.args
        self.device = device

        self.extractor = ExtSummarizer(args, device, checkpoint_ext)
        self.abstractor = AbsSummarizer(args, device, checkpoint_abs)
        self.context_attn = MultiHeadedAttention(
            head_count=self.args.dec_heads,
            model_dim=self.args.dec_hidden_size,
            dropout=self.args.dec_dropout,
            need_distribution=True)
        self.v = nn.Parameter(torch.Tensor(1, self.args.dec_hidden_size * 3))
        self.bv = nn.Parameter(torch.Tensor(1))
        self.attn_lin = nn.Linear(self.args.dec_hidden_size,
                                  self.args.dec_hidden_size)
        if self.args.hybrid_loss:
            self.ext_loss_fun = torch.nn.BCELoss(reduction='none')
        if self.args.hybrid_connector:
            self.p_sen = nn.Linear(self.args.dec_hidden_size, 1)

        # When Bert is testing, he loads directly.
        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            print("checkpoint loaded!")
        else:
            self.attn_lin.weight.data.normal_(mean=0.0, std=0.02)
            nn.init.xavier_uniform_(self.v)
            nn.init.constant_(self.bv, 0)
            if self.args.hybrid_connector:
                for module in self.p_sen.modules():
                    # print(each)
                    if isinstance(module, (nn.Linear, nn.Embedding)):
                        module.weight.data.normal_(mean=0.0, std=0.02)
                    elif isinstance(module, nn.LayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)
                    if isinstance(module,
                                  nn.Linear) and module.bias is not None:
                        module.bias.data.zero_()

            for module in self.context_attn.modules():
                # print(each)
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
        self.to(device)
Esempio n. 4
0
    def __init__(self, d_model, heads, d_ff, dropout):
        super(TransformerDecoderLayer, self).__init__()

        self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
        self.context_attn = MultiHeadedAttention(heads,
                                                 d_model,
                                                 dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
        self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
        self.drop = nn.Dropout(dropout)

        mask = self._get_attn_subsequent_mask(5000)
        self.register_buffer('mask', mask)
Esempio n. 5
0
    def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):
        super(TransformerDecoder, self).__init__()

        # Basic attributes.
        self.decoder_type = 'transformer'
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim)
        #
        self.context_attn_graph = MultiHeadedAttention(
            heads, d_model, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.drop_3 = nn.Dropout(dropout)
        self.layer_norm_3 = nn.LayerNorm(d_model, eps=1e-6)
        # Build TransformerDecoder.
        self.transformer_layers = nn.ModuleList(
            [TransformerDecoderLayer(d_model, heads, d_ff, dropout)
             for _ in range(num_layers)])

        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.att_weight_c = nn.Linear(self.embeddings.embedding_dim, 1)
        self.att_weight_q = nn.Linear(self.embeddings.embedding_dim, 1)
        self.att_weight_cq = nn.Linear(self.embeddings.embedding_dim, 1)
        self.graph_act = gelu
        self.graph_aware = nn.Linear(self.embeddings.embedding_dim*3, self.embeddings.embedding_dim)
        self.graph_drop = nn.Dropout(dropout)

        self.linear_filter = nn.Linear(d_model*2, 1)
        self.fix_top = torch.tensor((torch.arange(512,0,-1).type(torch.FloatTensor)/512).\
             unsqueeze(0).unsqueeze(0).expand(8, 512, -1)).to(self.get_device())
        self.fix_top.requires_grad = True
        self.fix_top = torch.nn.Parameter(self.fix_top, requires_grad=True)
        self.register_parameter("fix_top", self.fix_top)
    def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size):
        super(Z_TransformerDecoder, self).__init__()

        # Basic attributes.
        self.decoder_type = 'transformer'
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(dropout,self.embeddings.embedding_dim)
        self.vocab_size = vocab_size

        if COPY:
            self.copy_attn = MultiHeadedAttention(
                1, d_model, dropout=dropout)

        # Build TransformerDecoder.
        self.transformer_layers = nn.ModuleList(
            [Z_TransformerDecoderLayer(d_model, heads, d_ff, dropout)
             for _ in range(num_layers)])

        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
Esempio n. 7
0
class HybridSummarizer(nn.Module):
    def __init__(self, args, device, checkpoint = None, checkpoint_ext = None, checkpoint_abs = None):
        super(HybridSummarizer, self).__init__()
        self.args = args
        self.args
        self.device = device
        self.extractor = ExtSummarizer(args, device, checkpoint_ext)
        # self.abstractor = PGTransformers(modules, consts, options)
        self.abstractor = AbsSummarizer(args, device, checkpoint_abs)
        self.context_attn = MultiHeadedAttention(head_count = self.args.dec_heads, model_dim =self.args.dec_hidden_size, dropout=self.args.dec_dropout, need_distribution = True)

        self.v = nn.Parameter(torch.Tensor(1, self.args.dec_hidden_size * 3))
        self.bv = nn.Parameter(torch.Tensor(1))
        self.attn_lin = nn.Linear(self.args.dec_hidden_size, self.args.dec_hidden_size)
        if self.args.hybrid_loss:
            self.ext_loss_fun = torch.nn.BCELoss(reduction='none')
        if self.args.hybrid_connector:
            self.p_sen = nn.Linear(self.args.dec_hidden_size, 1)

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            print("checkpoint is loading !")
        else:
            self.attn_lin.weight.data.normal_(mean=0.0, std=0.02)

            nn.init.xavier_uniform_(self.v)
            nn.init.constant_(self.bv, 0)
            if self.args.hybrid_connector:
                for module in self.p_sen.modules():
                    # print(each)
                    if isinstance(module, (nn.Linear, nn.Embedding)):
                        module.weight.data.normal_(mean=0.0, std=0.02)
                    elif isinstance(module, nn.LayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)
                    if isinstance(module, nn.Linear) and module.bias is not None:
                        module.bias.data.zero_()

            for module in self.context_attn.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
        self.to(device)

    def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, labels = None):

        if labels is not None and self.args.oracle:
            ext_scores = ((labels.float(), + 0.1) / 1.3) * mask_cls.float()
        else:
            if labels is None:
                with torch.no_grad():
                    ext_scores, _, sent_vec = self.extractor(src, segs, clss, mask_src, mask_cls)
            else:
                ext_scores, _, sent_vec = self.extractor(src, segs, clss, mask_src, mask_cls)
                ext_loss = self.ext_loss_fun(ext_scores, labels.float())
                ext_loss = ext_loss * mask_cls.float()

        # [batchsize * (tgt_len - 1) * hidden_size]
        # projected into the probability distribution of vocab_size from hidden state.
        decoder_outputs, encoder_state, y_emb = self.abstractor(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
        src_pad_mask = (1 - mask_src).unsqueeze(1).repeat(1, tgt.size(1) - 1, 1)
        context_vector, attn_dist = self.context_attn(encoder_state, encoder_state, decoder_outputs, mask=src_pad_mask, type="context")
        if self.args.hybrid_connector:
            sorted_scores, sorted_scores_idx = torch.sort(ext_scores, dim=1, descending=True)
            # for top-3 select num.
            select_num = min(3, mask_cls.size(1))
            # 每个句子单独算一个值出来。
            # 这里有一堆东西,但是是把选出的前三个句子和对应评分加权,方便下边求和。
            selected_sent_vec = tuple([(sorted_scores[i][:select_num].unsqueeze(0).transpose(0,1) * sent_vec[i,tuple(sorted_scores_idx[i][:select_num])]).unsqueeze(0) for i, each in enumerate(sorted_scores_idx)])
            selected_sent_vec = torch.cat(selected_sent_vec, dim=0)
            selected_sent_vec = selected_sent_vec.sum(dim=1)
            E_sel = self.p_sen(selected_sent_vec)
            ext_scores = ext_scores * E_sel

        g = torch.sigmoid(F.linear(torch.cat([decoder_outputs, y_emb, context_vector], -1), self.v, self.bv))
        xids = src.unsqueeze(0).repeat(tgt.size(1) - 1, 1, 1).transpose(0,1)
        xids = xids * mask_tgt.unsqueeze(2)[:,:-1,:].long()

        # mask characters such as CLS um
        len0 = src.size(1)
        len0 = torch.Tensor([[len0]]).repeat(src.size(0), 1).long().to('cuda')
        clss_up = torch.cat((clss, len0), dim=1)
        sent_len = (clss_up[:, 1:] - clss) * mask_cls.long()
        for i in range(mask_cls.size(0)):
            for j in range(mask_cls.size(1)):
                if sent_len[i][j] < 0:
                    sent_len[i][j] += src.size(1)
        ext_scores_0 = ext_scores.unsqueeze(1).transpose(1,2).repeat(1,1, src.size(1))
        for i in range(clss.size(0)):
            tmp_vec = ext_scores_0[i, 0, :sent_len[i][0].int()]

            for j in range(1, clss.size(1)):
                tmp_vec = torch.cat((tmp_vec, ext_scores_0[i, j, :sent_len[i][j].int()]), dim=0)
            if i == 0:
                ext_scores_new = tmp_vec.unsqueeze(0)
            else:
                ext_scores_new = torch.cat((ext_scores_new, tmp_vec.unsqueeze(0)), dim=0)
        ext_scores_new = ext_scores_new * mask_src.float()
        attn_dist = attn_dist * (ext_scores_new + 1).unsqueeze(1)
        # Weighted sum formula.
        attn_dist = attn_dist / attn_dist.sum(dim=2).unsqueeze(2)

        ext_dist = Variable(torch.zeros(tgt.size(0), tgt.size(1) - 1, self.abstractor.bert.model.config.vocab_size).to(self.device))
        ext_vocab_prob = ext_dist.scatter_add(2, xids, (1 - g) * mask_tgt.unsqueeze(2)[:,:-1,:].float() * attn_dist) * mask_tgt.unsqueeze(2)[:,:-1,:].float()
        if self.args.hybrid_loss:
            return decoder_outputs, None, (ext_vocab_prob, g, ext_loss)
        else:
            return decoder_outputs, None, (ext_vocab_prob, g)
Esempio n. 8
0
class HybridSummarizer(nn.Module):
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 checkpoint_ext=None,
                 checkpoint_abs=None):
        super(HybridSummarizer, self).__init__()
        self.args = args
        self.args
        self.device = device

        self.extractor = ExtSummarizer(args, device, checkpoint_ext)
        self.abstractor = AbsSummarizer(args, device, checkpoint_abs)
        self.context_attn = MultiHeadedAttention(
            head_count=self.args.dec_heads,
            model_dim=self.args.dec_hidden_size,
            dropout=self.args.dec_dropout,
            need_distribution=True)
        self.v = nn.Parameter(torch.Tensor(1, self.args.dec_hidden_size * 3))
        self.bv = nn.Parameter(torch.Tensor(1))
        self.attn_lin = nn.Linear(self.args.dec_hidden_size,
                                  self.args.dec_hidden_size)
        if self.args.hybrid_loss:
            self.ext_loss_fun = torch.nn.BCELoss(reduction='none')
        if self.args.hybrid_connector:
            self.p_sen = nn.Linear(self.args.dec_hidden_size, 1)

        # When Bert is testing, he loads directly.
        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            print("checkpoint loaded!")
        else:
            self.attn_lin.weight.data.normal_(mean=0.0, std=0.02)
            nn.init.xavier_uniform_(self.v)
            nn.init.constant_(self.bv, 0)
            if self.args.hybrid_connector:
                for module in self.p_sen.modules():
                    # print(each)
                    if isinstance(module, (nn.Linear, nn.Embedding)):
                        module.weight.data.normal_(mean=0.0, std=0.02)
                    elif isinstance(module, nn.LayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)
                    if isinstance(module,
                                  nn.Linear) and module.bias is not None:
                        module.bias.data.zero_()

            for module in self.context_attn.modules():
                # print(each)
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
        self.to(device)

    def forward(self,
                src,
                tgt,
                segs,
                clss,
                mask_src,
                mask_tgt,
                mask_cls,
                labels=None):

        if labels is not None and self.args.oracle:
            ext_scores = ((labels.float(), +0.1) / 1.3) * mask_cls.float()
        else:
            # w
            if labels is None:
                with torch.no_grad():
                    ext_scores, _, sent_vec = self.extractor(
                        src, segs, clss, mask_src, mask_cls)
            else:
                ext_scores, _, sent_vec = self.extractor(
                    src, segs, clss, mask_src, mask_cls)
                ext_loss = self.ext_loss_fun(ext_scores, labels.float())
                ext_loss = ext_loss * mask_cls.float()
        decoder_outputs, encoder_state, y_emb = self.abstractor(
            src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
        src_pad_mask = (1 - mask_src).unsqueeze(1).repeat(
            1,
            tgt.size(1) - 1, 1)
        context_vector, attn_dist = self.context_attn(encoder_state,
                                                      encoder_state,
                                                      decoder_outputs,
                                                      mask=src_pad_mask,
                                                      type="context")
        if self.args.hybrid_connector:
            select_num = min(3, mask_cls.size(1))
            selected_sent_vec = tuple([
                (sorted_scores[i][:select_num].unsqueeze(0).transpose(0, 1) *
                 sent_vec[i, tuple(sorted_scores_idx[i][:select_num])]
                 ).unsqueeze(0) for i, each in enumerate(sorted_scores_idx)
            ])

            selected_sent_vec = torch.cat(selected_sent_vec, dim=0)
            selected_sent_vec = selected_sent_vec.sum(dim=1)
            E_sel = self.p_sen(selected_sent_vec)
            ext_scores = ext_scores * E_sel

        if torch.isnan(decoder_outputs[0][0][0]):
            print("ops, decoder_outputs!")
            print("src = ", src.size())
            print(src)
            print("tgt = ", tgt.size())
            print(tgt)
            # # # segs是每个词属于哪句话
            print("segs = ", segs.size())
            print(segs)
            # # clss 是每个句子的起点位置
            print("clss = ", clss.size())
            print(clss)
            print("mask_src = ", mask_src.size())
            print(mask_src)
            print("mask_cls = ", mask_cls.size())
            print(mask_cls)
            print("decoder_outputs ", decoder_outputs.size())
            print(decoder_outputs)
            print("y_emb ", y_emb)
            print(y_emb)
            print("context_vector ", context_vector.size())
            print(context_vector)
            exit()

        if torch.isnan(y_emb[0][0][0]):
            print("ops, yemb!")
            print("src = ", src.size())
            print(src)
            print("tgt = ", tgt.size())
            print(tgt)
            # # # segs是每个词属于哪句话
            print("segs = ", segs.size())
            print(segs)
            # # clss 是每个句子的起点位置
            print("clss = ", clss.size())
            print(clss)
            print("mask_src = ", mask_src.size())
            print(mask_src)
            print("mask_cls = ", mask_cls.size())
            print(mask_cls)
            print("decoder_outputs ", decoder_outputs.size())
            print(decoder_outputs)
            print("y_emb ", y_emb)
            print(y_emb)
            print("context_vector ", context_vector.size())
            print(context_vector)
            exit()

        if torch.isnan(context_vector[0][0][0]):
            print("ops, context_vector!")
            print("src = ", src.size())
            print(src)
            print("tgt = ", tgt.size())
            print(tgt)
            # # # segs是每个词属于哪句话
            print("segs = ", segs.size())
            print(segs)
            # # clss 是每个句子的起点位置
            print("clss = ", clss.size())
            print(clss)
            print("mask_src = ", mask_src.size())
            print(mask_src)
            print("mask_cls = ", mask_cls.size())
            print(mask_cls)
            print("decoder_outputs ", decoder_outputs.size())
            print(decoder_outputs)
            print("y_emb ", y_emb)
            print(y_emb)
            print("context_vector ", context_vector.size())
            print(context_vector)
            exit()

        g = torch.sigmoid(
            F.linear(torch.cat([decoder_outputs, y_emb, context_vector], -1),
                     self.v, self.bv))
        if torch.isnan(g[0][0]):
            print("ops!, g")
            print("src = ", src.size())
            print(src)
            print("tgt = ", tgt.size())
            print(tgt)
            # # # segs是每个词属于哪句话
            print("segs = ", segs.size())
            print(segs)
            # # clss 是每个句子的起点位置
            print("clss = ", clss.size())
            print(clss)
            print("mask_src = ", mask_src.size())
            print(mask_src)
            print("mask_cls = ", mask_cls.size())
            print(mask_cls)
            print("decoder_outputs ", decoder_outputs.size())
            print(decoder_outputs)
            print("y_emb ", y_emb)
            print(y_emb)
            print("context_vector ", context_vector.size())
            print(context_vector)
            print("g ", g.size())
            print(g)
            exit()

        xids = src.unsqueeze(0).repeat(tgt.size(1) - 1, 1, 1).transpose(0, 1)
        xids = xids * mask_tgt.unsqueeze(2)[:, :-1, :].long()
        len0 = src.size(1)
        len0 = torch.Tensor([[len0]]).repeat(src.size(0), 1).long().to('cuda')
        clss_up = torch.cat((clss, len0), dim=1)
        sent_len = (clss_up[:, 1:] - clss) * mask_cls.long()
        for i in range(mask_cls.size(0)):
            for j in range(mask_cls.size(1)):
                if sent_len[i][j] < 0:
                    sent_len[i][j] += src.size(1)
        ext_scores_0 = ext_scores.unsqueeze(1).transpose(1, 2).repeat(
            1, 1, src.size(1))
        for i in range(clss.size(0)):
            tmp_vec = ext_scores_0[i, 0, :sent_len[i][0].int()]

            for j in range(1, clss.size(1)):
                tmp_vec = torch.cat(
                    (tmp_vec, ext_scores_0[i, j, :sent_len[i][j].int()]),
                    dim=0)
            if i == 0:
                ext_scores_new = tmp_vec.unsqueeze(0)
            else:
                ext_scores_new = torch.cat(
                    (ext_scores_new, tmp_vec.unsqueeze(0)), dim=0)
        ext_scores_new = ext_scores_new * mask_src.float()
        attn_dist = attn_dist * (ext_scores_new + 1).unsqueeze(1)
        attn_dist = attn_dist / attn_dist.sum(dim=2).unsqueeze(2)
        ext_dist = Variable(
            torch.zeros(tgt.size(0),
                        tgt.size(1) - 1,
                        self.abstractor.bert.model.config.vocab_size).to(
                            self.device))
        ext_vocab_prob = ext_dist.scatter_add(
            2, xids, (1 - g) * mask_tgt.unsqueeze(2)[:, :-1, :].float() *
            attn_dist) * mask_tgt.unsqueeze(2)[:, :-1, :].float()

        if self.args.hybrid_loss:
            return decoder_outputs, None, (ext_vocab_prob, g, ext_loss)
        else:
            return decoder_outputs, None, (ext_vocab_prob, g)
Esempio n. 9
0
    def __init__(self,
                 args,
                 device,
                 vocab_size,
                 product_size,
                 vocab_words,
                 word_dists=None):
        super(ItemTransformerRanker, self).__init__()
        self.args = args
        self.device = device
        self.train_review_only = args.train_review_only
        self.embedding_size = args.embedding_size
        self.vocab_words = vocab_words
        self.word_dists = None
        if word_dists is not None:
            self.word_dists = torch.tensor(word_dists, device=device)
        self.prod_dists = torch.ones(product_size, device=device)
        self.prod_pad_idx = product_size
        self.word_pad_idx = vocab_size - 1
        self.seg_pad_idx = 3
        self.emb_dropout = args.dropout
        self.pretrain_emb_dir = None
        if os.path.exists(args.pretrain_emb_dir):
            self.pretrain_emb_dir = args.pretrain_emb_dir
        self.pretrain_up_emb_dir = None
        if os.path.exists(args.pretrain_up_emb_dir):
            self.pretrain_up_emb_dir = args.pretrain_up_emb_dir
        self.dropout_layer = nn.Dropout(p=args.dropout)

        self.product_emb = nn.Embedding(product_size + 1,
                                        self.embedding_size,
                                        padding_idx=self.prod_pad_idx)
        if args.sep_prod_emb:
            self.hist_product_emb = nn.Embedding(product_size + 1,
                                                 self.embedding_size,
                                                 padding_idx=self.prod_pad_idx)
        '''
        else:
            pretrain_product_emb_path = os.path.join(self.pretrain_up_emb_dir, "product_emb.txt")
            pretrained_weights = load_user_item_embeddings(pretrain_product_emb_path)
            pretrained_weights.append([0.] * len(pretrained_weights[0]))
            self.product_emb = nn.Embedding.from_pretrained(torch.FloatTensor(pretrained_weights), padding_idx=self.prod_pad_idx)
        '''
        self.product_bias = nn.Parameter(torch.zeros(product_size + 1),
                                         requires_grad=True)
        self.word_bias = nn.Parameter(torch.zeros(vocab_size),
                                      requires_grad=True)

        if self.pretrain_emb_dir is not None:
            word_emb_fname = "word_emb.txt.gz"  #for query and target words in pv and pvc
            pretrain_word_emb_path = os.path.join(self.pretrain_emb_dir,
                                                  word_emb_fname)
            word_index_dic, pretrained_weights = load_pretrain_embeddings(
                pretrain_word_emb_path)
            word_indices = torch.tensor(
                [0] + [word_index_dic[x]
                       for x in self.vocab_words[1:]] + [self.word_pad_idx])
            #print(len(word_indices))
            #print(word_indices.cpu().tolist())
            pretrained_weights = torch.FloatTensor(pretrained_weights)
            self.word_embeddings = nn.Embedding.from_pretrained(
                pretrained_weights[word_indices],
                padding_idx=self.word_pad_idx)
            #vectors of padding idx will not be updated
        else:
            self.word_embeddings = nn.Embedding(vocab_size,
                                                self.embedding_size,
                                                padding_idx=self.word_pad_idx)
        if self.args.model_name == "item_transformer":
            self.transformer_encoder = TransformerEncoder(
                self.embedding_size, args.ff_size, args.heads, args.dropout,
                args.inter_layers)
        #if self.args.model_name == "ZAM" or self.args.model_name == "AEM":
        else:
            self.attention_encoder = MultiHeadedAttention(
                args.heads, self.embedding_size, args.dropout)

        if args.query_encoder_name == "fs":
            self.query_encoder = FSEncoder(self.embedding_size,
                                           self.emb_dropout)
        else:
            self.query_encoder = AVGEncoder(self.embedding_size,
                                            self.emb_dropout)
        self.seg_embeddings = nn.Embedding(4,
                                           self.embedding_size,
                                           padding_idx=self.seg_pad_idx)
        #for each q,u,i
        #Q, previous purchases of u, current available reviews for i, padding value
        #self.logsoftmax = torch.nn.LogSoftmax(dim = -1)
        self.bce_logits_loss = torch.nn.BCEWithLogitsLoss(
            reduction='none')  #by default it's mean

        self.initialize_parameters(logger)  #logger
        self.to(device)  #change model in place
        self.item_loss = 0
        self.ps_loss = 0
	def __init__(self, args, device, checkpoint=None, bert_from_extractive=None):
		super(AbsSummarizer, self).__init__()
		self.args = args
		self.device = device
		self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

		if bert_from_extractive is not None:
			self.bert.model.load_state_dict(
				dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True)

		if (args.encoder == 'baseline'):
			bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size,
			                         num_hidden_layers=args.enc_layers, num_attention_heads=8,
			                         intermediate_size=args.enc_ff_size,
			                         hidden_dropout_prob=args.enc_dropout,
			                         attention_probs_dropout_prob=args.enc_dropout)
			self.bert.model = BertModel(bert_config)

		if (args.max_pos > 512):
			my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
			my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
			my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,
			                                      :].repeat(args.max_pos - 512, 1)
			self.bert.model.embeddings.position_embeddings = my_pos_embeddings
		self.vocab_size = self.bert.model.config.vocab_size

		self.enc_out_size = self.args.dec_hidden_size
		if self.args.use_dep:
			self.enc_out_size += 2
		if self.args.use_frame:
			self.enc_frame = nn.Linear(1, 20)
			self.frame_attn = MultiHeadedAttention(1, 20, 0.1)
			self.enc_out_size += 20
		self.enc_out = nn.Linear(self.enc_out_size, self.args.dec_hidden_size)
		self.drop = nn.Dropout(self.args.enc_dropout)
		self.layer_norm = nn.LayerNorm(self.args.dec_hidden_size, eps=1e-6)

		tgt_embeddings = nn.Embedding(self.vocab_size, self.args.dec_hidden_size, padding_idx=0)
		if (self.args.share_emb):
			tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)

		self.decoder = TransformerDecoder(
			self.args.dec_layers,
			self.args.dec_hidden_size, heads=self.args.dec_heads,
			d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings)

		self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device)
		self.generator[0].weight = self.decoder.embeddings.weight

		if checkpoint is not None:
			self.load_state_dict(checkpoint['model'], strict=True)
		else:
			for module in self.decoder.modules():
				if isinstance(module, (nn.Linear, nn.Embedding)):
					module.weight.data.normal_(mean=0.0, std=0.02)
				elif isinstance(module, nn.LayerNorm):
					module.bias.data.zero_()
					module.weight.data.fill_(1.0)
				if isinstance(module, nn.Linear) and module.bias is not None:
					module.bias.data.zero_()
			for p in self.generator.parameters():
				if p.dim() > 1:
					xavier_uniform_(p)
				else:
					p.data.zero_()
			if (args.use_bert_emb):
				tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
				tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
				self.decoder.embeddings = tgt_embeddings
				self.generator[0].weight = self.decoder.embeddings.weight

		self.to(device)