Пример #1
0
    def multi_decode(self, batch):
        """ A helper for beam """
        (fr, fr_len) = batch.fr
        (en, en_len) = batch.en
        (de, de_len) = batch.de
        batch_size = len(batch)

        fr_hid = self.fr_en.enc(fr[:, 1:], fr_len - 1)
        en_msgs_ = self.fr_en.dec.beam(fr_hid,
                                       fr_len - 1)  # (5, batch_size, en_len)
        en_msgs, en_lens = [], []
        for en in en_msgs_:
            en = [ee[1:] for ee in en]
            en_len = [len(x) for x in en]
            max_len = max(en_len)
            en_len = cuda(torch.LongTensor(en_len))
            en_lens.append(en_len)

            en = [
                np.lib.pad(xx, (0, max_len - len(xx)),
                           'constant',
                           constant_values=(0, 0)) for xx in en
            ]
            en = cuda(torch.LongTensor(np.array(en)))
            en_msgs.append(en)

        en_hids = [
            self.en_de.enc(en_msg, en_len)
            for (en_msg, en_len) in zip(en_msgs, en_lens)
        ]
        de_msg = self.en_de.dec.multi_decode(en_hids, en_lens)

        return en_msgs, de_msg
def valid_model(args, model, dev_it, dev_metrics, iters, loss_names,
                monitor_names, extra_input):
    with torch.no_grad():
        model.eval()

        for j, dev_batch in enumerate(dev_it):
            img_input = None if args.no_img else cuda(
                extra_input["img"]['multi30k'][1].index_select(
                    dim=0, index=dev_batch.idx.cpu()))
            en, en_len = dev_batch.en
            decoded = model(en, img_input)
            R = {}
            R["nll"] = F.cross_entropy(decoded,
                                       en[:, 1:].contiguous().view(-1),
                                       ignore_index=0)
            #R["nll_cur"] = F.cross_entropy( decoded, en[:,:-1].contiguous().view(-1), ignore_index=0 )

            if not (img_input is None):
                idx = cuda(torch.randperm(img_input.size(0)))
                img_input = img_input.index_select(0, idx)
                decoded = model(en, img_input)
            R["nll_rnd"] = F.cross_entropy(decoded,
                                           en[:, 1:].contiguous().view(-1),
                                           ignore_index=0)

            dev_metrics.accumulate(
                len(dev_batch),
                *[R[name].item() for name in loss_names + monitor_names])

        args.logger.info(dev_metrics)
Пример #3
0
    def multi_decode(self, src_hids, src_lens):
        # src_hid : (batch_size, x_seq_len, D_hid * n_dir)
        # src_len : (batch_size)
        batch_size = src_hids[0].size(0)
        x_seq_lens = [src_hid.size(1) for src_hid in src_hids]
        src_masks = [ xlen_to_inv_mask(src_len, seq_len=x_seq_len) for (src_len, x_seq_len) \
                     in zip(src_lens, x_seq_lens) ] # (batch_size, x_seq_len)
        y_seq_len = self.max_len_gen
        z_ts = []

        for src_len, src_hid in zip(src_lens, src_hids):
            h_index = (src_len - 1)[:, None,
                                    None].repeat(1, 1, src_hid.size(2))
            z_t = torch.cumsum(src_hid,
                               dim=1)  # (batch_size, x_seq_len, D_hid * n_dir)
            z_t = z_t.gather(dim=1, index=h_index).view(
                batch_size, -1)  # (batch_size, D_hid * n_dir)
            z_t = torch.div(z_t, src_len[:, None].float())
            z_ts.append(z_t)

        y_emb = cuda(torch.full((batch_size, ), self.init_token).long())
        y_emb = self.emb(y_emb)  # (batch_size, D_emb)
        input_feeds = [
            y_emb.data.new(batch_size, self.D_hid).zero_() for ii in range(5)
        ]

        done = cuda(torch.zeros(batch_size).long())
        eos_tensor = cuda(torch.zeros(batch_size).fill_(self.eos_token)).long()

        msg = []
        for idx in range(y_seq_len):
            probs = []
            for which in range(5):
                input = torch.cat([y_emb, input_feeds[which]],
                                  dim=1)  # (batch_size * width, D_emb + D_hid)
                trg_hid = self.layers[0](
                    input, z_ts[which])  # (batch_size * width, D_hid)
                z_ts[which] = trg_hid  # (batch_size * width, D_hid)

                out, _ = self.attention(
                    trg_hid, src_hids[which],
                    src_masks[which])  # (batch_size, D_hid)
                input_feeds[which] = out
                logit = self.out(out)  # (batch_size, voc_sz_trg)
                probs.append(F.softmax(logit, -1))

            prob = torch.stack(probs, dim=-1).mean(-1)
            tokens = prob.max(dim=1)[1]  # (batch_size)
            msg.append(tokens)

            is_next_eos = (tokens == eos_tensor).long()  # (batch_size)
            done = (done + is_next_eos).clamp(min=0, max=1).long()
            if done.sum() == batch_size:
                break

            y_emb = self.emb(tokens)  # (batch_size, D_emb)

        msg = torch.stack(msg, dim=1)  # (batch_size, y_seq_len)
        return msg
Пример #4
0
    def forward(self, x, x_len, h_0=None):
        # NOTE x_len.max() == x.size(1)
        # x : (batch_size, x_seq_len) or (batcg_size, x_seq_len, vocab_size)
        # x_len : (batch_size)
        batch_size, x_seq_len = x.size()[:2]  # NOTE dim==3 for gumbel softmax

        if h_0 is None:
            h_0 = cuda(
                torch.FloatTensor(self.n_layers * self.n_dir, batch_size,
                                  self.D_hid).zero_())
        if len(x.shape) == 2:
            input = self.emb(x)
        elif len(x.shape) == 3:
            # Gumbel
            input = torch.matmul(x, self.emb.weight)
        else:
            raise ValueError
        input = F.dropout(input, p=self.drop_ratio, training=self.training)

        # input (batch_size, x_seq_len, D_emb)
        # h_0 (n_layers * n_dir, batch_size, D_hid)
        output, _ = self.rnn(input, h_0)
        # output (batch_size, x_seq_len, n_dir * D_hid)
        # h_n (n_layers * n_dir, batch_size, D_hid)
        """
        if self.model == "RNN" : # RNN
            out = take_last(output, x_len, self.n_dir, self.D_hid)
            return out # (batch_size, n_dir * D_hid)

        else: # RNNAttn
            #inv_mask = xlen_to_inv_mask(x_len)[:,:,None] # (batch_size, x_seq_len, 1)
            #output.masked_fill_(inv_mask, 0)
        """
        return output.contiguous()  # (batch_size, x_seq_len, n_dir * D_hid)
Пример #5
0
 def batch_cap_rep(self, sents, sent_lens):
     """
     :param sents: [NB_SENT, len]
     :param sent_lens: [NB_SENT]
     :return: [NB_X, D_hid]
     """
     batch_size = 64
     start = 0
     result = []
     while start < sents.shape[0]:
         end = start + batch_size
         batch_sent = cuda(sents[start: end])
         batch_len = cuda(sent_lens[start: end])
         sent_enc = self.get_cap_rep(batch_sent, batch_len)
         result.append(sent_enc)
         start = end
     result = torch.cat(result, dim=0)
     return normf(result)
Пример #6
0
def _make_sure_message_valid(msg, msg_len, init_token):
    # Add BOS
    msg = torch.cat(
        [cuda(torch.full((msg.shape[0], 1), init_token)).long(), msg], dim=1)
    msg_len += 1

    # Make sure padding are all zeros
    #inv_mask = xlen_to_inv_mask(msg_len, seq_len=msg.shape[1])
    #msg.masked_fill_(mask=inv_mask.bool(), value=0)
    return msg, msg_len
Пример #7
0
    def get_cap_rep(self, x, x_len):
        batch_size, seq_len = x.shape
        hidden = cuda(torch.zeros(self.n_layers, batch_size, self.D_hid))
        emb = F.dropout( self.emb( x ), p=self.drop_ratio, training=self.training )

        output, _ = self.rnn(emb, hidden)
        # (batch, seq_len, 2 * D_hid)
        f_out = output.view(batch_size, seq_len, self.D_hid)
        f_idx = (x_len - 1)[:,None,None].repeat(1, 1, self.D_hid)
        f_out = f_out.gather(dim=1, index=f_idx).view(batch_size, -1) # (batch_size, D_hid)
        return f_out
def valid_model(args, model, valid_img_feats, valid_caps, valid_lens):
    model.eval()
    batch_size = 32
    start = 0
    val_metrics = Metrics('val_loss', 'loss', data_type="avg")
    with torch.no_grad():
        while start <= valid_img_feats.shape[0]:
            cap_id = random.randint(0, 4)
            end = start + batch_size
            batch_img_feat = cuda(valid_img_feats[start: end])
            batch_ens = cuda(valid_caps[cap_id][start: end])
            batch_lens = cuda(valid_lens[cap_id][start: end])
            R = model(batch_ens[:, 1:], batch_lens - 1, batch_img_feat)
            if args.img_pred_loss == "vse":
                R['loss'] = R['loss'].sum()
            elif args.img_pred_loss == "mse":
                R['loss'] = R['loss'].mean()
            else:
                raise ValueError
            val_metrics.accumulate(batch_size, R['loss'])
            start = end
        return val_metrics
Пример #9
0
def valid_model(args, model, dev_it, extra_input):
    with torch.no_grad():
        model.eval()
        img_features, cap_features = [], []
        for j, dev_batch in enumerate(dev_it):
            img = cuda(extra_input["img"]["multi30k"][1].index_select(
                dim=0, index=dev_batch.idx.cpu()))  # (batch_size, D_img)
            en, en_len = dev_batch.en
            cap_rep = model.get_cap_rep(en[:, 1:], en_len - 1)
            cap_features.append(normf(cap_rep))
            img = model.img_enc(img)  # (batch_size, D_hid)
            img_features.append(normf(img))
        img_features = torch.cat(img_features, dim=0)  # (5000, D_img)
        cap_features = torch.cat(cap_features, dim=0)  # (5000, D_img)
        scores = torch.mm(img_features, cap_features.t())  # (5000, 5000)
        _, cap_idx = torch.sort(scores, dim=1, descending=True)  # (5000, 5000)
        _, img_idx = torch.sort(scores, dim=0, descending=True)  # (5000, 5000)
        query = cuda(torch.arange(5000))
        cap_r1, cap_r5, cap_r10 = [retrieval(idx, query, 1) for idx in \
                                   [ cap_idx[:,:1], cap_idx[:,:5], cap_idx[:,:10] ] ]
        img_r1, img_r5, img_r10 = [retrieval(idx, query, 0) for idx in \
                                   [ img_idx[:1,:], img_idx[:5,:], img_idx[:10,:] ] ]
        return (cap_r1, cap_r5, cap_r10, img_r1, img_r5, img_r10)
Пример #10
0
 def batch_enc_img(self, img_feat):
     """
     :param img_feat: [NB_img, D_img]
     :return: [NB_img, D_hid]
     """
     batch_size = 64
     result = []
     start = 0
     while start < img_feat.shape[0]:
         end = start + batch_size
         batch_img_feat = cuda(img_feat[start: end])
         batch_img_feat = F.dropout(batch_img_feat,
                                    p=self.drop_ratio,
                                    training=self.training )
         batch_img_feat = self.img_enc(batch_img_feat)
         result.append(batch_img_feat)
         start = end
     result = torch.cat(result, dim=0)
     return normf(result)
def get_retrieve_result(args, model, valid_img_feats, valid_caps, valid_lens):
    with torch.no_grad():
        model.eval()
        query = cuda(torch.arange(valid_img_feats.shape[0]))
        img_features = model.batch_enc_img(valid_img_feats)
        cap_features = []
        for valid_cap, valid_len in zip(valid_caps, valid_lens):
            cap_features.append(model.batch_cap_rep(valid_cap[:, 1:], valid_len - 1))
        scores = [torch.mm(img_features, cap.t()) for cap in cap_features] # [ (img 1000, cap_i 1000) x 5 ]
        img_idxs = [torch.sort( sc, dim=0, descending=True )[1] for sc in scores] # [ (img 1000, cap_i 1000) x 5 ]
        img_r1  = np.mean([retrieval(img_idx[:1, :], query, 0) for img_idx in img_idxs])
        img_r5  = np.mean([retrieval(img_idx[:5, :], query, 0) for img_idx in img_idxs])
        img_r10 = np.mean([retrieval(img_idx[:10, :], query, 0) for img_idx in img_idxs])

        scores = torch.cat(scores, dim=1 ) # (img 1000, cap 5000)
        cap_idx = torch.sort(scores, dim=1, descending=True )[1] # (img 1000, cap 5000)
        cap_idx = torch.remainder(cap_idx, 1000 ) # (img 1000, cap 5000)
        cap_r1, cap_r5, cap_r10 = [retrieval(idx, query, 1) for idx in \
                                   [ cap_idx[:,:1], cap_idx[:,:5], cap_idx[:,:10] ] ]
        return cap_r1, cap_r5, cap_r10, img_r1, img_r5, img_r10
Пример #12
0
    def forward(self, x, x_len, img):
        # x : (batch_size, seq_len)
        # x_len : (batch_size)
        # img : (batch_size, D_img)
        batch_size, seq_len = x_len.size(0), x_len.max().item()
        x_enc = self.get_cap_rep(x, x_len)
        R = {}

        if self.img_pred_loss == "mse":
            x_enc = F.dropout( x_enc, p=self.drop_ratio, training=self.training ) # (batch_size, D_img)
            x_enc = self.img_enc( x_enc )
            #loss = F.mse_loss(x_enc, img, reduction='none')
            loss = ( (x_enc - img) ** 2 ).mean(1)

        elif self.img_pred_loss == "vse":
            img = F.dropout( img, p=self.drop_ratio, training=self.training ) # (batch_size, D_img)
            img = self.img_enc( img ) # (batch_size, D_hid)

            x_enc = normf(x_enc)
            img = normf(img)

            scores = torch.mm( img, x_enc.t() ) # (batch_size, batch_size)
            diagonal = scores.diag().view(batch_size, -1) # (batch_size, 1)
            pos_cap_scores = diagonal.expand_as(scores) # (batch_size, batch_size)
            pos_img_scores = diagonal.t().expand_as(scores) # (batch_size, batch_size)

            cost_cap = (self.margin + scores - pos_cap_scores).clamp(min=0)
            cost_img = (self.margin + scores - pos_img_scores).clamp(min=0)

            mask = cuda( torch.eye( batch_size ) > .5 ) # remove diagonal
            cost_cap = cost_cap.masked_fill(mask, 0.0) # (batch_size, batch_size)
            cost_img = cost_img.masked_fill(mask, 0.0) # (batch_size, batch_size)

            if self.training:
                loss = cost_cap.max(1)[0] + cost_img.max(0)[0] # (batch_size)
            else:
                loss = cost_cap.mean(dim=1) + cost_img.mean(dim=0) # (batch_size)
        R['loss'] = loss
        return R
Пример #13
0
 def forward(self, msg, img):
     # msg : (batch_size, seq_len) or (batch_size, seq_len, vocab_size)
     # img : (batch_size, D_img)
     batch_size = msg.size(0)
     if self.no_img:
         assert (img is None)
         hidden = cuda( torch.zeros(self.n_layers, batch_size, self.D_hid) )
     else:
         assert not (img is None)
         img = F.dropout( img, p=self.drop_ratio, training=self.training )
         hidden = self.img_enc( img )[None,:,:] # (1, batch_size, D_hid)
         hidden = hidden.repeat(self.n_layers, 1, 1)
     input, target = msg[:,:-1], msg[:,1:]
     if len(input.shape) == 2:
         emb = F.dropout(self.encoder( input ), p=self.drop_ratio, training=self.training)
     elif len(input.shape) == 3:
         emb = torch.matmul(input, self.encoder.weight)
         emb = F.dropout(emb, p=self.drop_ratio, training=self.training)
     else:
         raise ValueError
     output, _ = self.rnn(emb, (hidden, hidden))
     output = F.dropout( output, p=self.drop_ratio, training=self.training )
     decoded = self.decoder(output).view(-1, self.voc_sz)
     return decoded
Пример #14
0
    def send(self,
             src_hid,
             src_len,
             trg_len,
             send_method,
             value_fn=None,
             gumbel_temp=1,
             dot_token=None):
        # src_hid : (batch_size, x_seq_len, D_hid * n_dir)
        # src_len : (batch_size)
        batch_size, x_seq_len = src_hid.size()[:2]
        src_mask = xlen_to_inv_mask(
            src_len, seq_len=x_seq_len)  # (batch_size, x_seq_len)
        y_seq_len = math.floor(trg_len.max().item() * self.msg_len_ratio) \
            if self.msg_len_ratio > 0 else self.max_len_gen

        h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2))
        z_t = torch.cumsum(src_hid,
                           dim=1)  # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1,
                         index=h_index).view(batch_size,
                                             -1)  # (batch_size, D_hid * n_dir)
        z_t = torch.div(z_t, src_len[:, None].float())

        y_emb = cuda(torch.full((batch_size, ), self.init_token).long())
        y_emb = self.emb(y_emb)  # (batch_size, D_emb)
        #y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training )

        prev_trg_hids = [z_t for _ in range(self.n_layers)]
        trg_hid = prev_trg_hids[0]
        input_feed = y_emb.data.new(batch_size, self.D_hid).zero_()

        done = cuda(torch.zeros(batch_size).long())
        seq_lens = cuda(
            torch.zeros(batch_size).fill_(y_seq_len).long())  # (max_len)
        max_seq_lens = cuda(
            torch.zeros(batch_size).fill_(y_seq_len).long())  # (max_len)
        eos_tensor = cuda(torch.zeros(batch_size).fill_(self.eos_token)).long()

        self.log_probs, self.R_b, self.neg_Hs, self.gumbel_tokens = [], [], [], []
        msg = []

        for idx in range(y_seq_len):
            if self.input_feeding:
                input = torch.cat([y_emb, input_feed],
                                  dim=1)  # (batch_size, D_emb + D_hid)
            else:
                input = y_emb

            if send_method == "reinforce" and value_fn:
                if self.input_feeding:
                    value_fn_input = [y_emb, input_feed]
                else:
                    value_fn_input = [y_emb, trg_hid]
                value_fn_input = torch.cat(value_fn_input, dim=1).detach()
                self.R_b.append(
                    value_fn(value_fn_input).view(-1))  # (batch_size)

            for i, rnn in enumerate(self.layers):
                trg_hid = rnn(input, prev_trg_hids[i])  # (batch_size, D_hid)
                #input = F.dropout(trg_hid, p=self.drop_ratio, training=self.training)
                input = trg_hid
                prev_trg_hids[i] = trg_hid

            if self.attention is not None:
                out, attn_scores = self.attention(
                    trg_hid, src_hid, src_mask)  # (batch_size, D_hid)
            else:
                out = trg_hid

            input_feed = out
            logit = self.out(out)  # (batch_size, voc_sz_trg)

            if send_method == "argmax":
                tokens = logit.max(dim=1)[1]  # (batch_size)
                tok_dist = Categorical(logits=logit)
                self.neg_Hs.append(-1 * tok_dist.entropy())

            elif send_method == "reinforce":
                tok_dist = Categorical(logits=logit)
                tokens = tok_dist.sample()
                self.log_probs.append(
                    tok_dist.log_prob(tokens))  # (batch_size)
                self.neg_Hs.append(-1 * tok_dist.entropy())

            elif send_method == "gumbel":
                y = gumbel_softmax(logit, gumbel_temp)
                tok_dist = Categorical(probs=y)
                self.neg_Hs.append(-1 * tok_dist.entropy())
                tokens = torch.argmax(y, dim=1)
                tokens_oh = cuda(torch.zeros(y.size())).scatter_(
                    1, tokens.unsqueeze(-1), 1)
                self.gumbel_tokens.append((tokens_oh - y).detach() + y)
            else:
                raise ValueError

            msg.append(tokens)
            is_next_eos = (tokens == eos_tensor).long()  # (batch_size)
            new_seq_lens = max_seq_lens.clone().masked_fill_(
                mask=is_next_eos.bool(),
                value=float(idx +
                            1))  # NOTE idx+1 ensures this is valid length
            seq_lens = torch.min(seq_lens, new_seq_lens)  # (contains lengths)

            done = (done + is_next_eos).clamp(min=0, max=1).long()
            if done.sum() == batch_size:
                break

            y_emb = self.emb(tokens)  # (batch_size, D_emb)

        msg = torch.stack(msg, dim=1)  # (batch_size, y_seq_len)
        self.neg_Hs = torch.stack(self.neg_Hs,
                                  dim=1)  # (batch_size, y_seq_len)
        if send_method == "reinforce":
            self.log_probs = torch.stack(self.log_probs,
                                         dim=1)  # (batch_size, y_seq_len)
            self.R_b = torch.stack(
                self.R_b,
                dim=1) if value_fn else self.R_b  # (batch_size, y_seq_len)

        if send_method == 'gumbel' and len(self.gumbel_tokens) > 0:
            self.gumbel_tokens = torch.stack(self.gumbel_tokens, dim=1)

        result = {"msg": msg.clone(), "new_seq_lens": seq_lens.clone()}

        # Trim sequence length with first dot tensor
        if dot_token is not None:
            seq_lens = first_appear_indice(msg, dot_token) + 1
            pass

        # Jason's trick on trim the sentence length based on ground truth length
        if self.msg_len_ratio > 0:
            en_ref_len = torch.floor(trg_len.float() *
                                     self.msg_len_ratio).long()
            seq_lens = torch.min(seq_lens, en_ref_len)

        # Make length larger than min len and valid
        seq_lens = torch.max(
            seq_lens,
            seq_lens.new(seq_lens.size()).fill_(self.min_len_gen))
        seq_lens = torch.min(seq_lens,
                             seq_lens.new(seq_lens.size()).fill_(msg.shape[1]))

        # Make sure message is valid
        ends_with_eos = (msg.gather(
            dim=1, index=(seq_lens - 1)[:,
                                        None]).view(-1) == eos_tensor).long()
        eos_or_pad = ends_with_eos * self.pad_token + (
            1 - ends_with_eos) * self.eos_token
        msg = torch.cat(
            [msg, msg.new(batch_size, 1).fill_(self.pad_token)], dim=1)
        if send_method == 'gumbel' and len(self.gumbel_tokens) > 0:
            pad_gumbel_tokens = cuda(
                torch.zeros(msg.shape[0], 1, self.gumbel_tokens.shape[2]))
            pad_gumbel_tokens[:, :, self.pad_token] = 1
            self.gumbel_tokens = torch.cat(
                [self.gumbel_tokens, pad_gumbel_tokens], dim=1)

        # (batch_size, y_seq_len + 1)
        msg.scatter_(dim=1, index=seq_lens[:, None], src=eos_or_pad[:, None])
        if send_method == 'gumbel' and len(self.gumbel_tokens) > 0:
            batch_ids = cuda(torch.arange(msg.shape[0]))
            self.gumbel_tokens[batch_ids, seq_lens] = 0.
            self.gumbel_tokens[batch_ids, seq_lens, eos_or_pad] = 1.

        seq_lens = seq_lens + (1 - ends_with_eos)
        msg_mask = xlen_to_inv_mask(
            seq_lens, seq_len=msg.size(1))  # (batch_size, x_seq_len)
        msg.masked_fill_(msg_mask.bool(), self.pad_token)
        result.update({"msg": msg, "new_seq_lens": seq_lens})
        return result
def train_model(args, model):

    resnet = torchvision.models.resnet152(pretrained=True)
    resnet = nn.Sequential(*list(resnet.children())[:-1])
    resnet = nn.DataParallel(resnet).cuda()
    resnet.eval()

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter( args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["loss"], {"loss":1.0}
    monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split()

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    # Train dataset
    args.logger.info("Loading train imgs...")
    train_dataset = ImageFolderWithPaths(os.path.join(args.data_dir, 'flickr30k'), preprocess_rc)
    train_imgs = open(os.path.join(args.data_dir, 'flickr30k/train.txt'), 'r').readlines()
    train_imgs = [x.strip() for x in train_imgs if x.strip() != ""]
    train_dataset.samples = [x for x in train_dataset.samples if x[0].split("/")[-1] in train_imgs]
    train_dataset.imgs = [x for x in train_dataset.imgs if x[0].split("/")[-1] in train_imgs]
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8,
                                               pin_memory=False)
    args.logger.info("Train loader built!")

    en_vocab = TextVocab(counter=torch.load(os.path.join(args.data_dir, 'bpe/vocab.en.pth')))
    word2idx = en_vocab.stoi
    train_en = [open(os.path.join(args.data_dir, 'flickr30k/caps', 'train.{}.bpe'.format(idx+1))).readlines() for idx in range(5)]
    train_en = [[["<bos>"] + sentence.strip().split() + ["<eos>"] for sentence in doc if sentence.strip() != "" ] for doc in train_en]
    train_en = [[[word2idx[word] for word in sentence] for sentence in doc] for doc in train_en]

    args.logger.info("Train corpus built!")

    # Valid dataset
    valid_img_feats = torch.tensor(torch.load(os.path.join(args.data_dir, 'flickr30k/val_feat.pth')))
    valid_ens = []
    valid_en_lens = []
    for idx in range(5):
        valid_en = []
        with open(os.path.join(args.data_dir, 'flickr30k/caps', 'val.{}.bpe'.format(idx+1))) as f:
            for line in f:
                line = line.strip()
                if line == "":
                    continue

                words = ["<bos>"] + line.split() + ["<eos>"]
                words = [word2idx[word] for word in words]
                valid_en.append(words)

        # Pad
        valid_en_len = [len(sent) for sent in valid_en]
        valid_en = [np.lib.pad(xx, (0, max(valid_en_len) - len(xx)), 'constant', constant_values=(0, 0)) for xx in valid_en]
        valid_ens.append(torch.tensor(valid_en).long())
        valid_en_lens.append(torch.tensor(valid_en_len).long())
    args.logger.info("Valid corpus built!")

    iters = -1
    should_stop = False
    for epoch in range(999999999):
        if should_stop:
            break

        for idx, (train_img, lab, path) in enumerate(train_loader):
            iters += 1
            if iters > args.max_training_steps:
                should_stop = True
                break

            if iters % args.eval_every == 0:
                res = get_retrieve_result(args, model, valid_caps=valid_ens, valid_lens=valid_en_lens,
                                          valid_img_feats=valid_img_feats)
                val_metrics = valid_model(args, model, valid_img_feats=valid_img_feats,
                                          valid_caps=valid_ens, valid_lens=valid_en_lens)
                args.logger.info("[VALID] update {} : {}".format(iters, str(val_metrics)))
                if not args.debug:
                    write_tb(writer, monitor_names, res, iters, prefix="dev/")
                    write_tb(writer, loss_names, [val_metrics.__getattr__(name) for name in loss_names],
                             iters, prefix="dev/")

                best.accumulate((res[0]+res[3])/2, iters)
                args.logger.info('model:' + args.prefix + args.hp_str)
                args.logger.info('epoch {} iters {}'.format(epoch, iters))
                args.logger.info(best)

                if args.early_stop and (iters - best.iters) // args.eval_every > args.patience:
                    args.logger.info("Early stopping.")
                    return

            model.train()

            def get_lr_anneal(iters):
                lr_end = args.lr_min
                return max( 0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) /
                           args.linear_anneal_steps ) + lr_end

            if args.lr_anneal == "linear":
                opt.param_groups[0]['lr'] = get_lr_anneal(iters)

            opt.zero_grad()
            batch_size = len(path)
            path = [p.split("/")[-1] for p in path]
            sentence_idx = [train_imgs.index(p) for p in path]
            en = [train_en[random.randint(0, 4)][sentence_i] for sentence_i in sentence_idx]
            en_len = [len(x) for x in en]

            en = [ np.lib.pad( xx, (0, max(en_len) - len(xx)), 'constant', constant_values=(0,0) ) for xx in en ]
            en = cuda( torch.LongTensor( np.array(en).tolist() ) )
            en_len = cuda( torch.LongTensor( en_len ) )

            with torch.no_grad():
                train_img = resnet(train_img).view(batch_size, -1)
            R = model(en[:,1:], en_len-1, train_img)
            if args.img_pred_loss == "vse":
                R['loss'] = R['loss'].sum()
            elif args.img_pred_loss == "mse":
                R['loss'] = R['loss'].mean()
            else:
                raise Exception()

            total_loss = 0
            for loss_name in loss_names:
                total_loss += R[loss_name] * loss_cos[loss_name]

            train_metrics.accumulate(batch_size, *[R[name].item() for name in loss_names])

            total_loss.backward()
            if args.plot_grad:
                plot_grad(writer, model, iters)

            if args.grad_clip > 0:
                nn.utils.clip_grad_norm_(params, args.grad_clip)

            opt.step()

            if iters % args.eval_every == 0:
                args.logger.info("update {} : {}".format(iters, str(train_metrics)))

            if iters % args.eval_every == 0 and not args.debug:
                write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                         iters, prefix="train/")
                write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/")
                train_metrics.reset()
def train_model(args, model, iterators, extra_input):
    (train_its, dev_its) = iterators

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["nll"], {"nll": 1.0}
    monitor_names = ["nll_rnd"]
    """
    if args.rep_pen_co > 0.0:
        loss_names.append("nll_cur")
        loss_cos["nll_cur"] = -1 * args.rep_pen_co
    else:
        monitor_names.append("nll_cur")
    """

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    dev_metrics = Metrics('dev_loss',
                          *loss_names,
                          *monitor_names,
                          data_type="avg")
    best = Best(min, 'loss', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    iters = 0
    should_stop = False
    for epoch in range(999999999):
        if should_stop:
            break

        for dataset in args.dataset.split("_"):
            if should_stop:
                break

            train_it = train_its[dataset]
            for _, train_batch in enumerate(train_it):
                if iters >= args.max_training_steps:
                    args.logger.info(
                        'stopping training after {} training steps'.format(
                            args.max_training_steps))
                    should_stop = True
                    break

                if iters % args.eval_every == 0:
                    dev_metrics.reset()
                    valid_model(args, model, dev_its['multi30k'], dev_metrics,
                                iters, loss_names, monitor_names, extra_input)
                    if not args.debug:
                        write_tb(writer, loss_names, [dev_metrics.__getattr__(name) for name in loss_names], \
                                 iters, prefix="dev/")
                        write_tb(writer, monitor_names, [dev_metrics.__getattr__(name) for name in monitor_names], \
                                 iters, prefix="dev/")
                    best.accumulate(dev_metrics.nll, iters)

                    args.logger.info('model:' + args.prefix + args.hp_str)
                    args.logger.info('epoch {} dataset {} iters {}'.format(
                        epoch, dataset, iters))
                    args.logger.info(best)

                    if args.early_stop and (
                            iters -
                            best.iters) // args.eval_every > args.patience:
                        args.logger.info("Early stopping.")
                        return

                model.train()

                def get_lr_anneal(iters):
                    lr_end = args.lr_min
                    return max(0, (args.lr - lr_end) *
                               (args.linear_anneal_steps - iters) /
                               args.linear_anneal_steps) + lr_end

                if args.lr_anneal == "linear":
                    opt.param_groups[0]['lr'] = get_lr_anneal(iters)

                opt.zero_grad()

                batch_size = len(train_batch)
                img_input = None if args.no_img else cuda(
                    extra_input["img"][dataset][0].index_select(
                        dim=0, index=train_batch.idx.cpu()))
                if dataset == "coco":
                    en, en_len = train_batch.__dict__[
                        "_" + str(random.randint(1, 5))]
                elif dataset == "multi30k":
                    en, en_len = train_batch.en

                decoded = model(en, img_input)
                R = {}
                R["nll"] = F.cross_entropy(decoded,
                                           en[:, 1:].contiguous().view(-1),
                                           ignore_index=0)
                #R["nll_cur"] = F.cross_entropy( decoded, en[:,:-1].contiguous().view(-1), ignore_index=0 )

                total_loss = 0
                for loss_name in loss_names:
                    total_loss += R[loss_name] * loss_cos[loss_name]

                train_metrics.accumulate(
                    batch_size, *[R[name].item() for name in loss_names])

                total_loss.backward()
                if args.plot_grad:
                    plot_grad(writer, model, iters)

                if args.grad_clip > 0:
                    nn.utils.clip_grad_norm_(params, args.grad_clip)

                opt.step()
                iters += 1

                if iters % args.eval_every == 0:
                    args.logger.info("update {} : {}".format(
                        iters, str(train_metrics)))

                if iters % args.eval_every == 0 and not args.debug:
                    write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                             iters, prefix="train/")
                    write_tb(writer, ['lr'], [opt.param_groups[0]['lr']],
                             iters,
                             prefix="train/")
                    train_metrics.reset()
Пример #17
0
    def get_grounding(self,
                      en_msg,
                      en_msg_len,
                      batch,
                      en_lm=None,
                      all_img=None,
                      ranker=None,
                      use_gumbel_tokens=False):
        """ Forward speaker with English sentence to get grounding loss and rewards """
        results = {}
        rewards = {}
        batch_size = en_msg.shape[0]
        # NOTE add <BOS> to beginning
        en_msg_ = torch.cat([
            cuda(torch.full((batch_size, 1), self.init_token)).long(), en_msg
        ],
                            dim=1)
        gumbel_tokens = None
        if use_gumbel_tokens:
            gumbel_tokens = self.fr_en.dec.gumbel_tokens
            init_tokens = torch.zeros(
                [gumbel_tokens.shape[0], 1, gumbel_tokens.shape[2]])
            init_tokens = init_tokens.to(device=gumbel_tokens.device)
            init_tokens[:, :, self.init_token] = 1
            gumbel_tokens = torch.cat([init_tokens, gumbel_tokens], dim=1)

        if self.use_en_lm:  # monitor EN LM NLL
            if "wiki" in self.en_lm_dataset:
                if use_gumbel_tokens:
                    raise NotImplementedError
                en_nll_lm = en_lm.get_nll(en_msg_)  # (batch_size, en_msg_len)
                if self.train_en_lm:
                    en_nll_lm = sum_reward(en_nll_lm,
                                           en_msg_len + 1)  # (batch_size)
                    rewards['lm'] = -1 * en_nll_lm.detach()
                    # R = R + -1 * en_nll_lm.detach() * self.en_lm_nll_co # (batch_size)
                results.update({"en_nll_lm": en_nll_lm.mean()})

            elif self.en_lm_dataset in ["coco", "multi30k"]:
                if use_gumbel_tokens:
                    en_lm.train()
                    en_nll_lm = en_lm.get_loss_oh(gumbel_tokens, None)
                    en_lm.eval()
                else:
                    en_nll_lm = en_lm.get_loss(
                        en_msg_, None)  # (batch_size, en_msg_len)
                if self.train_en_lm:
                    en_nll_lm = sum_reward(en_nll_lm,
                                           en_msg_len + 1)  # (batch_size)
                    rewards['lm'] = -1 * en_nll_lm.detach()
                results.update({"en_nll_lm": en_nll_lm.mean()})
            else:
                raise Exception()

        if self.use_ranker:  # NOTE Experiment 3 : Reward = NLL_DE + NLL_EN_LM + NLL_IMG_PRED
            if use_gumbel_tokens and self.train_ranker:
                raise NotImplementedError
            ranker.eval()
            img = cuda(all_img.index_select(
                dim=0, index=batch.idx.cpu()))  # (batch_size, D_img)

            if self.img_pred_loss == "nll":
                img_pred_loss = ranker.get_loss(
                    en_msg_, img)  # (batch_size, en_msg_len)
                img_pred_loss = sum_reward(img_pred_loss,
                                           en_msg_len + 1)  # (batch_size)
            else:
                with torch.no_grad():
                    img_pred_loss = ranker(en_msg, en_msg_len, img)["loss"]

            if self.train_ranker:
                rewards['img_pred'] = -1 * img_pred_loss.detach()
            results.update({
                "img_pred_loss_{}".format(self.img_pred_loss):
                img_pred_loss.mean()
            })

            # Get ranker retrieval result
            with torch.no_grad():
                K = 19
                # Randomly select K distractor image
                random_idx = torch.randint(all_img.shape[0],
                                           size=[batch_size, K])
                wrong_img = cuda(
                    all_img.index_select(dim=0, index=random_idx.view(-1)))
                wrong_img_feat = ranker.batch_enc_img(wrong_img).view(
                    batch_size, K, -1)
                right_img_feat = ranker.batch_enc_img(img)

                # [bsz, K+1, hid_size]
                all_feat = torch.cat(
                    [right_img_feat.unsqueeze(1), wrong_img_feat], dim=1)

                # [bsz, hid_size]
                cap_feats = ranker.batch_cap_rep(en_msg, en_msg_len)
                scores = (cap_feats.unsqueeze(1) * all_feat).sum(-1)
                r1_acc = (torch.argmax(scores, -1) == 0).float().mean()
                results['r1_acc'] = r1_acc
        return results, rewards
Пример #18
0
def train_model(args, model, iterators, extra_input):
    (train_its, dev_its) = iterators

    if not args.debug:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(args.event_path + args.id_str)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer == 'Adam':
        opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr)
    else:
        raise NotImplementedError

    loss_names, loss_cos = ["loss"], {"loss": 1.0}
    monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split()

    train_metrics = Metrics('train_loss', *loss_names, data_type="avg")
    best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \
                gpu=args.gpu, debug=args.debug)

    iters = 0
    for epoch in range(999999999):
        for dataset in args.dataset.split("_"):
            train_it = train_its[dataset]
            for _, train_batch in enumerate(train_it):
                iters += 1

                if iters % args.eval_every == 0:
                    R = valid_model(args, model, dev_its['multi30k'],
                                    extra_input)
                    if not args.debug:
                        write_tb(writer,
                                 monitor_names,
                                 R,
                                 iters,
                                 prefix="dev/")
                    best.accumulate((R[0] + R[3]) / 2, iters)

                    args.logger.info('model:' + args.prefix + args.hp_str)
                    args.logger.info('epoch {} dataset {} iters {}'.format(
                        epoch, dataset, iters))
                    args.logger.info(best)

                    if args.early_stop and (
                            iters -
                            best.iters) // args.eval_every > args.patience:
                        args.logger.info("Early stopping.")
                        return

                model.train()

                def get_lr_anneal(iters):
                    lr_end = args.lr_min
                    return max(0, (args.lr - lr_end) *
                               (args.linear_anneal_steps - iters) /
                               args.linear_anneal_steps) + lr_end

                if args.lr_anneal == "linear":
                    opt.param_groups[0]['lr'] = get_lr_anneal(iters)

                opt.zero_grad()

                batch_size = len(train_batch)
                img = extra_input["img"][dataset][0].index_select(
                    dim=0, index=train_batch.idx.cpu())  # (batch_size, D_img)
                en, en_len = train_batch.__dict__["_" +
                                                  str(random.randint(1, 5))]
                R = model(en[:, 1:], en_len - 1, cuda(img))
                R['loss'] = R['loss'].mean()

                total_loss = 0
                for loss_name in loss_names:
                    total_loss += R[loss_name] * loss_cos[loss_name]

                train_metrics.accumulate(
                    batch_size, *[R[name].item() for name in loss_names])

                total_loss.backward()
                if args.plot_grad:
                    plot_grad(writer, model, iters)

                if args.grad_clip > 0:
                    nn.utils.clip_grad_norm_(params, args.grad_clip)

                opt.step()

                if iters % args.eval_every == 0:
                    args.logger.info("update {} : {}".format(
                        iters, str(train_metrics)))

                if iters % args.eval_every == 0 and not args.debug:
                    write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \
                             iters, prefix="train/")
                    write_tb(writer, ['lr'], [opt.param_groups[0]['lr']],
                             iters,
                             prefix="train/")
                    train_metrics.reset()
Пример #19
0
    def beam_search(self, h, h_len, width): # (batch_size, x_seq_len, D_hid * n_dir)
        voc_size, batch_size, x_seq_len = self.voc_sz_trg, h.size()[0], h.size()[1]
        live = [ [ ( 0.0, [ self.init_token ], 2 ) ] for ii in range(batch_size) ]
        dead = [ [] for ii in range(batch_size) ]
        n_dead = [0 for ii in range(batch_size)]
        xmask = xlen_to_inv_mask(h_len)[:,None,:] # (batch_size, 1, x_seq_len)
        max_len_gen = self.max_len_gen if self.msg_len_ratio < 0.0 else int(x_seq_len * self.msg_len_ratio)
        max_len_gen = np.clip(max_len_gen, 2, self.max_len_gen).item()

        h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2))
        z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir)
        z_t = torch.div( z_t, h_len[:, None].float() )
        z_t = self.ctx_to_z0(z_t) # (batch_size, n_layers * D_hid)
        z_t = z_t.view(batch_size, self.n_layers, self.D_hid).transpose(0,1).contiguous() \

        input = cuda( torch.full((batch_size, 1), self.init_token).long() )
        input = self.emb( input ) # (batch_size, 1, D_emb)

        h_big = h.contiguous().view(-1, self.D_hid * self.n_dir ) \
                # (batch_size * x_seq_len, D_hid * n_dir)
        ctx_h = self.h_to_att( h_big ).view(batch_size, 1, x_seq_len, self.D_att) \
                # NOTE (batch_size, 1, x_seq_len, D_att)

        for tidx in range(max_len_gen):
            cwidth = 1 if tidx == 0 else width
            # input (batch_size * width, 1, D_emb)
            # z_t (n_layers, batch_size * width, D_hid)
            _, z_t_ = self.rnn( input, z_t )
            # out (batch_size * width, 1, D_hid)
            # z_t (n_layers, batch_size * width, D_hid)
            ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size * cwidth, -1) \
                    # (batch_size * width, n_layers * D_hid)

            ctx_y = self.y_to_att( input.view(-1, self.D_emb) ).view(batch_size, cwidth, 1, self.D_att)
            ctx_s = self.z_to_att( ctx_z_t_ ).view(batch_size, cwidth, 1, self.D_att)
            ctx = F.tanh(ctx_y + ctx_s + ctx_h) # (batch_size, cwidth, x_seq_len, D_att)
            ctx = ctx.view(batch_size * cwidth * x_seq_len, self.D_att)

            score = self.att_to_score(ctx).view(batch_size, -1, x_seq_len) # (batch_size, cwidth, x_seq_len)
            score.masked_fill_(xmask.repeat(1, cwidth, 1), -float('inf'))
            score = F.softmax( score.view(-1, x_seq_len), dim=1 ).view(batch_size, -1, x_seq_len)
            score = score.view(batch_size, cwidth, x_seq_len, 1) # (batch_size, width, x_seq_len, 1)

            c_t = torch.mul( h.view(batch_size, 1, x_seq_len, -1), score ) # (batch_size, width, x_seq_len, D_hid * n_dir)
            c_t = torch.sum( c_t, 2).view(batch_size * cwidth, -1) # (batch_size * width, D_hid * n_dir)
            # c_t (batch_size * width, 1, D_hid * n_dir)
            # z_t (n_layers, batch_size * width, D_hid)
            out, z_t = self.rnn2( c_t[:,None,:], z_t_ )
            # out (batch_size * width, 1, D_hid)
            # z_t (n_layers, batch_size * width, D_hid)

            fin_y = self.y_to_out( input.view(-1, self.D_emb) ) # (batch_size * width, D_out)
            fin_c = self.c_to_out( c_t ) # (batch_size * width, D_out)
            fin_s = self.z_to_out( out.view(-1, self.D_hid) ) # (batch_size * width, D_out)
            fin = F.tanh( fin_y + fin_c + fin_s )

            cur_prob = F.log_softmax( self.out( fin.view(-1, self.D_out) ), dim=1 )\
                    .view(batch_size, cwidth, voc_size).data # (batch_size, width, voc_sz_trg)
            pre_prob = cuda( torch.FloatTensor( [ [ x[0] for x in ee ] for ee in live ] ).view(batch_size, cwidth, 1) ) \
                    # (batch_size, width, 1)
            total_prob = cur_prob + pre_prob # (batch_size, cwidth, voc_size)
            total_prob = total_prob.view(batch_size, -1)

            _, topi_s = total_prob.topk(width, dim=1)
            topv_s = cur_prob.view(batch_size, -1).gather(1, topi_s)
            # (batch_size, width)

            new_live = [ [] for ii in range(batch_size) ]
            for bidx in range(batch_size):
                n_live = width - n_dead[bidx]
                if n_live > 0:
                    tis = topi_s[bidx][:n_live]
                    tvs = topv_s[bidx][:n_live]
                    for eidx, (topi, topv) in enumerate(zip(tis, tvs)): # NOTE max width times
                        if topi % voc_size == self.eoz_token :
                            dead[bidx].append( (  live[bidx][ topi // voc_size ][0] + topv, \
                                                  live[bidx][ topi // voc_size ][1] + [ topi % voc_size ],\
                                                  topi) )
                            n_dead[bidx] += 1
                        else:
                            new_live[bidx].append( (    live[bidx][ topi // voc_size ][0] + topv, \
                                                        live[bidx][ topi // voc_size ][1] + [ topi % voc_size ],\
                                                        topi) )
                while len(new_live[bidx]) < width:
                    new_live[bidx].append( (    -99999999999, \
                                                [0],\
                                                0) )
            live = new_live

            if n_dead == [width for ii in range(batch_size)]:
                break

            in_vocab_idx = [ [ x[2] % voc_size for x in ee ] for ee in live ] # NOTE batch_size first
            input = self.emb( cuda( torch.LongTensor( in_vocab_idx ) ).view(-1) )\
                    .view(-1, 1, self.D_emb) # input (batch_size * width, 1, D_emb)

            in_width_idx = [ [ x[2] // voc_size + bbidx * cwidth for x in ee ] for bbidx, ee in enumerate(live) ] \
                    # live : (batch_size, width)
            z_t = z_t.index_select( 1, cuda( torch.LongTensor( in_width_idx ).view(-1) ) ).\
                    view(self.n_layers, batch_size * width, self.D_hid)
            # h_0 (n_layers, batch_size * width, D_hid)

        for bidx in range(batch_size):
            if n_dead[bidx] < width:
                for didx in range( width - n_dead[bidx] ):
                    (a, b, c) = live[bidx][didx]
                    dead[bidx].append( (a, b, c)  )

        dead_ = [ [ ( a / ( math.pow(5+len(b), self.beam_alpha) / math.pow(5+1, self.beam_alpha) ) , b, c) for (a,b,c) in ee] for ee in dead]
        ans = []
        for dd_ in dead_:
            dd = sorted( dd_, key=operator.itemgetter(0), reverse=True )
            ans.append( dd[0][1] )
        return ans
Пример #20
0
    def send(self, h, h_len, y_len, send_method):
        # h : (batch_size, x_seq_len, D_hid * n_dir)
        # h_len : (batch_size)
        batch_size, x_seq_len = h.size()[:2]
        xmask = xlen_to_inv_mask(h_len) # (batch_size, x_seq_len)
        max_len_gen = self.max_len_gen if self.msg_len_ratio < 0.0 else int(y_len.max().item() * self.msg_len_ratio)
        max_len_gen = np.clip(max_len_gen, 2, self.max_len_gen).item()

        h_index = (h_len - 1)[:,None,None].repeat(1, 1, h.size(2))
        z_t = torch.cumsum(h, dim=1) # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1, index=h_index).view(batch_size, -1) # (batch_size, D_hid * n_dir)
        z_t = torch.div( z_t, h_len[:, None].float() )
        z_t = self.ctx_to_z0(z_t) # (batch_size, n_layers * D_hid)
        z_t = z_t.view(batch_size, self.n_layers, self.D_hid).transpose(0,1).contiguous() \

        y_emb = cuda( torch.full((batch_size, 1), self.init_token).long() )
        y_emb = self.emb( y_emb ) # (batch_size, 1, D_emb)
        y_emb = F.dropout( y_emb, p=self.drop_ratio, training=self.training )

        h_big = h.view(-1, h.size(2) ) \
                # (batch_size * x_seq_len, D_hid * n_dir) 
        ctx_h = self.h_to_att( h_big ).view(batch_size, x_seq_len, self.D_att)

        done = cuda( torch.zeros(batch_size).long() )
        seq_lens = cuda( torch.zeros(batch_size).fill_(max_len_gen).long() ) # (max_len)
        max_seq_lens = cuda( torch.zeros(batch_size).fill_(max_len_gen).long() ) # (max_len)
        eos_tensor = cuda( torch.zeros(batch_size).fill_( self.eoz_token )).long()
        msg = []
        self.log_probs = [] # (batch_size, seq_len)
        batch_logits = 0

        for idx in range(max_len_gen):
            # in (batch_size, 1, D_emb)
            # z_t (n_layers, batch_size, D_hid)
            _, z_t_ = self.rnn( y_emb, z_t )
            # out (batch_size, 1, D_hid)
            # z_t (n_layers, batch_size, D_hid)
            ctx_z_t_ = z_t_.transpose(0,1).contiguous().view(batch_size, self.n_layers * self.D_hid) \
                    # (batch_size, n_layers * D_hid)

            ctx_y = self.y_to_att( y_emb.view(batch_size, self.D_emb) )[:,None,:]
            ctx_s = self.z_to_att( ctx_z_t_ )[:,None,:]
            ctx = F.tanh(ctx_y + ctx_s + ctx_h)
            ctx = ctx.view(batch_size * x_seq_len, self.D_att)

            score = self.att_to_score(ctx).view(batch_size, -1) # (batch_size, x_seq_len)
            score.masked_fill_(xmask, -float('inf'))
            score = F.softmax( score, dim=1 )
            score = score[:,:,None] # (batch_size, x_seq_len, 1)

            c_t = torch.mul( h, score ) # (batch_size, x_seq_len, D_hid * n_dir)
            c_t = torch.sum( c_t, 1) # (batch_size, D_hid * n_dir)
            # in (batch_size, 1, D_hid * n_dir)
            # z_t (n_layers, batch_size, D_hid)
            out, z_t = self.rnn2( c_t[:,None,:], z_t_ )
            # out (batch_size, 1, D_hid)
            # z_t (n_layers, batch_size, D_hid)

            fin_y = self.y_to_out( y_emb.view(batch_size, self.D_emb) ) # (batch_size, D_out)
            fin_c = self.c_to_out( c_t ) # (batch_size, D_out)
            fin_s = self.z_to_out( out.view(batch_size, self.D_hid) ) # (batch_size, D_out)
            fin = F.tanh( fin_y + fin_c + fin_s )

            logit = self.out( fin ) # (batch_size, voc_sz_trg)

            if send_method == "argmax":
                tokens = logit.data.max(dim=1)[1] # (batch_size)
                tokens_idx = tokens

            elif send_method == "gumbel":
                tokens = gumbel_softmax_hard(logit, self.temp, self.st)\
                        .view(batch_size, self.voc_sz_trg) # (batch_size, voc_sz_trg)
                tokens_idx = (tokens * cuda(torch.arange(self.voc_sz_trg))[None,:]).sum(dim=1).long()

            elif send_method == "reinforce":
                cat = Categorical(logits=logit)
                tokens = cat.sample()
                tokens_idx = tokens
                self.log_probs.append( cat.log_prob(tokens) )

                if self.entropy:
                    batch_logits += (logit * (1-done)[:,None].float()).sum(dim=0)

            msg.append(tokens.unsqueeze(dim=1))

            is_next_eos = ( tokens_idx == eos_tensor ).long() # (batch_size)
              # (1 if eos, 0 otherwise)
            done = (done + is_next_eos).clamp(min=0, max=1).long()
            new_seq_lens = max_seq_lens.clone().masked_fill_(mask=is_next_eos.byte(), \
                                                             value=float(idx+1)) # either max or idx if next is eos
            # max_seq_lens : (batch_size)
            seq_lens = torch.min(seq_lens, new_seq_lens)

            if done.sum() == batch_size:
                break

            y_emb = self.emb(tokens)[:,None,:] # (batch_size, 1, D_emb)

        if self.msg_len_ratio > 0.0 :
            seq_lens = torch.clamp( (y_len.float() * self.msg_len_ratio).floor_(), 1, len(msg) ).long()

        msg = torch.cat(msg, dim=1)
        if send_method == "reinforce": # want to sum per-token log prob to yield log prob for the whole message sentence
            self.log_probs = torch.stack(self.log_probs, dim=1) # (batch_size, seq_len)

          # (batch_size, seq_len) if argmax or reinforce
          # (batch_size, seq_len, voc_sz_trg) if gumbel
        results = {"msg":msg, "seq_lens":seq_lens}
        if send_method == "reinforce" and self.entropy:
            results.update( {"batch_logits":batch_logits} )
        return results
Пример #21
0
    def beam(self, src_hid, src_len):
        # src_hid : (batch_size, x_seq_len, D_hid * n_dir)
        # src_len : (batch_size)
        batch_size, x_seq_len = src_hid.size()[:2]
        src_mask = xlen_to_inv_mask(
            src_len, seq_len=x_seq_len)  # (batch_size, x_seq_len)
        voc_size, width = self.voc_sz_trg, self.beam_width
        y_seq_len = self.max_len_gen

        h_index = (src_len - 1)[:, None, None].repeat(1, 1, src_hid.size(2))
        z_t = torch.cumsum(src_hid,
                           dim=1)  # (batch_size, x_seq_len, D_hid * n_dir)
        z_t = z_t.gather(dim=1,
                         index=h_index).view(batch_size,
                                             -1)  # (batch_size, D_hid * n_dir)
        z_t = torch.div(z_t, src_len[:, None].float())

        y_emb = cuda(torch.full((batch_size, ), self.init_token).long())
        y_emb = self.emb(y_emb)  # (batch_size, D_emb)

        input_feed = y_emb.data.new(batch_size, self.D_hid).zero_()

        live = [[(0.0, [self.init_token], 2)] for ii in range(batch_size)]
        dead = [[] for ii in range(batch_size)]
        n_dead = [0 for ii in range(batch_size)]
        src_hid_ = src_hid
        src_mask_ = src_mask

        for idx in range(y_seq_len):
            cwidth = 1 if idx == 0 else width
            input = torch.cat([y_emb, input_feed],
                              dim=1)  # (batch_size * width, D_emb + D_hid)
            trg_hid = self.layers[0](input, z_t)  # (batch_size * width, D_hid)
            z_t = trg_hid  # (batch_size * width, D_hid)

            out, attn_scores = self.attention(
                trg_hid, src_hid_, src_mask_)  # (batch_size * width, D_hid)
            input_feed = out  # (batch_size * width, D_hid)

            logit = self.out(out)  # (batch_size * width, voc_sz_trg)
            cur_prob = F.log_softmax(logit,
                                     dim=1).view(batch_size, cwidth,
                                                 self.voc_sz_trg)
            pre_prob = cuda( torch.FloatTensor( [ [ x[0] for x in ee ] for ee in live ] )\
                            .view(batch_size, cwidth, 1) ) # (batch_size, width, 1)
            total_prob = cur_prob + pre_prob  # (batch_size, cwidth, voc_sz)
            total_prob = total_prob.view(batch_size,
                                         -1)  # (batch_size, cwidth * voc_sz)

            topi_s = total_prob.topk(width, dim=1)[1]
            topv_s = cur_prob.view(batch_size, -1).gather(1, topi_s)

            new_live = [[] for ii in range(batch_size)]
            for bidx in range(batch_size):
                n_live = width - n_dead[bidx]
                if n_live > 0:
                    tis = topi_s[bidx][:n_live].cpu().numpy().tolist()
                    tvs = topv_s[bidx][:n_live].cpu().numpy().tolist()
                    for eidx, (topi, topv) in enumerate(zip(tis, tvs)):
                        if topi % voc_size == self.eos_token:
                            dead[bidx].append(
                                (live[bidx][topi // voc_size][0] + topv,
                                 live[bidx][topi // voc_size][1] +
                                 [topi % voc_size], topi))
                            n_dead[bidx] += 1
                        else:
                            new_live[bidx].append(
                                (live[bidx][topi // voc_size][0] + topv,
                                 live[bidx][topi // voc_size][1] +
                                 [topi % voc_size], topi))
                while len(new_live[bidx]) < width:
                    new_live[bidx].append((-99999999999, [0], 0))
            live = new_live
            if n_dead == [width for ii in range(batch_size)]:
                break

            in_vocab_idx = np.array([[x[2] % voc_size for x in ee]
                                     for ee in live])  # NOTE batch_size first
            y_emb = self.emb( cuda( torch.LongTensor( in_vocab_idx ) ).view(-1) )\
                    .view(-1, self.D_emb) # input (batch_size * width, 1, D_emb)

            in_width_idx = np.array( [ [ x[2] // voc_size + bbidx * cwidth for x in ee ] \
                            for bbidx, ee in enumerate(live) ] )
            in_width_idx = cuda(torch.LongTensor(in_width_idx).view(-1))
            z_t = z_t.index_select(0,
                                   in_width_idx).view(batch_size * width,
                                                      self.D_hid)
            input_feed = input_feed.index_select(0, in_width_idx).view(
                batch_size * width, self.D_hid)
            src_hid_ = src_hid_.index_select(0, in_width_idx).view(
                batch_size * width, x_seq_len, self.D_hid)
            src_mask_ = src_mask_.index_select(0, in_width_idx).view(
                batch_size * width, x_seq_len)

        for bidx in range(batch_size):
            if n_dead[bidx] < width:
                for didx in range(width - n_dead[bidx]):
                    (a, b, c) = live[bidx][didx]
                    dead[bidx].append((a, b, c))

        dead_ = [ [ ( a / ( math.pow(5+len(b), self.beam_alpha) / math.pow(5+1, self.beam_alpha) ) , b, c)\
                   for (a,b,c) in ee] for ee in dead]
        ans = [[], [], [], [], []]
        for dd_ in dead_:
            dd = sorted(dd_, key=operator.itemgetter(0), reverse=True)
            for idx in range(5):
                ans[idx].append(dd[idx][1])
            #ans.append( dd[0][1] )
        return ans