Example #1
0
    def step(self, step, lprobs, scores):
        super()._init_buffers(lprobs)
        bsz, beam_size, vocab_size = lprobs.size()

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam
            lprobs = lprobs[:, ::beam_size, :].contiguous()
        else:
            # make probs contain cumulative scores for each hypothesis
            lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))

        torch.topk(
            lprobs.view(bsz, -1),
            k=min(
                # Take the best 2 x beam_size predictions. We'll choose the first
                # beam_size of these which don't predict eos to continue with.
                beam_size * 2,
                lprobs.view(bsz, -1).size(1) - 1,  # -1 so we never select pad
            ),
            out=(self.scores_buf, self.indices_buf),
        )
        torch.div(self.indices_buf, vocab_size, out=self.beams_buf)
        self.indices_buf.fmod_(vocab_size)
        return self.scores_buf, self.indices_buf, self.beams_buf
Example #2
0
def calc_precision(pred, label):
    t1 = torch.topk(pred, 1)[-1]
    t5 = torch.topk(pred, 5)[-1]
    mask_1 = torch.eq(t1, label.view(-1, 1))
    mask_5 = torch.eq(t5, label.view(-1, 1))
    t1_error = 1 - len(t1[mask_1]) / len(label)
    t5_error = 1 - len(t5[mask_5]) / len(label)
    return t1_error, t5_error
def pick_top_n(preds, top_n=5):
    top_pred_prob, top_pred_label = torch.topk(preds, top_n, 1)
    top_pred_prob /= torch.sum(top_pred_prob)
    top_pred_prob = top_pred_prob.squeeze(0).cpu().numpy()
    top_pred_label = top_pred_label.squeeze(0).cpu().numpy()
    c = np.random.choice(top_pred_label, size=1, p=top_pred_prob)
    return c
Example #4
0
    def ohem_detect_loss(self, cls_score, rois_label, bbox_pred, rois_target, rois_inside_ws, rois_outside_ws):

        def log_sum_exp(x):
            x_max = x.data.max()
            return torch.log(torch.sum(torch.exp(x - x_max), dim=1, keepdim=True)) + x_max

        num_hard = cfg.TRAIN.BATCH_SIZE * self.batch_size
        pos_idx = rois_label > 0
        num_pos = pos_idx.int().sum()

        # classification loss
        num_classes = cls_score.size(1)
        weight = cls_score.data.new(num_classes).fill_(1.)
        weight[0] = num_pos.data[0] / num_hard

        conf_p = cls_score.detach()
        conf_t = rois_label.detach()

        # rank on cross_entropy loss
        loss_c = log_sum_exp(conf_p) - conf_p.gather(1, conf_t.view(-1,1))
        loss_c[pos_idx] = 100. # include all positive samples
        _, topk_idx = torch.topk(loss_c.view(-1), num_hard)
        loss_cls = F.cross_entropy(cls_score[topk_idx], rois_label[topk_idx], weight=weight)

        # bounding box regression L1 loss
        pos_idx = pos_idx.unsqueeze(1).expand_as(bbox_pred)
        loc_p = bbox_pred[pos_idx].view(-1, 4)
        loc_t = rois_target[pos_idx].view(-1, 4)
        loss_box = F.smooth_l1_loss(loc_p, loc_t)

        return loss_cls, loss_box
def get_topk_labels(class_weights, topk):
    assert len(class_weights.shape) == 1, 'this is implemented only for a vector of class weights'
    probs, indx = torch.topk(class_weights, k = topk)

    labels = []
    for i in range(topk):
        labels.append(fine_labels_legend[indx[i]])

    return probs, labels
Example #6
0
 def update_beam(self,newlogprobs): # newlogprobs is beamsz,len(EN.vocab)
     newlogprobs = newlogprobs.data
     newlogprobs += self.probs.unsqueeze(1) # beamsz,len(EN.vocab)
     newlogprobs = newlogprobs.view(-1) # flatten to beamsz*len(EN.vocab) (search across all beams)
     sorte,indices = torch.topk(newlogprobs,self.beamsz) 
     # sorte and indices are beamsz. sorte contains probs, indices represent english word indices
     self.probs = sorte
     self.oldbeamindices = indices / len(EN.vocab)
     currbeam = indices % len(EN.vocab) # beamsz
     self.update_wordlist(currbeam)
Example #7
0
def _get_most_activated_channels(z, num_channels=5):
    """
    z: CxHxW
    """

    per_channel_max_activations, _ = z.max(1)[0].max(1)

    most_activated_channels = \
        torch.topk(per_channel_max_activations, num_channels)[1]

    return most_activated_channels
Example #8
0
def calculate_video_results(output_buffer, video_id, test_results, class_names):
    video_outputs = torch.stack(output_buffer)
    average_scores = torch.mean(video_outputs, dim=0)
    sorted_scores, locs = torch.topk(average_scores, k=10)

    video_results = []
    for i in range(sorted_scores.size(0)):
        video_results.append({
            'label': class_names[locs[i]],
            'score': sorted_scores[i]
        })

    test_results['results'][video_id] = video_results
    def proposal_layer(self, rpn_class, rpn_bbox):
        # handling proposals
        scores = rpn_class[:, :, 1]
        # Box deltas [batch, num_rois, 4]
        deltas_mul = Variable(torch.from_numpy(np.reshape(
            self.config.RPN_BBOX_STD_DEV, [1, 1, 4]).astype(np.float32))).cuda()
        deltas = rpn_bbox * deltas_mul

        pre_nms_limit = min(6000, self.anchors.shape[0])

        scores, ix = torch.topk(scores, pre_nms_limit, dim=-1,
                                largest=True, sorted=True)


        ix = torch.unsqueeze(ix, 2)
        ix = torch.cat([ix, ix, ix, ix], dim=2)
        deltas = torch.gather(deltas, 1, ix)

        _anchors = []
        for i in range(self.config.IMAGES_PER_GPU):
            anchors = Variable(torch.from_numpy(
                self.anchors.astype(np.float32))).cuda()
            _anchors.append(anchors)
        anchors = torch.stack(_anchors, 0) 
    
        pre_nms_anchors = torch.gather(anchors, 1, ix)
        refined_anchors = apply_box_deltas_graph(pre_nms_anchors, deltas)

        # Clip to image boundaries. [batch, N, (y1, x1, y2, x2)]
        height, width = self.config.IMAGE_SHAPE[:2]
        window = np.array([0, 0, height, width]).astype(np.float32)
        window = Variable(torch.from_numpy(window)).cuda()

        refined_anchors_clipped = clip_boxes_graph(refined_anchors, window)

        refined_proposals = []
        for i in range(self.config.IMAGES_PER_GPU):
            indices = nms(
                torch.cat([refined_anchors_clipped.data[i], scores.data[i]], 1), 0.7)
            indices = indices[:self.proposal_count]
            indices = torch.stack([indices, indices, indices, indices], dim=1)
            indices = Variable(indices).cuda()
            proposals = torch.gather(refined_anchors_clipped[i], 0, indices)
            padding = self.proposal_count - proposals.size()[0]
            proposals = torch.cat(
                [proposals, Variable(torch.zeros([padding, 4])).cuda()], 0)
            refined_proposals.append(proposals)

        rpn_rois = torch.stack(refined_proposals, 0)

        return rpn_rois
Example #10
0
def sample_with_temperature(logits, sampling_temp, keep_topk):
    """Select next tokens randomly from the top k possible next tokens.

    Samples from a categorical distribution over the ``keep_topk`` words using
    the category probabilities ``logits / sampling_temp``.

    Args:
        logits (FloatTensor): Shaped ``(batch_size, vocab_size)``.
            These can be logits (``(-inf, inf)``) or log-probs (``(-inf, 0]``).
            (The distribution actually uses the log-probabilities
            ``logits - logits.logsumexp(-1)``, which equals the logits if
            they are log-probabilities summing to 1.)
        sampling_temp (float): Used to scale down logits. The higher the
            value, the more likely it is that a non-max word will be
            sampled.
        keep_topk (int): This many words could potentially be chosen. The
            other logits are set to have probability 0.

    Returns:
        (LongTensor, FloatTensor):

        * topk_ids: Shaped ``(batch_size, 1)``. These are
          the sampled word indices in the output vocab.
        * topk_scores: Shaped ``(batch_size, 1)``. These
          are essentially ``(logits / sampling_temp)[topk_ids]``.
    """

    if sampling_temp == 0.0 or keep_topk == 1:
        # For temp=0.0, take the argmax to avoid divide-by-zero errors.
        # keep_topk=1 is also equivalent to argmax.
        topk_scores, topk_ids = logits.topk(1, dim=-1)
        if sampling_temp > 0:
            topk_scores /= sampling_temp
    else:
        logits = torch.div(logits, sampling_temp)

        if keep_topk > 0:
            top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
            kth_best = top_values[:, -1].view([-1, 1])
            kth_best = kth_best.repeat([1, logits.shape[1]]).float()

            # Set all logits that are not in the top-k to -10000.
            # This puts the probabilities close to 0.
            ignore = torch.lt(logits, kth_best)
            logits = logits.masked_fill(ignore, -10000)

        dist = torch.distributions.Multinomial(
            logits=logits, total_count=1)
        topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
        topk_scores = logits.gather(dim=1, index=topk_ids)
    return topk_ids, topk_scores
def run_epoch(loader, model, criterion, optimizer, epoch=0, n_epochs=0, train=True):
    time_meter = Meter(name='Time', cum=True)
    loss_meter = Meter(name='Loss', cum=False)
    error_meter = Meter(name='Error', cum=False)

    if train:
        model.train()
        print('Training')
    else:
        model.eval()
        print('Evaluating')

    end = time.time()
    for i, (input, target) in enumerate(loader):
        if train:
            model.zero_grad()
            optimizer.zero_grad()

        # Forward pass
        input_var = Variable(input, volatile=(not train)).cuda(async=True)
        target_var = Variable(target, volatile=(not train), requires_grad=False).cuda(async=True)
        output_var = model(input_var)
        loss = criterion(output_var, target_var)

        # Backward pass
        if train:
            loss.backward()
            optimizer.step()
            optimizer.n_iters = optimizer.n_iters + 1 if hasattr(optimizer, 'n_iters') else 1

        # Accounting
        _, predictions_var = torch.topk(output_var, 1)
        error = 1 - torch.eq(predictions_var, target_var).float().mean()
        batch_time = time.time() - end
        end = time.time()

        # Log errors
        time_meter.update(batch_time)
        loss_meter.update(loss)
        error_meter.update(error)
        print('  '.join([
            '%s: (Epoch %d of %d) [%04d/%04d]' % ('Train' if train else 'Eval',
                epoch, n_epochs, i + 1, len(loader)),
            str(time_meter),
            str(loss_meter),
            str(error_meter),
        ]))

    return time_meter.value(), loss_meter.value(), error_meter.value()
def predict(line, max_predictions):
    """Give continuation of the line with at most max_predictions BPE tokens. Returns line extended with predictions of
     the model."""

    line_encoded = enc.encode(line)
    line_encoded = torch.tensor(line_encoded)
    line_encoded = line_encoded.unsqueeze_(0) # batch of size 1
    line_encoded_list = list(line_encoded[0].numpy())
    line_encoded = line_encoded.to(device)
    state = None

    for i in range(max_predictions):
        with timeit('forward'):
            logits, state = model(line_encoded, past=state)
        
        #        predicted = argmax(logits[0,-1,:])

        # [[idx1, idx2, ...]]
        with timeit('topk'):
            _, line_encoded_candidates = torch.topk(logits[:,-1,:], k=beam_width, dim=-1)

        # determine which candidates are stopwords by decoding them and
        # comparing against NLTK stopword list
        
        line_encoded_candidates = to_list(line_encoded_candidates[0])
        is_stopword = []
        for s in line_encoded_candidates:
            is_stopword.append(enc.decode([s.item()]).strip() in stopwords)

            
        # find first prediction which is not a stopword
        predicted = None
        for (idx, candidate) in enumerate(line_encoded_candidates):
            if is_stopword[idx]:
                #                print('skipping stopword ', idx)
                continue
            else:
                predicted = candidate
                break
        assert predicted is not None
        line_encoded = torch.tensor([[predicted]]).to(device)
        line_encoded_list.append(predicted)

    return enc.decode(line_encoded_list)
Example #13
0
    def classifier(self, xs):
        """
        classify an image (or a batch of images)

        :param xs: a batch of scaled vectors of pixels from an image
        :return: a batch of the corresponding class labels (as one-hots)
        """
        # use the trained model q(y|x) = categorical(alpha(x))
        # compute all class probabilities for the image(s)
        alpha = self.encoder_y.forward(xs)

        # get the index (digit) that corresponds to
        # the maximum predicted class probability
        res, ind = torch.topk(alpha, 1)

        # convert the digit(s) to one-hot tensor(s)
        ys = Variable(torch.zeros(alpha.size()))
        ys = ys.scatter_(1, ind, 1.0)
        return ys
def beam_search(decoder, decoder_input, encoder_outputs, hidden, max_length, k, target_lang):
    
    candidates = [(decoder_input, 0, hidden)]
    potential_candidates = []
    completed_translations = []

    # put a cap on the length of generated sentences
    for m in range(max_length):
        for c in candidates:
            # unpack the tuple
            c_sequence = c[0]
            c_score = c[1]
            c_hidden = c[2]
            # EOS token
            if c_sequence[-1] == EOS_token:
                completed_translations.append((c_sequence, c_score))
                k = k - 1
            else:
                # pdb.set_trace()
                next_word_probs, hidden = decoder(torch.cuda.LongTensor([c_sequence[-1]]).view(1, 1), torch.cuda.FloatTensor(c_hidden), encoder_outputs, attn_mask = None)
                next_word_probs = next_word_probs[0]
                # in the worst-case, one sequence will have the highest k probabilities
                # so to save computation, only grab the k highest_probability from each candidate sequence
                top_probs, top_idx = torch.topk(next_word_probs, k)
                for i in range(len(top_probs)):
                    word = top_idx[i].reshape(1, 1).to(device)
                    new_score = c_score + top_probs[i]
                    potential_candidates.append((torch.cat((c_sequence, word)).to(device), new_score, hidden))

        candidates = sorted(potential_candidates, key= lambda x: x[1], reverse=True)[0:k] 
        potential_candidates = []

    completed = completed_translations + candidates
    completed = sorted(completed, key= lambda x: x[1], reverse=True)[0] 
    final_translation = []
    for x in completed[0]:
        final_translation.append(target_lang.index2word[x.squeeze().item()])
    return final_translation
def get_concentrated_mask(class_weights, alpha, topk):
    # returns a logical mask, binary for class_weights > alpha
    # AND if class_weights one of the k largest.

    # NOTE: this only works for a vector of class_weights at the moment.

    # boolean vector for where class_weights > alpha
    mask_alpha = (class_weights >= alpha).float().detach()

    # but if there are more than k, only take the topk
    mask_topk = torch.zeros(class_weights.shape).to(device)

    seq_tensor = torch.LongTensor([i for i in range(class_weights.shape[0])])

    if topk > 0:
        _, topk_domain = torch.topk(class_weights, topk)
        # mask_topk[topk_domain] = 1
        # print(topk_domain)
        for i in range(topk):
            mask_topk[seq_tensor, topk_domain[:, i]] = 1
    else:
        topk_domain = None

    return mask_alpha * mask_topk, topk_domain, seq_tensor
def train_batch(input_variable, input_lengths, target_variable, topics, model,
                teacher_forcing_ratio):
    loss_list = []
    # Forward propagation
    prev_generated_seq = None
    target_variable_reshaped = target_variable[:, 1:].contiguous().view(-1)

    for i in range(config.num_exams):
        topics = topics if config.use_topics else None
        decoder_outputs, _, other = \
            model(input_variable, prev_generated_seq, input_lengths,
                   target_variable, teacher_forcing_ratio, topics)

        decoder_outputs_reshaped = decoder_outputs.view(-1, vocab_size)
        lossi = criterion(decoder_outputs_reshaped, target_variable_reshaped)
        loss_list.append(lossi.item())
        if model.training:
            model.zero_grad()
            lossi.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            optimizer.step()
        prev_generated_seq = torch.squeeze(torch.topk(decoder_outputs, 1, dim=2)[1]).view(-1, decoder_outputs.size(1))
        prev_generated_seq = _mask(prev_generated_seq)
    return loss_list
Example #17
0
File: rnn.py Project: Pinafore/qb
        def get_highlights():
            questions = [request.form['text']]
            examples = [self.text_field.preprocess(q) for q in questions]
            padded_examples, lengths = self.text_field.pad(examples)
            padded_examples = np.array(padded_examples, dtype=np.object)
            lengths = np.array(lengths)
            order = np.argsort(-lengths)
            # rev_order = np.argsort(order)
            ordered_examples = padded_examples[order]
            ordered_lengths = lengths[order]
            text, lengths = self.text_field.numericalize((ordered_examples, ordered_lengths), device=-1, train=False)
            lengths = list(lengths.cpu().numpy())

            qanta_ids = self.qanta_id_field.process([0 for _ in questions])  # .cuda()
            hidden_init = self.model.init_hidden(len(questions))
            text = Variable(text.data, volatile=False)

            out, _ = self.model(text, lengths, hidden_init, qanta_ids, extract_grad_hook('embed'))

            guessForEvidence = request.form['guessForEvidence']
            guessForEvidence = guessForEvidence.split("style=\"color:blue\">")[1].split("</a>")[0].lower()
            indicator = -1

            guess = str(guessForEvidence)
            guesses = self.guess([request.form['text']], 500)[0]
            for index, (g, s) in enumerate(guesses):
                print(g.lower().replace("_", " ")[0:25])
                print(guessForEvidence)
                if g.lower().replace("_", " ")[0:25] == guessForEvidence:
                    print("INDICATOR SET")
                    indicator = index
                    guess = g.lower().replace("_", " ")[0:25]
                    break
            if indicator == -1:
                highlights = {
                    'wiki': ['No Evidence', 'No Evidence'],
                    'qb': ['No Evidence', 'No Evidence'],
                    'guess': guess,
                    'visual': 'No Evidence'
                }
                return jsonify(highlights)

            # label = torch.max(out,1)[1]
            label = torch.topk(out, k=500, dim=1)
            label = label[1][0][indicator]  # [0]

            criterion = nn.CrossEntropyLoss()
            loss = criterion(out, label)
            self.model.zero_grad()
            loss.backward()

            grads = extracted_grads['embed'].transpose(0, 1)
            grads = grads.data.cpu()
            scores = grads.sum(dim=2).numpy()
            grads = grads.numpy()
            text = text.transpose(0, 1).data.cpu().numpy()

            scores = scores.tolist()

            normalized_scores = scores
            # normalize scores across the words, doing positive and negatives seperately        
            # final scores should be in range [0,1] 0 is dark red, 1 is dark blue. 0.5 is no highlight
            total_score_pos = 1e-6    # 1e-6 for case where all positive/neg scores are 0
            total_score_neg = 1e-6
            for idx, s in enumerate(normalized_scores):
                s[0] = s[0] * s[0] * s[0] / 5
                if s[0] < 0:
                    total_score_neg = total_score_neg + math.fabs(s[0])
                else:
                    total_score_pos = total_score_pos + s[0]
            for idx, s in enumerate(normalized_scores):
                if s[0] < 0:
                    normalized_scores[idx] = (s[0] / total_score_neg) / 2   # / by 2 to get max of -0.5
                else:
                    normalized_scores[idx] = 0.0
            normalized_scores = [0.5 + n for n in normalized_scores]  # center scores

            returnVal = ""
            for s in normalized_scores:
                returnVal = returnVal + ' ' + str(s)

            localPreprocess = create_qb_tokenizer()
            examples = [localPreprocess(q) for q in questions]
            words = []
            for t in examples[0]:
                words.append(str(t))

            visual = colorize(words, normalized_scores, colors='RdBu')
            print("Guess", guess)
            highlights = {
                'wiki': [returnVal, returnVal],
                'qb': [returnVal, returnVal],
                'guess': guess,
                'visual': visual
            }
            return jsonify(highlights)
Example #18
0
    def parse(self, question, context, beam_size=5):
        table = context
        args = self.args
        src_sent_var = nn_utils.to_input_variable([question], self.vocab.source,
                                                  cuda=self.args.cuda, training=False)

        utterance_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(question)])
        dec_init_vec = self.init_decoder_state(last_state, last_cell)

        column_word_encodings, table_header_encoding, table_header_mask = self.encode_table_header([table])

        h_tm1 = dec_init_vec
        # (batch_size, query_len, hidden_size)
        utterance_encodings_att_linear = self.att_src_linear(utterance_encodings)

        zero_action_embed = Variable(self.new_tensor(self.args.action_embed_size).zero_())

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < self.args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = utterance_encodings.expand(hyp_num, utterance_encodings.size(1), utterance_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = utterance_encodings_att_linear.expand(hyp_num,
                                                                                 utterance_encodings_att_linear.size(1),
                                                                                 utterance_encodings_att_linear.size(2))

            # x: [prev_action, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_action_state]
            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)

                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.hidden_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                a_tm1_embeds = []
                for e_id, hyp in enumerate(hypotheses):
                    action_tm1 = hyp.actions[-1]
                    if action_tm1:
                        if isinstance(action_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
                        elif isinstance(action_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        elif isinstance(action_tm1, WikiSqlSelectColumnAction):
                            a_tm1_embed = self.column_rnn_input(table_header_encoding[0, action_tm1.column_id])
                        elif isinstance(action_tm1, GenTokenAction):
                            a_tm1_embed = self.src_embed.weight[self.vocab.source[action_tm1.token]]
                        else:
                            raise ValueError('unknown action %s' % action_tm1)
                    else:
                        a_tm1_embed = zero_action_embed

                    a_tm1_embeds.append(a_tm1_embed)

                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # ApplyRule action probability
            # (batch_size, grammar_size)
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # column attention
            # (batch_size, max_head_num)
            column_attention_weights = self.column_pointer_net(table_header_encoding, table_header_mask,
                                                               att_t.unsqueeze(0)).squeeze(0)
            column_selection_log_prob = torch.log(column_attention_weights)

            # (batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # primitive copy prob
            # (batch_size, src_token_num)
            primitive_copy_prob = self.src_pointer_net(utterance_encodings, None,
                                                       att_t.unsqueeze(0)).squeeze(0)

            # (batch_size, primitive_vocab_size)
            primitive_gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            new_hyp_meta = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id]
                            new_hyp_score = hyp.score + prod_score

                            meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id,
                                          'score': prod_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)]
                        new_hyp_score = hyp.score + action_score

                        meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar),
                                      'score': action_score, 'new_hyp_score': new_hyp_score,
                                      'prev_hyp_id': hyp_id}
                        new_hyp_meta.append(meta_entry)
                    elif action_type == WikiSqlSelectColumnAction:
                        for col_id, column in enumerate(table.header):
                            col_sel_score = column_selection_log_prob[hyp_id, col_id]
                            new_hyp_score = hyp.score + col_sel_score

                            meta_entry = {'action_type': 'sel_col', 'col_id': col_id,
                                          'score': col_sel_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == GenTokenAction:
                        # remember that we can only copy stuff from the input!
                        # we only copy tokens sequentially!!
                        prev_action = hyp.action_infos[-1].action

                        valid_token_pos_list = []
                        if type(prev_action) is GenTokenAction and \
                                not prev_action.is_stop_signal():
                            token_pos = hyp.action_infos[-1].src_token_position + 1
                            if token_pos < len(question):
                                valid_token_pos_list = [token_pos]
                        else:
                            valid_token_pos_list = list(range(len(question)))

                        col_id = hyp.frontier_node['col_idx'].value
                        if table.header[col_id].type == 'real':
                            valid_token_pos_list = [i for i in valid_token_pos_list
                                                    if any(c.isdigit() for c in question[i]) or
                                                    hyp._value_buffer and question[i] in (',', '.', '-', '%')]

                        p_copies = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id]
                        for token_pos in valid_token_pos_list:
                            token = question[token_pos]
                            p_copy = p_copies[token_pos]
                            score_copy = torch.log(p_copy)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': token, 'token_pos': token_pos,
                                          'score': score_copy, 'new_hyp_score': score_copy + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

                        # add generation probability for </primitive>
                        if hyp._value_buffer:
                            eos_prob = primitive_predictor_prob[hyp_id, 0] * \
                                       primitive_gen_from_vocab_prob[hyp_id, self.vocab.primitive['</primitive>']]
                            eos_score = torch.log(eos_prob)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': '</primitive>',
                                          'score': eos_score, 'new_hyp_score': eos_score + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

            if not new_hyp_meta: break

            new_hyp_scores = torch.cat([x['new_hyp_score'] for x in new_hyp_meta])
            top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
                                                      k=min(new_hyp_scores.size(0),
                                                            beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
                action_info = ActionInfo()
                hyp_meta_entry = new_hyp_meta[meta_id]
                prev_hyp_id = hyp_meta_entry['prev_hyp_id']
                prev_hyp = hypotheses[prev_hyp_id]

                action_type_str = hyp_meta_entry['action_type']
                if action_type_str == 'apply_rule':
                    # ApplyRule action
                    prod_id = hyp_meta_entry['prod_id']
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                elif action_type_str == 'sel_col':
                    action = WikiSqlSelectColumnAction(hyp_meta_entry['col_id'])
                else:
                    action = GenTokenAction(hyp_meta_entry['token'])
                    if 'token_pos' in hyp_meta_entry:
                        action_info.copy_from_src = True
                        action_info.src_token_position = hyp_meta_entry['token_pos']

                action_info.action = action
                action_info.t = t

                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                t += 1
            else: break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
Example #19
0
    def __reduce_sequences_for_bag__(self, inputs, sequence_lengths):
        """ Reduces sequences to top `n_sequences*network_config['sequence_reduction_fraction']` important sequences,
        sorted descending by importance. Reduction is performed using minibatches of network_config['reduction_mb_size']
        sequences.
        
        Parameters
        ----------
        inputs: torch.Tensor
            Input of shape (n_sequences, n_input_features, n_sequence_positions) = (d_k, 20+3, d_l)
        sequence_lengths: torch.Tensor
            Sequences lengths as tensor of dtype torch.long and shape (n_sequences,) = (d_k,)
        
        Returns
        ----------
        reduced_inputs: torch.Tensor
            Top `n_sequences*network_config['sequence_reduction_fraction']` important sequences,
            sorted descending by importance as tensor of shape
            (n_reduced_sequences, n_input_features, n_sequence_positions),
            where `n_reduced_sequences=n_sequences*network_config['sequence_reduction_fraction']`
        reduced_sequence_lengths: torch.Tensor
            Sequences lengths of `reduced_inputs` as tensor of dtype torch.long and shape (n_reduced_sequences,),
            where `n_reduced_sequences=n_sequences*network_config['sequence_reduction_fraction']`
        """
        if self.sequence_reduction_fraction != 1.0:
            # Get number of sequences to reduce to
            n_reduced_sequences = int(sequence_lengths.shape[0] *
                                      self.sequence_reduction_fraction)
            # Get number of minibatches for reduction
            n_mbs = int(np.ceil(inputs.shape[0] / self.reduction_mb_size))
            mb_is = torch.arange(start=0, end=n_mbs, dtype=torch.int)

            # Calculate attention weights for sequences (loop over minibatch of sequences)
            attention_acts = torch.jit.annotate(List[torch.Tensor], [])
            for mb_i in mb_is.unbind(dim=0):
                # Get inputs for current minibatch
                inputs_mb = inputs[mb_i * self.reduction_mb_size:(mb_i + 1) *
                                   self.reduction_mb_size].to(
                                       device=self.device, dtype=torch.float16)

                # Get sequence embedding (h_1)
                emb_seqs = self.sequence_embedding_16bit(inputs_mb).to(
                    dtype=torch.float32)

                # Calculate attention weights before softmax (h_2)
                attention_acts.append(
                    self.attention_nn(emb_seqs).squeeze(dim=-1))

            # Concatenate attention weights for all sequences
            attention_acts = torch.cat(attention_acts, dim=0)

            # Get indices of k sequences with highest attention weights
            _, used_sequences = torch.topk(attention_acts,
                                           n_reduced_sequences,
                                           dim=0,
                                           largest=True,
                                           sorted=True)

            # Get top k sequences and sequence lengths
            reduced_inputs = inputs[used_sequences.to(
                device=self.device)].detach().to(device=self.device,
                                                 dtype=torch.float16)
            reduced_sequence_lengths = \
                sequence_lengths[used_sequences.to(device=self.device)].detach().to(device=self.device,
                                                                                    dtype=torch.float16)
        else:
            with torch.no_grad():
                reduced_inputs = inputs.detach().to(device=self.device,
                                                    dtype=torch.float16)
                reduced_sequence_lengths = sequence_lengths.detach().to(
                    device=self.device, dtype=torch.float16)

        return reduced_inputs, reduced_sequence_lengths
Example #20
0
    def generate(self,
                 encoded,
                 lang_id,
                 max_len=200,
                 sample=False,
                 temperature=None):
        """
        Generate a sentence from a given initial state.
        Input:
            - FloatTensor of size (batch_size, hidden_dim) representing
              sentences encoded in the latent space
        Output:
            - LongTensor of size (seq_len, batch_size), word indices
            - LongTensor of size (batch_size,), sentence x_len
        """
        if self.beam_size > 0:
            return self.generate_beam(encoded, lang_id, self.beam_size,
                                      max_len, sample, temperature)

        encoder_out = encoded.dec_input
        latent = encoder_out['encoder_out']

        x_len = encoded.input_len
        is_cuda = latent.is_cuda
        one_hot = None

        # check inputs
        assert type(lang_id) is int
        assert latent.size() == (x_len.max(), x_len.size(0), self.emb_dim)
        assert (sample is True) ^ (temperature is None)

        # initialize generated sentences batch
        slen, bs = latent.size(0), latent.size(1)
        assert x_len.max() == slen and x_len.size(0) == bs
        cur_len = 1
        decoded = torch.LongTensor(max_len, bs).fill_(self.pad_index)
        unfinished_sents = torch.LongTensor(bs).fill_(1)
        lengths = torch.LongTensor(bs).fill_(1)
        if is_cuda:
            decoded = decoded.cuda()
            unfinished_sents = unfinished_sents.cuda()
            lengths = lengths.cuda()
        decoded[0] = self.bos_index[lang_id]

        incremental_state = {}
        latent_state = 0
        prev_lengths = 0
        while cur_len < max_len:

            # previous word embeddings
            prev_latent_state = copy.deepcopy(latent_state)
            prev_lengths = copy.deepcopy(lengths)
            scores = self.forward(encoded,
                                  decoded[:cur_len],
                                  lang_id,
                                  one_hot,
                                  incremental_state=None)
            latent_state = scores
            scores = scores.x.data[-1, :, :]  # T x B x V -> B x V

            # select next words: sample or one-hot
            if sample:
                next_words = torch.multinomial((scores / temperature).exp(),
                                               1).squeeze(1)
            else:
                next_words = torch.topk(scores, 1)[1].squeeze(1)
            assert next_words.size() == (bs, )
            decoded[
                cur_len] = next_words * unfinished_sents + self.pad_index * (
                    1 - unfinished_sents)
            lengths.add_(unfinished_sents)
            unfinished_sents.mul_(next_words.ne(self.eos_index).long())
            cur_len += 1

            # stop when there is a </s> in each sentence
            if unfinished_sents.max() == 0:
                break

        if cur_len == max_len:
            decoded[max_len - 1].masked_fill_(unfinished_sents.byte(),
                                              self.eos_index)
        assert (decoded == self.eos_index).sum() == bs

        if cur_len == 2:
            prev_latent_state = latent_state
        # one more round is required
        # temp_scores = self.forward(encoded, decoded[:cur_len], lang_id, one_hot, incremental_state=None)
        # now latent output of decoder is temp_score.dec_input['encoder_out']
        # padding mask is decoded[:cur_len].t().eq(self.padding_idx)
        # padding_mask = decoded[:cur_len].t().eq(self.pad_index)
        # since we are making last index of every sentence as eos, we do no need last round
        padding_mask = decoded[:cur_len - 1].t().eq(self.pad_index)
        latent_state = LatentState(
            input_len=prev_lengths,
            dec_input={
                'encoder_out': latent_state.dec_input['encoder_out'],
                'encoder_padding_mask': padding_mask,
            },
            dis_input=latent_state.dis_input,
            x=prev_latent_state.x,
        )
        # if lang_id == 0:
        #     print(lang_id, decoded[0], decoded[-1])
        #     input('Decoder aux generate')
        return latent_state, decoded[:cur_len], lengths, one_hot
Example #21
0
    def sample_beam(self,
                    ingr_features,
                    ingr_mask,
                    beam=3,
                    img_features=None,
                    first_token_value=0,
                    replacement=True,
                    last_token_value=0,
                    device='cpu'):
        k = beam
        alpha = 0.0
        # create dummy previous word
        if ingr_features is not None:
            fs = ingr_features.size(0)
        else:
            fs = img_features.size(0)
        first_word = torch.ones(fs) * first_token_value

        first_word = first_word.to(device).long()

        sequences = [[[first_word], 0, {}, False, 1]]
        finished = []

        for i in range(self.seq_length):
            # forward
            all_candidates = []
            for rem in range(len(sequences)):
                incremental = sequences[rem][2]
                outputs, _ = self.forward(ingr_features, ingr_mask,
                                          torch.stack(sequences[rem][0], 1),
                                          img_features, incremental)
                outputs = outputs.squeeze(1)
                if not replacement:
                    # predicted mask
                    if i == 0:
                        predicted_mask = torch.zeros(
                            outputs.shape).float().to(device)
                    else:
                        # ensure no repetitions in sampling if replacement==False
                        batch_ind = [
                            j for j in range(fs)
                            if sequences[rem][0][i][j] != 0
                        ]
                        sampled_ids_new = sequences[rem][0][i][batch_ind]
                        predicted_mask[batch_ind,
                                       sampled_ids_new] = float('-inf')

                    # mask previously selected ids
                    outputs += predicted_mask

                outputs_prob = torch.nn.functional.log_softmax(outputs, dim=-1)
                probs, indices = torch.topk(outputs_prob, beam)
                # tokens is [batch x beam ] and every element is a list
                # score is [ batch x beam ] and every element is a scalar
                # incremental is [batch x beam ] and every element is a dict

                for bid in range(beam):
                    tokens = sequences[rem][0] + [indices[:, bid]]
                    score = sequences[rem][1] + probs[:, bid].squeeze().item()
                    if indices[:, bid].item() == last_token_value:
                        finished.append([
                            tokens, score, None, True, sequences[rem][-1] + 1
                        ])
                    else:
                        all_candidates.append([
                            tokens, score, incremental, False,
                            sequences[rem][-1] + 1
                        ])

            # if all the top-k scoring beams have finished, we can return them
            ordered_all = sorted(all_candidates + finished,
                                 key=lambda tup: tup[1] /
                                 (np.power(tup[-1], alpha)),
                                 reverse=True)[:k]
            if all(el[-1] == True for el in ordered_all):
                all_candidates = []

            # order all candidates by score
            ordered = sorted(all_candidates,
                             key=lambda tup: tup[1] /
                             (np.power(tup[-1], alpha)),
                             reverse=True)
            # select k best
            sequences = ordered[:k]
            finished = sorted(finished,
                              key=lambda tup: tup[1] /
                              (np.power(tup[-1], alpha)),
                              reverse=True)[:k]

        if len(finished) != 0:
            sampled_ids = torch.stack(finished[0][0][1:], 1)
            logits = finished[0][1]
        else:
            sampled_ids = torch.stack(sequences[0][0][1:], 1)
            logits = sequences[0][1]
        return sampled_ids, logits
    def _generate(
        self,
        encoder_inputs,
        srcs_ids,
        beam_size=None,
        maxlen=None,
        prefix_tokens=None,
        src_weights=None,
    ):
        """Generates a translation from multiple source sentences"""
        n_srcs = len(srcs_ids)
        srcs_tokens = encoder_inputs[0]
        align_src_tokens = srcs_tokens.index_select(0, srcs_ids[self.align_to])

        bsz, srclen = align_src_tokens.size()
        maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen

        # the max beam size is the dictionary size - 1, since we never select pad
        beam_size = beam_size if beam_size is not None else self.beam_size
        assert (
            beam_size < self.vocab_size
        ), "Beam size must be smaller than target vocabulary"

        # Encode
        encoder_outs = self._encode(encoder_inputs, beam_size, srcs_ids)
        incremental_states = self._init_incremental_states(n_srcs)

        # initialize buffers
        scores = align_src_tokens.new(bsz * beam_size, maxlen + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = align_src_tokens.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos

        # may differ from input length
        src_encoding_len = encoder_outs[self.align_to][0][0].size(0)

        attn = scores.new(bsz * beam_size, src_encoding_len, maxlen + 2)
        attn_buf = attn.clone()

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{"idx": None, "score": -math.inf} for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early or step == maxlen or unfinalized_scores is None:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                best_unfinalized_score = unfinalized_scores[sent].max()
                if self.normalize_scores:
                    best_unfinalized_score /= (maxlen + 1) ** self.len_penalty
                if worst_finalized[sent]["score"] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[
                :, 1 : step + 2
            ]  # skip the first index, which is EOS
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1) ** self.len_penalty

            sents_seen = set()
            for i, (idx, score) in enumerate(
                zip(bbsz_idx.tolist(), eos_scores.tolist())
            ):
                sent = idx // beam_size
                sents_seen.add(sent)

                def get_hypo():
                    _, alignment = attn_clone[i].max(dim=0)
                    return {
                        "tokens": tokens_clone[i],
                        "score": score,
                        "attention": attn_clone[i],  # src_len x tgt_len
                        "alignment": alignment,
                        "positional_scores": pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif not self.stop_early and score > worst_finalized[sent]["score"]:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]["idx"]
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(
                        enumerate(finalized[sent]), key=lambda r: r[1]["score"]
                    )
                    worst_finalized[sent] = {"score": s["score"], "idx": idx}

            # return number of hypotheses finished this step
            num_finished = 0
            for sent in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfinalized_scores):
                    finished[sent] = True
                    num_finished += 1
            return num_finished

        reorder_state = None
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                for model_id, model in enumerate(self.models):
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
                        for src_id in range(n_srcs):
                            model.decoder.reorder_incremental_state(
                                incremental_states[(src_id, model_id)], reorder_state
                            )
            # Run decoder for one step
            logprobs, avg_attn, possible_translation_tokens = self._decode(
                tokens[:, : step + 1], encoder_outs, incremental_states, n_srcs
            )

            if step == 0:
                # at the first step all hypotheses are equally likely, so use
                # only the first beam
                logprobs = logprobs.unfold(0, 1, beam_size).squeeze(2).contiguous()
                scores = scores.type_as(logprobs)
                scores_buf = scores_buf.type_as(logprobs)
            else:
                # make probs contain cumulative scores for each hypothesis
                logprobs.add_(scores[:, step - 1].view(-1, 1))
            logprobs[:, self.pad] = -math.inf  # never select pad

            # apply unk reward
            if possible_translation_tokens is None:
                unk_index = self.unk
            else:
                unk_index = torch.nonzero(possible_translation_tokens == self.unk)[0, 0]
            logprobs[:, unk_index] += self.unk_reward

            # external lexicon reward
            logprobs[:, self.lexicon_indices] += self.lexicon_reward

            logprobs += self.word_reward
            logprobs[:, self.eos] -= self.word_reward

            # Record attention scores
            attn[:, :, step + 1].copy_(avg_attn)

            cand_scores = buffer("cand_scores", type_of=scores)
            cand_indices = buffer("cand_indices")
            cand_beams = buffer("cand_beams")
            eos_bbsz_idx = buffer("eos_bbsz_idx")
            eos_scores = buffer("eos_scores", type_of=scores)
            if step < maxlen:
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    logprobs_slice = logprobs.view(bsz, -1, logprobs.size(-1))[:, 0, :]
                    cand_scores = torch.gather(
                        logprobs_slice, dim=1, index=prefix_tokens[:, step].view(-1, 1)
                    ).expand(-1, cand_size)
                    cand_indices = (
                        prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size)
                    )
                    cand_beams.resize_as_(cand_indices).fill_(0)
                else:
                    # take the best 2 x beam_size predictions. We'll choose the first
                    # beam_size of these which don't predict eos to continue with.
                    torch.topk(
                        logprobs.view(bsz, -1),
                        k=min(
                            cand_size, logprobs.view(bsz, -1).size(1) - 1
                        ),  # -1 so we never select pad
                        out=(cand_scores, cand_indices),
                    )

                    possible_tokens_size = self.vocab_size
                    if possible_translation_tokens is not None:
                        possible_tokens_size = possible_translation_tokens.size(0)
                    # cand_indices has values in [0, vocab_size * beam_size]
                    # the following does euclidean division bu vocab_size
                    # to retrieve the beam and word id of each candidate
                    torch.div(cand_indices, possible_tokens_size, out=cand_beams)
                    cand_indices.fmod_(possible_tokens_size)
                    # Handle vocab reduction
                    if possible_translation_tokens is not None:
                        possible_translation_tokens = possible_translation_tokens.view(
                            1, possible_tokens_size
                        ).expand(cand_indices.size(0), possible_tokens_size)
                        cand_indices = torch.gather(
                            possible_translation_tokens,
                            dim=1,
                            index=cand_indices,
                            out=cand_indices,
                        )
            else:
                # finalize all active hypotheses once we hit maxlen
                # pick the hypothesis with the highest log prob of EOS right now
                torch.sort(
                    logprobs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
                num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)
                assert num_remaining_sent == 0
                break

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add_(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)
            if step >= self.minlen:
                # only consider eos when it's among the top beam_size indices
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
                if eos_bbsz_idx.numel() > 0:
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
                    num_remaining_sent -= finalize_hypos(
                        step, eos_bbsz_idx, eos_scores, cand_scores
                    )

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < maxlen

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer("active_mask")
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[: eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer("active_hypos"), buffer("_ignore")
            torch.topk(
                active_mask,
                k=beam_size,
                dim=1,
                largest=False,
                out=(_ignore, active_hypos),
            )
            active_bbsz_idx = buffer("active_bbsz_idx")
            torch.gather(cand_bbsz_idx, dim=1, index=active_hypos, out=active_bbsz_idx)
            active_scores = torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )
            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, : step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, : step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            torch.index_select(
                attn[:, :, : step + 2],
                dim=0,
                index=active_bbsz_idx,
                out=attn_buf[:, :, : step + 2],
            )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(bsz):
            finalized[sent] = sorted(
                finalized[sent], key=lambda r: r["score"], reverse=True
            )

        return finalized
 def forward(self, x, k):
     return torch.topk(x, k)
Example #24
0
    def detect(self, img: torch.Tensor,
               num_feats: int) -> Tuple[torch.Tensor, torch.Tensor]:
        sp, sigmas, pix_dists = self.scale_pyr(img)
        all_responses = []
        all_lafs = []
        for oct_idx, octave in enumerate(sp):
            sigmas_oct = sigmas[oct_idx]
            pix_dists_oct = pix_dists[oct_idx]
            B, L, CH, H, W = octave.size()
            # Run response function
            oct_resp = self.resp(octave.view(B * L, CH, H, W),
                                 sigmas_oct.view(-1)).view(B, L, CH, H, W)

            # We want nms for scale responses, so reorder to (B, CH, L, H, W)
            oct_resp = oct_resp.permute(0, 2, 1, 3, 4)

            # Differentiable nms
            coord_max, response_max = self.nms(oct_resp)

            # Now, lets crop out some small responses
            responses_flatten = response_max.view(response_max.size(0),
                                                  -1)  # [B * N, 3]
            max_coords_flatten = coord_max.view(response_max.size(0), 3,
                                                -1).permute(0, 2,
                                                            1)  # [B, N, 3]

            if responses_flatten.size(1) > num_feats:
                resp_flat_best, idxs = torch.topk(responses_flatten,
                                                  k=num_feats,
                                                  dim=1)
                max_coords_best = torch.gather(
                    max_coords_flatten, 1,
                    idxs.unsqueeze(-1).repeat(1, 1, 3))
            else:
                resp_flat_best = responses_flatten
                max_coords_best = max_coords_flatten
            B, N = resp_flat_best.size()

            # Converts scale level index from ConvSoftArgmax3d to the actual scale, using the sigmas
            max_coords_best = _scale_index_to_scale(max_coords_best,
                                                    sigmas_oct)

            # Create local affine frames (LAFs)
            rotmat = angle_to_rotation_matrix(
                torch.zeros(B, N).to(max_coords_best.device).to(
                    max_coords_best.dtype))
            current_lafs = torch.cat([
                self.mr_size * max_coords_best[:, :, 0].view(B, N, 1, 1) *
                rotmat, max_coords_best[:, :, 1:3].view(B, N, 2, 1)
            ],
                                     dim=3)
            # Normalize LAFs
            current_lafs = normalize_laf(
                current_lafs,
                octave[:, 0])  # We don`t need # of scale levels, only shape

            all_responses.append(resp_flat_best)
            all_lafs.append(current_lafs)

        # Sort and keep best n
        responses: torch.Tensor = torch.cat(all_responses, dim=1)
        lafs: torch.Tensor = torch.cat(all_lafs, dim=1)
        responses, idxs = torch.topk(responses, k=num_feats, dim=1)
        lafs = torch.gather(
            lafs, 1,
            idxs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 2, 3))
        return responses, denormalize_laf(lafs, img)
Example #25
0
        output = model(img_tens)

    fps = 1.0 / (time.perf_counter() - fps_time)
    print("Net FPS: %f" % (fps))
    tfps += fps
    count += 1

    im2 = img.copy()
    drw = ImageDraw.Draw(im2)
    pred_logits = output['pred_logits'][0]
    pred_boxes = output['pred_boxes'][0]

    for logits, box in zip(pred_logits, pred_boxes):
        m = th.nn.Softmax(dim=0)
        prob = m(logits)
        top3 = th.topk(logits, 3)
        if top3.indices[0] >= len(CLASSES) or prob[
                top3.indices[0]] < args.threshold:
            continue

        print(' ===== print top3 values =====')
        print('top3', top3)
        print('top 1: Label[%-20s]  probability[%5.3f]' %
              (CLASSES[top3.indices[0]], prob[top3.indices[0]] * 100))
        if top3.indices[1] < len(CLASSES):
            print('top 2: Label[%-20s]  probability[%5.3f]' %
                  (CLASSES[top3.indices[1]], prob[top3.indices[1]] * 100))
        if top3.indices[2] < len(CLASSES):
            print('top 3: Label[%-20s]  probability[%5.3f]' %
                  (CLASSES[top3.indices[2]], prob[top3.indices[2]] * 100))
# encoding all unique sentences present in the training dataset
embeddings = semantic_search_model.encode(sentences, batch_size=batch_size, convert_to_tensor=True)

logging.info("Retrieve top-{} with semantic search model: {}".format(top_k, semantic_model_name))

# retrieving top-k sentences given a sentence from the dataset
progress = tqdm.tqdm(unit="docs", total=len(sent2idx))
for idx in range(len(sentences)):
    sentence_embedding = embeddings[idx]
    cos_scores = util.pytorch_cos_sim(sentence_embedding, embeddings)[0]
    cos_scores = cos_scores.cpu()
    progress.update(1)

    #We use torch.topk to find the highest 5 scores
    top_results = torch.topk(cos_scores, k=top_k+1)
    
    for score, iid in zip(top_results[0], top_results[1]):
        if iid != idx and (iid, idx) not in duplicates:
            silver_data.append((sentences[idx], sentences[iid]))
            duplicates.add((idx,iid))

progress.reset()
progress.close()

logging.info("Length of silver_dataset generated: {}".format(len(silver_data)))
logging.info("Step 2.2: Label STSbenchmark (silver dataset) with cross-encoder: {}".format(model_name))
cross_encoder = CrossEncoder(cross_encoder_path)
silver_scores = cross_encoder.predict(silver_data)

# All model predictions should be between [0,1]
Example #27
0
    def _generate(self,
                  model,
                  sample,
                  prefix_tokens=None,
                  bos_token=None,
                  **kwargs):
        if not self.retain_dropout:
            model.eval()

        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v
            for k, v in sample['net_input'].items()
            if k != 'prev_output_tokens'
        }

        src_tokens = encoder_input['src_tokens']
        src_lengths = (src_tokens.ne(self.eos)
                       & src_tokens.ne(self.pad)).long().sum(dim=1)
        input_size = src_tokens.size()
        # batch dimension goes first followed by source lengths
        bsz = input_size[0]
        src_len = input_size[1]
        beam_size = self.beam_size
        self.no_repeat_ngram_op = NGramRepeatBlock()

        if self.match_source_len:
            max_len = src_lengths.max().item()
        else:
            max_len = min(
                int(self.max_len_a * src_len + self.max_len_b),
                # exclude the EOS marker
                model.max_decoder_positions() - 1,
            )

        # compute the encoder output for each beam
        encoder_outs = model.forward_encoder(encoder_input)
        new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
        new_order = new_order.to(src_tokens.device).long()
        encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)

        # initialize buffers
        scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens.new(bsz * beam_size,
                                max_len + 2).long().fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos if bos_token is None else bos_token
        attn, attn_buf = None, None

        # The blacklist indicates candidates that should be ignored.
        # For example, suppose we're sampling and have already finalized 2/5
        # samples. Then the blacklist would mark 2 positions as being ignored,
        # so that we only finalize the remaining 3 samples.
        blacklist = src_tokens.new_zeros(bsz, beam_size).eq(
            -1)  # forward and backward-compatible False mask

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
                        beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfin_idx):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size or step == max_len:
                return True
            return False

        def apply_no_repeat_ngram_cpu(self, tokens,lprobs, bsz,step,
                beam_size, no_repeat_ngram_size):
            """ Fairseq implementation of blocking
                repeated ngrams
            """
            banned_list = [[] for bbsz_idx in range(bsz * beam_size)]
            cpu_tokens = tokens.cpu()[:, :step + 1].numpy()
            check_start_pos = step + 2 - no_repeat_ngram_size
            for bbsz_idx in range(bsz * beam_size):
                for i in range(check_start_pos):
                    is_banned = True
                    for k in range(no_repeat_ngram_size - 1):
                        if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[
                            bbsz_idx, check_start_pos + k]:
                            is_banned = False
                            break
                    if is_banned:
                        banned_list[bbsz_idx].append(
                            cpu_tokens[bbsz_idx,
                                       i + no_repeat_ngram_size - 1])

            def calculate_banned_tokens(bbsz_idx):
                """before decoding the next token, prevent decoding
                of ngrams that have already appeared
                """
                banned_tokens_per_sample = [
                    (bbsz_idx, t) for t in banned_list[bbsz_idx]
                ]
                return banned_tokens_per_sample

            banned_tokens = []
            if step + 2 - no_repeat_ngram_size >= 0:
                for bbsz_idx in range(bsz * beam_size):
                    banned_tokens.extend(calculate_banned_tokens(bbsz_idx))

            if banned_tokens:
                banned_tokens = torch.LongTensor(banned_tokens)
                lprobs.index_put_(
                    tuple(banned_tokens.t()),
                    lprobs.new_tensor([-math.inf] * len(banned_tokens)))

            return lprobs

        def finalize_hypos(step, bbsz_idx, eos_scores):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            assert not tokens_clone.eq(self.eos).any()
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(
                0, bbsz_idx)[:, :, 1:step + 2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1)**self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

            sents_seen = set()
            for i, (idx, score) in enumerate(
                zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                if self.match_source_len and step > src_lengths[unfin_idx]:
                    score = -math.inf

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i]
                    else:
                        hypo_attn = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': None,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfin_idx):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished

        reorder_state = None
        batch_idxs = None
        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(
                        batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(
                        corr.unsqueeze(-1) * beam_size)
                model.reorder_incremental_state(reorder_state)
                encoder_outs = model.reorder_encoder_out(
                    encoder_outs, reorder_state)

            lprobs, avg_attn_scores = model.forward_decoder(
                tokens[:, :step + 1],
                encoder_outs,
                temperature=self.temperature,
            )

            lprobs[:, self.pad] = -math.inf  # never select pad
            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty

            # handle max length constraint
            if step >= max_len:
                lprobs[:, :self.eos] = -math.inf
                lprobs[:, self.eos + 1:] = -math.inf

            # handle prefix tokens (possibly with different lengths)
            if prefix_tokens is not None and step < prefix_tokens.size(
                1) and step < max_len:
                prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(
                    1, beam_size).view(-1)
                prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
                prefix_mask = prefix_toks.ne(self.pad)
                lprobs[prefix_mask] = -math.inf
                lprobs[prefix_mask] = lprobs[prefix_mask].scatter_(
                    -1, prefix_toks[prefix_mask].unsqueeze(-1),
                    prefix_lprobs[prefix_mask])
                # if prefix includes eos, then we should make sure tokens and
                # scores are the same across all beams
                eos_mask = prefix_toks.eq(self.eos)
                if eos_mask.any():
                    # validate that the first beam matches the prefix
                    first_beam = tokens[eos_mask].view(
                        -1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
                    eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
                    target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
                    assert (first_beam == target_prefix).all()

                    def replicate_first_beam(tensor, mask):
                        tensor = tensor.view(-1, beam_size, tensor.size(-1))
                        tensor[mask] = tensor[mask][:, :1, :]
                        return tensor.view(-1, tensor.size(-1))

                    # copy tokens, scores and lprobs from the first beam to all beams
                    tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
                    scores = replicate_first_beam(scores, eos_mask_batch_dim)
                    lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim)
            elif step < self.min_len:
                # minimum length constraint (does not apply if using prefix_tokens)
                lprobs[:, self.eos] = -math.inf

            # Record attention scores
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, src_tokens.size(1),
                                      max_len + 2)
                    attn_buf = attn.clone()
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)

            self.search.set_src_lengths(src_lengths)

            if self.no_repeat_ngram_size > 0:
                #Applying Cuda Op for NGram repeat Blocking
                if (tokens.is_cuda and lprobs.is_cuda):
                    lprobs = self.no_repeat_ngram_op(tokens,lprobs, bsz, step,
                            beam_size, self.no_repeat_ngram_size)
                else:
                    lprobs = apply_no_repeat_ngram_cpu(tokens, lprobs, bsz,
                                step, beam_size, self.ngram_repeat_block_size)

            cand_scores, cand_indices, cand_beams = self.search.step(
                step,
                lprobs.view(bsz, -1, self.vocab_size),
                scores.view(bsz, beam_size, -1)[:, :, :step],
            )

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos, except for blacklisted ones
            # or candidates with a score of -inf
            eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
            eos_mask[:, :beam_size][blacklist] = 0

            # only consider eos when it's among the top beam_size indices
            eos_bbsz_idx = torch.masked_select(
                cand_bbsz_idx[:, :beam_size],
                mask=eos_mask[:, :beam_size],
            )

            finalized_sents = set()
            if eos_bbsz_idx.numel() > 0:
                eos_scores = torch.masked_select(
                    cand_scores[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                )
                finalized_sents = finalize_hypos(step, eos_bbsz_idx,
                                                 eos_scores)
                num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < max_len

            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = torch.nonzero(batch_mask).squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths = src_lengths[batch_idxs]
                blacklist = blacklist[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(
                    new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(
                        new_bsz * beam_size, attn.size(1), -1)
                    attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

            # Set active_mask so that values > cand_size indicate eos or
            # blacklisted hypos and values < cand_size indicate candidate
            # active hypos. After this, the min values per row are the top
            # candidate active hypos.
            active_mask = buffer('active_mask')
            eos_mask[:, :beam_size] |= blacklist
            active_mask = torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, new_blacklist = buffer('active_hypos'), buffer(
                'new_blacklist')
            torch.topk(active_mask,
                       k=beam_size,
                       dim=1,
                       largest=False,
                       out=(new_blacklist, active_hypos))

            # update blacklist to ignore any finalized hypos
            blacklist = new_blacklist.ge(cand_size)[:, :beam_size]
            assert (~blacklist).any(dim=1).all()

            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx,
                dim=1,
                index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            if attn is not None:
                torch.index_select(
                    attn[:, :, :step + 2],
                    dim=0,
                    index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],
                )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent_id, _ in enumerate(finalized):
            finalized[sent_id] = sorted(finalized[sent_id],
                                        key=lambda r: r['score'],
                                        reverse=True)
        return finalized
Example #28
0
def global_topk(input, k, largest):
    # https://stackoverflow.com/questions/64241325/top-k-indices-of-a-multi-dimensional-tensor
    v, i = th.topk(input.flatten(), k, largest=largest)
    return np.array(np.unravel_index(i.cpu().numpy(), input.shape)).T.tolist()
Example #29
0
def main(args):
    data_path = args['dataset']['path']
    train_batch_size = args['model']['train_batch_size']
    val_batch_size = args['model']['val_batch_size']
    check_iter = args['model']['check_iter']
    model_save_path = args['model']['model_save_path']
    pretrained_model = args['model']['pretrained_model']
    compression_model = args['dataset']['grid_size'][2]
    grid_size = args['dataset']['grid_size']
    visibility = args['model']['visibility']
    pytorch_device = torch.device('cuda:0')
    if args['model']['polar']:
        fea_dim = 9
        circular_padding = True
    else:
        fea_dim = 7
        circular_padding = False

    #prepare miou fun
    unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
    unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1]

    #prepare model
    my_BEV_model = BEV_Unet(n_class=len(unique_label),
                            n_height=compression_model,
                            input_batch_norm=True,
                            dropout=0.5,
                            circular_padding=circular_padding,
                            use_vis_fea=visibility)
    my_model = ptBEVnet(my_BEV_model,
                        pt_model='pointnet',
                        grid_size=grid_size,
                        fea_dim=fea_dim,
                        max_pt_per_encode=256,
                        out_pt_fea_dim=512,
                        kernal_size=1,
                        pt_selection='random',
                        fea_compre=compression_model)
    if os.path.exists(model_save_path):
        my_model = load_pretrained_model(my_model, torch.load(model_save_path))
    elif os.path.exists(pretrained_model):
        my_model = load_pretrained_model(my_model,
                                         torch.load(pretrained_model))
    my_model.to(pytorch_device)

    optimizer = optim.Adam(my_model.parameters())
    loss_fn = panoptic_loss(center_loss_weight = args['model']['center_loss_weight'], offset_loss_weight = args['model']['offset_loss_weight'],\
                            center_loss = args['model']['center_loss'], offset_loss=args['model']['offset_loss'])

    #prepare dataset
    train_pt_dataset = SemKITTI(
        data_path + '/sequences/',
        imageset='train',
        return_ref=True,
        instance_pkl_path=args['dataset']['instance_pkl_path'])
    val_pt_dataset = SemKITTI(
        data_path + '/sequences/',
        imageset='val',
        return_ref=True,
        instance_pkl_path=args['dataset']['instance_pkl_path'])
    if args['model']['polar']:
        train_dataset = spherical_dataset(train_pt_dataset,
                                          args['dataset'],
                                          grid_size=grid_size,
                                          ignore_label=0,
                                          use_aug=True)
        val_dataset = spherical_dataset(val_pt_dataset,
                                        args['dataset'],
                                        grid_size=grid_size,
                                        ignore_label=0)
    else:
        train_dataset = voxel_dataset(train_pt_dataset,
                                      args['dataset'],
                                      grid_size=grid_size,
                                      ignore_label=0,
                                      use_aug=True)
        val_dataset = voxel_dataset(val_pt_dataset,
                                    args['dataset'],
                                    grid_size=grid_size,
                                    ignore_label=0)
    train_dataset_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=train_batch_size,
        collate_fn=collate_fn_BEV,
        shuffle=True,
        num_workers=4)
    val_dataset_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                     batch_size=val_batch_size,
                                                     collate_fn=collate_fn_BEV,
                                                     shuffle=False,
                                                     num_workers=4)

    # training
    epoch = 0
    best_val_PQ = 0
    start_training = False
    my_model.train()
    global_iter = 0
    exce_counter = 0
    evaluator = PanopticEval(len(unique_label) + 1, None, [0], min_points=50)

    while epoch < args['model']['max_epoch']:
        pbar = tqdm(total=len(train_dataset_loader))
        for i_iter, (train_vox_fea, train_label_tensor, train_gt_center,
                     train_gt_offset, train_grid, _, _,
                     train_pt_fea) in enumerate(train_dataset_loader):
            # validation
            if global_iter % check_iter == 0:
                my_model.eval()
                evaluator.reset()
                with torch.no_grad():
                    for i_iter_val, (
                            val_vox_fea, val_vox_label, val_gt_center,
                            val_gt_offset, val_grid, val_pt_labels,
                            val_pt_ints,
                            val_pt_fea) in enumerate(val_dataset_loader):
                        val_vox_fea_ten = val_vox_fea.to(pytorch_device)
                        val_vox_label = SemKITTI2train(val_vox_label)
                        val_pt_fea_ten = [
                            torch.from_numpy(i).type(
                                torch.FloatTensor).to(pytorch_device)
                            for i in val_pt_fea
                        ]
                        val_grid_ten = [
                            torch.from_numpy(i[:, :2]).to(pytorch_device)
                            for i in val_grid
                        ]
                        val_label_tensor = val_vox_label.type(
                            torch.LongTensor).to(pytorch_device)
                        val_gt_center_tensor = val_gt_center.to(pytorch_device)
                        val_gt_offset_tensor = val_gt_offset.to(pytorch_device)

                        if visibility:
                            predict_labels, center, offset = my_model(
                                val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)
                        else:
                            predict_labels, center, offset = my_model(
                                val_pt_fea_ten, val_grid_ten)

                        for count, i_val_grid in enumerate(val_grid):
                            # get foreground_mask
                            for_mask = torch.zeros(
                                1,
                                grid_size[0],
                                grid_size[1],
                                grid_size[2],
                                dtype=torch.bool).to(pytorch_device)
                            for_mask[0, val_grid[count][:, 0],
                                     val_grid[count][:, 1],
                                     val_grid[count][:, 2]] = True
                            # post processing
                            panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),\
                                                                                      val_pt_dataset.thing_list, threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
                                                                                      top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
                            panoptic_labels = panoptic_labels.cpu().detach(
                            ).numpy().astype(np.int32)
                            panoptic = panoptic_labels[0, val_grid[count][:,
                                                                          0],
                                                       val_grid[count][:, 1],
                                                       val_grid[count][:, 2]]
                            evaluator.addBatch(
                                panoptic & 0xFFFF, panoptic,
                                np.squeeze(val_pt_labels[count]),
                                np.squeeze(val_pt_ints[count]))
                        del val_vox_label, val_pt_fea_ten, val_label_tensor, val_grid_ten, val_gt_center, val_gt_center_tensor, val_gt_offset, val_gt_offset_tensor, predict_labels, center, offset, panoptic_labels, center_points
                my_model.train()
                class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ(
                )
                miou, ious = evaluator.getSemIoU()
                print('Validation per class PQ, SQ, RQ and IoU: ')
                for class_name, class_pq, class_sq, class_rq, class_iou in zip(
                        unique_label_str, class_all_PQ[1:], class_all_SQ[1:],
                        class_all_RQ[1:], ious[1:]):
                    print('%15s : %6.2f%%  %6.2f%%  %6.2f%%  %6.2f%%' %
                          (class_name, class_pq * 100, class_sq * 100,
                           class_rq * 100, class_iou * 100))
                # save model if performance is improved
                if best_val_PQ < class_PQ:
                    best_val_PQ = class_PQ
                    torch.save(my_model.state_dict(), model_save_path)
                print('Current val PQ is %.3f while the best val PQ is %.3f' %
                      (class_PQ * 100, best_val_PQ * 100))
                print('Current val miou is %.3f' % (miou * 100))

                if start_training:
                    sem_l, hm_l, os_l = np.mean(
                        loss_fn.lost_dict['semantic_loss']), np.mean(
                            loss_fn.lost_dict['heatmap_loss']), np.mean(
                                loss_fn.lost_dict['offset_loss'])
                    print(
                        'epoch %d iter %5d, loss: %.3f, semantic loss: %.3f, heatmap loss: %.3f, offset loss: %.3f\n'
                        % (epoch, i_iter, sem_l + hm_l + os_l, sem_l, hm_l,
                           os_l))
                print('%d exceptions encountered during last training\n' %
                      exce_counter)
                exce_counter = 0
                loss_fn.reset_loss_dict()

            # training
            try:
                train_vox_fea_ten = train_vox_fea.to(pytorch_device)
                train_label_tensor = SemKITTI2train(train_label_tensor)
                train_pt_fea_ten = [
                    torch.from_numpy(i).type(
                        torch.FloatTensor).to(pytorch_device)
                    for i in train_pt_fea
                ]
                train_grid_ten = [
                    torch.from_numpy(i[:, :2]).to(pytorch_device)
                    for i in train_grid
                ]
                train_label_tensor = train_label_tensor.type(
                    torch.LongTensor).to(pytorch_device)
                train_gt_center_tensor = train_gt_center.to(pytorch_device)
                train_gt_offset_tensor = train_gt_offset.to(pytorch_device)

                if args['model']['enable_SAP'] and epoch >= args['model'][
                        'SAP']['start_epoch']:
                    for fea in train_pt_fea_ten:
                        fea.requires_grad_()

                # forward
                if visibility:
                    sem_prediction, center, offset = my_model(
                        train_pt_fea_ten, train_grid_ten, train_vox_fea_ten)
                else:
                    sem_prediction, center, offset = my_model(
                        train_pt_fea_ten, train_grid_ten)
                # loss
                loss = loss_fn(sem_prediction, center, offset,
                               train_label_tensor, train_gt_center_tensor,
                               train_gt_offset_tensor)

                # self adversarial pruning
                if args['model']['enable_SAP'] and epoch >= args['model'][
                        'SAP']['start_epoch']:
                    loss.backward()
                    for i, fea in enumerate(train_pt_fea_ten):
                        fea_grad = torch.norm(fea.grad, dim=1)
                        top_k_grad, _ = torch.topk(
                            fea_grad,
                            int(args['model']['SAP']['rate'] *
                                fea_grad.shape[0]))
                        # delete high influential points
                        train_pt_fea_ten[i] = train_pt_fea_ten[i][
                            fea_grad < top_k_grad[-1]]
                        train_grid_ten[i] = train_grid_ten[i][
                            fea_grad < top_k_grad[-1]]
                    optimizer.zero_grad()

                    # second pass
                    # forward
                    if visibility:
                        sem_prediction, center, offset = my_model(
                            train_pt_fea_ten, train_grid_ten,
                            train_vox_fea_ten)
                    else:
                        sem_prediction, center, offset = my_model(
                            train_pt_fea_ten, train_grid_ten)
                    # loss
                    loss = loss_fn(sem_prediction, center, offset,
                                   train_label_tensor, train_gt_center_tensor,
                                   train_gt_offset_tensor)

                # backward + optimize
                loss.backward()
                optimizer.step()
            except Exception as error:
                if exce_counter == 0:
                    print(error)
                exce_counter += 1

            # zero the parameter gradients
            optimizer.zero_grad()
            pbar.update(1)
            start_training = True
            global_iter += 1
        pbar.close()
        epoch += 1
Example #30
0
    def generate_beam(self, src_enc, src_len, tgt_lang_id, beam_size, length_penalty, early_stopping, max_len=200):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """

        # check inputs
        assert src_enc.size(0) == src_len.size(0)
        assert beam_size >= 1

        # batch size / number of words
        bs = len(src_len)
        n_words = self.n_words

        # expand to beam size the source latent representations / source lengths
        src_enc = src_enc.unsqueeze(1).expand((bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size,) + src_enc.shape[1:])
        src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1)

        # generated sentences (batch with beam current hypotheses)
        generated = src_len.new(max_len, bs * beam_size)  # upcoming output
        generated.fill_(self.pad_index)                   # fill upcoming ouput with <PAD>
        generated[0].fill_(self.eos_index)                # we use <EOS> for <BOS> everywhere

        # generated hypotheses
        generated_hyps = [BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs)]

        # positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated)

        # language IDs
        langs = positions.clone().fill_(tgt_lang_id)

        # scores for each sentence in the beam
        beam_scores = src_enc.new(bs, beam_size).fill_(0)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)

        # current position
        cur_len = 1

        # cache compute states
        cache = {'slen': 0}

        # done sentences
        done = [False for _ in range(bs)]

        while cur_len < max_len:

            # compute word scores
            tensor = self.forward(
                'fwd',
                x=generated[:cur_len],
                lengths=src_len.new(bs * beam_size).fill_(cur_len),
                positions=positions[:cur_len],
                langs=langs[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                cache=cache
            )
            assert tensor.size() == (1, bs * beam_size, self.dim)
            tensor = tensor.data[-1, :, :]               # (bs * beam_size, dim)
            scores = self.pred_layer.get_scores(tensor)  # (bs * beam_size, n_words)
            scores = F.log_softmax(scores, dim=-1)       # (bs * beam_size, n_words)
            assert scores.size() == (bs * beam_size, n_words)

            # select next words with scores
            _scores = scores + beam_scores[:, None].expand_as(scores)  # (bs * beam_size, n_words)
            _scores = _scores.view(bs, beam_size * n_words)            # (bs, beam_size * n_words)

            next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
            assert next_scores.size() == next_words.size() == (bs, 2 * beam_size)

            # next batch beam content
            # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for sent_id in range(bs):

                # if we are done with this sentence
                done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
                if done[sent_id]:
                    next_batch_beam.extend([(0, self.pad_index, 0)] * beam_size)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next words for this sentence
                for idx, value in zip(next_words[sent_id], next_scores[sent_id]):

                    # get beam and word IDs
                    beam_id = idx // n_words
                    word_id = idx % n_words

                    # end of sentence, or next word
                    if word_id == self.eos_index or cur_len + 1 == max_len:
                        generated_hyps[sent_id].add(generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item())
                    else:
                        next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))

                    # the beam for next step is full
                    if len(next_sent_beam) == beam_size:
                        break

                # update next beam content
                assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, self.pad_index, 0)] * beam_size  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == beam_size * (sent_id + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == bs * beam_size
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = src_len.new([x[2] for x in next_batch_beam])

            # re-order batch and internal states
            generated = generated[:, beam_idx]
            generated[cur_len] = beam_words
            for k in cache.keys():
                if k != 'slen':
                    cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])

            # update current length
            cur_len = cur_len + 1

            # stop when we are done with each sentence
            if all(done):
                break

        # visualize hypotheses
        # print([len(x) for x in generated_hyps], cur_len)
        # globals().update( locals() );
        # !import code; code.interact(local=vars())
        # for ii in range(bs):
        #     for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
        #         print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
        #     print("")

        # select the best hypotheses
        tgt_len = src_len.new(bs)
        best = []

        for i, hypotheses in enumerate(generated_hyps):
            best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
            tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            best.append(best_hyp)

        # generate target batch
        decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index)
        for i, hypo in enumerate(best):
            decoded[:tgt_len[i] - 1, i] = hypo
            decoded[tgt_len[i] - 1, i] = self.eos_index

        # sanity check
        assert (decoded == self.eos_index).sum() == 2 * bs

        return decoded, tgt_len
Example #31
0
    def decode_beamsearch(self, input, u_len, w_len, decode_dict):
        """
        this method is meant to be used at inference time
            input = input to the encoder
            u_len = utterance lengths
            w_len = word lengths
            decode_dict:
                - k                = beamwidth for beamsearch
                - batch_size       = batch_size
                - time_step        = max_summary_length
                - vocab_size       = 30522 for BERT
                - device           = cpu or cuda
                - start_token_id   = ID of the start token
                - stop_token_id    = ID of the stop token
                - alpha            = length normalisation
                - length_offset    = length offset
                - keypadmask_dtype = torch.bool
        """
        k = decode_dict['k']
        search_method = decode_dict['search_method']
        batch_size = decode_dict['batch_size']
        time_step = decode_dict['time_step']
        vocab_size = decode_dict['vocab_size']
        device = decode_dict['device']
        start_token_id = decode_dict['start_token_id']
        stop_token_id = decode_dict['stop_token_id']
        alpha = decode_dict['alpha']
        penalty_ug = decode_dict['penalty_ug']
        keypadmask_dtype = decode_dict['keypadmask_dtype']

        # create beam array & scores
        beams = [None for _ in range(k)]
        beam_scores = np.zeros((batch_size, k))

        # we should only feed through the encoder just once!!
        enc_output_dict = self.encoder(input, u_len, w_len)  # memory
        u_output = enc_output_dict['u_output']

        # we run the decoder time_step times (auto-regressive)
        tgt_ids = torch.zeros((batch_size, time_step),
                              dtype=torch.int64).to(device)
        tgt_ids[:, 0] = start_token_id

        for i in range(k):
            beams[i] = tgt_ids

        finished_beams = [[] for _ in range(batch_size)]

        # initial hidden state
        ht = torch.zeros((self.decoder.num_layers, batch_size,
                          self.decoder.dec_hidden_size),
                         dtype=torch.float).to(self.device)
        for bn, l in enumerate(u_len):
            ht[:, bn, :] = u_output[bn, l - 1, :].unsqueeze(0)
        beam_ht = [None for _ in range(k)]
        for _k in range(k):
            beam_ht[_k] = ht.clone()

        finish = False

        # attn_scores_array = None

        for t in range(time_step - 1):
            if finish: break
            decoder_output_t_array = torch.zeros((batch_size, k * vocab_size))

            for i, beam in enumerate(beams):

                # inference decoding
                decoder_output, beam_ht[
                    i], attn_scores = self.decoder.forward_step(
                        beam[:, t:t + 1],
                        beam_ht[i],
                        enc_output_dict,
                        logsoftmax=True)

                # if attn_scores_array == None:
                #     enc_pos = attn_scores.size(-1)
                #     attn_scores_array = torch.zeros((k, time_step, enc_pos)) # BATCH_SIZE must be 1
                #
                # attn_scores_array[i,t,:] = attn_scores[0,0,:]

                # print("t = {}: attn_scores = {}".format(t , attn_scores))
                # import pdb; pdb.set_trace()

                # check if there is STOP_TOKEN emitted in the previous time step already
                # i.e. if the input at this time step is STOP_TOKEN
                for n_idx in range(batch_size):
                    if beam[n_idx][t] == stop_token_id:  # already stop
                        decoder_output[n_idx, :] = float('-inf')
                        decoder_output[
                            n_idx,
                            stop_token_id] = 0.0  # to ensure STOP_TOKEN will be picked again!

                decoder_output_t_array[:, i * vocab_size:(i + 1) *
                                       vocab_size] = decoder_output

                # add previous beam score bias
                for n_idx in range(batch_size):
                    decoder_output_t_array[n_idx, i * vocab_size:(i + 1) *
                                           vocab_size] += beam_scores[n_idx, i]

                    if search_method == 'argmax':
                        # Penalty term for repeated uni-gram
                        unigram_dict = {}
                        for tt in range(t + 1):
                            v = beam[n_idx, tt].cpu().numpy().item()
                            if v not in unigram_dict: unigram_dict[v] = 1
                            else: unigram_dict[v] += 1
                        for vocab_id, vocab_count in unigram_dict.items():
                            decoder_output_t_array[
                                n_idx, (i * vocab_size) +
                                vocab_id] -= penalty_ug * vocab_count / (t + 1)

                # only support batch_size = 1!
                if t == 0:
                    decoder_output_t_array[n_idx, (i + 1) *
                                           vocab_size:] = float('-inf')
                    break

            if search_method == 'sampling':
                # Sampling
                scores = np.zeros((batch_size, k))
                indices = np.zeros((batch_size, k))
                pmf = np.exp(decoder_output_t_array.cpu().numpy())
                for bi in range(batch_size):
                    if pmf[bi].sum() != 1.0:
                        pmf[bi] /= pmf[bi].sum()
                    sampled_ids = np.random.choice(k * vocab_size,
                                                   size=k,
                                                   p=pmf[bi])
                    for _s, s_id in enumerate(sampled_ids):
                        scores[bi, _s] = decoder_output_t_array[bi, s_id]
                        indices[bi, _s] = s_id

            elif search_method == 'argmax':
                # Argmax
                topk_scores, topk_ids = torch.topk(decoder_output_t_array,
                                                   k,
                                                   dim=-1)
                scores = topk_scores.double().cpu().numpy()
                indices = topk_ids.double().cpu().numpy()

            new_beams = [
                torch.zeros((batch_size, time_step),
                            dtype=torch.int64).to(device) for _ in range(k)
            ]
            for r_idx, row in enumerate(indices):
                for c_idx, node in enumerate(row):
                    vocab_idx = node % vocab_size
                    beam_idx = int(node / vocab_size)

                    new_beams[c_idx][r_idx, :t +
                                     1] = beams[beam_idx][r_idx, :t + 1]
                    new_beams[c_idx][r_idx, t + 1] = vocab_idx

                    # if there is a beam that has [END_TOKEN] --- store it
                    if vocab_idx == stop_token_id:
                        finished_beams[r_idx].append(
                            new_beams[c_idx][r_idx, :t + 1 + 1])
                        scores[r_idx, c_idx] = float('-inf')

            # only support BATCH SIZE = 1
            count_stop = 0
            for ik in range(k):
                if scores[0, ik] == float('-inf'): count_stop += 1
            if count_stop == k: finish = True

            beams = new_beams
            if search_method == 'sampling':
                # normalisation the score
                scores = np.exp(scores)
                scores = scores / scores.sum(axis=-1).reshape(batch_size, 1)
                beam_scores = np.log(scores +
                                     1e-20)  # suppress warning log(zero)
            elif search_method == 'argmax':
                beam_scores = scores

            # print("=========================  t = {} =========================".format(t))
            # for ik in range(k):
            # print("beam{}: [{:.5f}]".format(ik, scores[0,ik]),bert_tokenizer.decode(beams[ik][0].cpu().numpy()[:t+2]))
            # import pdb; pdb.set_trace()

            if (t % 50) == 0:
                print("{}=".format(t), end="")
                sys.stdout.flush()
        print("{}=#".format(t))

        for bi in range(batch_size):
            if len(finished_beams[bi]) == 0:
                finished_beams[bi].append(beams[0][bi])

        summaries_id = [None for _ in range(batch_size)]
        # for j in range(batch_size): summaries_id[j] = beams[0][j].cpu().numpy()
        for j in range(batch_size):
            _scores = self.beam_scoring(finished_beams[j], enc_output_dict,
                                        alpha)
            summaries_id[j] = finished_beams[j][np.argmax(
                _scores)].cpu().numpy()
            print(bert_tokenizer.decode(summaries_id[j]))

        return summaries_id
Example #32
0
def get_prediction(image_bytes):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tensor = data_transforms(image_bytes).view(1, 3, 224, 224).to(device)
    outputs = model.forward(tensor)
    prob, y_hat = torch.topk(outputs, k=5)
    return y_hat.tolist(), prob.tolist()
Example #33
0
 def forward(self, x):
     res, ind = torch.topk(x, 3)
     return torch.sigmoid(res), ind
Example #34
0
    def beam_search(self, src_sent: List[str], beam_size: int = 5, max_decoding_time_step: int = 70) -> List[
        Hypothesis]:
        """ Given a single source sentence, perform beam search, yielding translations in the target language.
        @param src_sent (List[str]): a single source sentence (words)
        @param beam_size (int): beam size
        @param max_decoding_time_step (int): maximum number of time steps to unroll the decoding RNN
        @returns hypotheses (List[Hypothesis]): a list of hypothesis, each hypothesis has two fields:
                value: List[str]: the decoded target sentence, represented as a list of words
                score: float: the log-likelihood of the target sentence
        """
        src_sents_var = self.vocab.src.to_input_tensor([src_sent], self.device)

        src_encodings, dec_init_vec = self.encode(src_sents_var, [len(src_sent)])
        src_encodings_att_linear = self.att_projection(src_encodings)

        h_tm1 = dec_init_vec
        att_tm1 = torch.zeros(1, self.hidden_size, device=self.device)

        eos_id = self.vocab.tgt['</s>']

        hypotheses = [['<s>']]
        hyp_scores = torch.zeros(len(hypotheses), dtype=torch.float, device=self.device)
        completed_hypotheses = []

        t = 0
        while len(completed_hypotheses) < beam_size and t < max_decoding_time_step:
            t += 1
            hyp_num = len(hypotheses)

            exp_src_encodings = src_encodings.expand(hyp_num,
                                                     src_encodings.size(1),
                                                     src_encodings.size(2))

            exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num,
                                                                           src_encodings_att_linear.size(1),
                                                                           src_encodings_att_linear.size(2))

            y_tm1 = torch.tensor([self.vocab.tgt[hyp[-1]] for hyp in hypotheses], dtype=torch.long, device=self.device)
            y_t_embed = self.model_embeddings.target(y_tm1)

            x = torch.cat([y_t_embed, att_tm1], dim=-1)

            (h_t, cell_t), att_t, _ = self.step(x, h_tm1,
                                                exp_src_encodings, exp_src_encodings_att_linear, enc_masks=None)

            # log probabilities over target words
            log_p_t = F.log_softmax(self.target_vocab_projection(att_t), dim=-1)

            live_hyp_num = beam_size - len(completed_hypotheses)
            contiuating_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(log_p_t) + log_p_t).view(-1)
            top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(contiuating_hyp_scores, k=live_hyp_num)

            prev_hyp_ids = top_cand_hyp_pos / len(self.vocab.tgt)
            hyp_word_ids = top_cand_hyp_pos % len(self.vocab.tgt)

            new_hypotheses = []
            live_hyp_ids = []
            new_hyp_scores = []

            for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores):
                prev_hyp_id = prev_hyp_id.item()
                hyp_word_id = hyp_word_id.item()
                cand_new_hyp_score = cand_new_hyp_score.item()

                hyp_word = self.vocab.tgt.id2word[hyp_word_id]
                new_hyp_sent = hypotheses[prev_hyp_id] + [hyp_word]
                if hyp_word == '</s>':
                    completed_hypotheses.append(Hypothesis(value=new_hyp_sent[1:-1],
                                                           score=cand_new_hyp_score))
                else:
                    new_hypotheses.append(new_hyp_sent)
                    live_hyp_ids.append(prev_hyp_id)
                    new_hyp_scores.append(cand_new_hyp_score)

            if len(completed_hypotheses) == beam_size:
                break

            live_hyp_ids = torch.tensor(live_hyp_ids, dtype=torch.long, device=self.device)
            h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
            att_tm1 = att_t[live_hyp_ids]

            hypotheses = new_hypotheses
            hyp_scores = torch.tensor(new_hyp_scores, dtype=torch.float, device=self.device)

        if len(completed_hypotheses) == 0:
            completed_hypotheses.append(Hypothesis(value=hypotheses[0][1:],
                                                   score=hyp_scores[0].item()))

        completed_hypotheses.sort(key=lambda hyp: hyp.score, reverse=True)

        return completed_hypotheses
Example #35
0
    def sample(self,
               ingr_features,
               ingr_mask,
               greedy=True,
               temperature=1.0,
               beam=-1,
               img_features=None,
               first_token_value=0,
               replacement=True,
               last_token_value=0,
               device='cpu'):

        incremental_state = {}

        # create dummy previous word
        if ingr_features is not None:
            fs = ingr_features.size(0)
        else:
            fs = img_features.size(0)

        if beam != -1:
            if fs == 1:
                return self.sample_beam(ingr_features, ingr_mask, beam,
                                        img_features, first_token_value,
                                        replacement, last_token_value)
            else:
                print(
                    "Beam Search can only be used with batch size of 1. Running greedy or temperature sampling..."
                )

        first_word = torch.ones(fs) * first_token_value

        first_word = first_word.to(device).long()
        sampled_ids = [first_word]
        logits = []

        for i in range(self.seq_length):
            # forward
            outputs, _ = self.forward(ingr_features,
                                      ingr_mask,
                                      torch.stack(sampled_ids, 1),
                                      img_features,
                                      incremental_state,
                                      device=device)
            outputs = outputs.squeeze(1)
            if not replacement:
                # predicted mask
                if i == 0:
                    predicted_mask = torch.zeros(
                        outputs.shape).float().to(device)
                else:
                    # ensure no repetitions in sampling if replacement==False
                    batch_ind = [
                        j for j in range(fs) if sampled_ids[i][j] != 0
                    ]
                    sampled_ids_new = sampled_ids[i][batch_ind]
                    predicted_mask[batch_ind, sampled_ids_new] = float('-inf')

                # mask previously selected ids
                outputs += predicted_mask

            logits.append(outputs)
            if greedy:
                outputs_prob = torch.nn.functional.softmax(outputs, dim=-1)
                _, predicted = outputs_prob.max(1)
                predicted = predicted.detach()
            else:
                k = 10
                outputs_prob = torch.div(outputs.squeeze(1), temperature)
                outputs_prob = torch.nn.functional.softmax(outputs_prob,
                                                           dim=-1).data

                # top k random sampling
                prob_prev_topk, indices = torch.topk(outputs_prob, k=k, dim=1)
                predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
                predicted = torch.index_select(indices, dim=1,
                                               index=predicted)[:, 0].detach()

            sampled_ids.append(predicted)

        sampled_ids = torch.stack(sampled_ids[1:], 1)
        logits = torch.stack(logits, 1)

        return sampled_ids, logits
Example #36
0
    def beam_search(self, batch):
        # batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            self.train.model.get_input_from_batch(batch)

        encoder_outputs, encoder_hidden, max_encoder_output = self.train.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.train.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        dec_h, dec_c = s_t_0  # (1, 2*H)
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.dataset.vocab.word2id(opt.BOS)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context=c_t_0[0],
                      coverage=(coverage_t_0[0] if self.args.is_coverage else None))
                 for _ in range(self.args.beam_size)]
        results = []
        steps = 0
        while steps < self.args.max_decoder_steps and len(results) < self.args.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.dataset.vocab.vocab_size else self.dataset.vocab.word2id(opt.UNKNOWN_TOKEN)
                             for t in latest_tokens]
            y_t_1 = torch.LongTensor(latest_tokens).to(opt.device)
            all_state_h = []
            all_state_c = []
            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if self.args.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.train.model.decoder(y_t_1, s_t_1,
                                                                                          encoder_outputs,
                                                                                          enc_padding_mask,
                                                                                          c_t_1,
                                                                                          extra_zeros,
                                                                                          enc_batch_extend_vocab,
                                                                                          coverage_t_1)

            topk_log_probs, topk_ids = torch.topk(final_dist, self.args.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if self.args.is_coverage else None)

                for j in range(self.args.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].data.item(),
                                        log_prob=topk_log_probs[i, j].data.item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.dataset.vocab.word2id(opt.EOS):
                    if steps >= self.args.min_decoder_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == self.args.beam_size or len(results) == self.args.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)
        return beams_sorted[0]
Example #37
0
    def generate(self,
                 encoded,
                 lang_id,
                 max_len=200,
                 sample=False,
                 temperature=None):
        """
        Generate a sentence from a given initial state.
        Input:
            - FloatTensor of size (batch_size, hidden_dim) representing
              sentences encoded in the latent space
        Output:
            - LongTensor of size (seq_len, batch_size), word indices
            - LongTensor of size (batch_size,), sentence x_len
        """
        if self.beam_size > 0:
            return self.generate_beam(encoded, lang_id, self.beam_size,
                                      max_len, sample, temperature)

        encoder_out = encoded.dec_input
        latent = encoder_out['encoder_out']

        x_len = encoded.input_len
        is_cuda = latent.is_cuda
        one_hot = None

        # check inputs
        assert type(lang_id) is int
        assert latent.size() == (x_len.max(), x_len.size(0), self.emb_dim)
        assert (sample is True) ^ (temperature is None)

        # initialize generated sentences batch
        slen, bs = latent.size(0), latent.size(1)
        assert x_len.max() == slen and x_len.size(0) == bs
        cur_len = 1
        decoded = torch.LongTensor(max_len, bs).fill_(self.pad_index)
        unfinished_sents = torch.LongTensor(bs).fill_(1)
        lengths = torch.LongTensor(bs).fill_(1)
        if is_cuda:
            decoded = decoded.cuda()
            unfinished_sents = unfinished_sents.cuda()
            lengths = lengths.cuda()
        decoded[0] = self.bos_index[lang_id]

        incremental_state = {}
        while cur_len < max_len:

            # previous word embeddings
            scores = self.forward(encoded,
                                  decoded[:cur_len],
                                  lang_id,
                                  one_hot,
                                  incremental_state=None)
            scores = scores.data[-1, :, :]  # T x B x V -> B x V

            # select next words: sample or one-hot
            if sample:
                next_words = torch.multinomial((scores / temperature).exp(),
                                               1).squeeze(1)
            else:
                next_words = torch.topk(scores, 1)[1].squeeze(1)
            assert next_words.size() == (bs, )
            decoded[
                cur_len] = next_words * unfinished_sents + self.pad_index * (
                    1 - unfinished_sents)
            lengths.add_(unfinished_sents)
            unfinished_sents.mul_(next_words.ne(self.eos_index).long())
            cur_len += 1

            # stop when there is a </s> in each sentence
            if unfinished_sents.max() == 0:
                break

        if cur_len == max_len:
            decoded[max_len - 1].masked_fill_(unfinished_sents.byte(),
                                              self.eos_index)
        assert (decoded == self.eos_index).sum() == bs

        # if lang_id == 0:
        #     print(lang_id, decoded[0], decoded[-1])
        #     input('Decoder generate')
        return decoded[:cur_len], lengths, one_hot
def train_model(model,
                datasetloader_dict,
                dataset_dict,
                loss_function,
                optimizer,
                num_epochs=50):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
        logging.info('-' * 10)

        for phase in ['train', 'val']:

            if phase == 'train':
                model.train(True)
            else:
                model.train(False)

            running_loss = 0.0
            running_corrects = 0
            total = 0

            for batch_index, batch_datums in enumerate(
                    datasetloader_dict[phase]):
                optimizer.zero_grad()

                ### forward (compute loss) ###
                # change this block Model Wise
                model.hidden = model.init_hidden()
                if use_gpu:
                    actual_label_tensor = Variable(
                        batch_datums['answer_index'].cuda(),
                        requires_grad=False).view(-1)
                    token_sequence_tensor = Variable(
                        batch_datums['question_token_ids'].cuda(),
                        requires_grad=False)
                    answer_vector_tensor = Variable(
                        batch_datums['answer_vector_tensor'].cuda(),
                        requires_grad=False)
                else:
                    actual_label_tensor = Variable(
                        batch_datums['answer_index'],
                        requires_grad=False).view(-1)
                    token_sequence_tensor = Variable(
                        batch_datums['question_token_ids'], requires_grad=False)
                    answer_vector_tensor = Variable(
                        batch_datums['answer_vector_tensor'],
                        requires_grad=False)
                prediction_scores_tensor = model.forward(token_sequence_tensor)

                loss = loss_function(prediction_scores_tensor,
                                     actual_label_tensor)
                loss_value = loss.data[0]

                ### Statistics ###
                topK = 3
                _, predictions = torch.max(prediction_scores_tensor.data, 1)
                prediction = predictions.view(-1)
                answer_vector = answer_vector_tensor.data.view(-1)
                top_predicted_indices = torch.topk(
                    prediction_scores_tensor.data, topK, 1)[1][0].tolist()
                correctness = max([
                    int(answer_vector[prediction])
                    for prediction in top_predicted_indices
                ])
                running_corrects += correctness
                running_loss += loss_value

                ### backward + optimize (if in training phase) ###
                if phase == 'train' and epoch > 0:
                    loss.backward()
                    optimizer.step()

                ### some debug logs ###
                if total % 1000 == 0:
                    logging.info('{} datums processed'.format(int(total)))

                total += len(batch_datums[batch_datums.keys()[0]])

            epoch_loss = running_loss / float(total)
            epoch_acc = running_corrects / float(total)

            logging.info('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

    time_elapsed = time.time() - since
    logging.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logging.info('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
Example #39
0
    def _viterbi_decode_nbest(self, feats, mask, nbest):
        """
            input:
                feats: (batch, seq_len, self.tag_size+2)
                mask: (batch, seq_len)
            output:
                decode_idx: (batch, nbest, seq_len) decoded sequence
                path_score: (batch, nbest) corresponding score for each sequence (to be implementated)
                nbest decode for sentence with one token is not well supported, to be optimized
        """
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(2)
        assert (tag_size == self.tagset_size + 2)
        ## calculate sentence length for each sentence
        length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long()
        ## mask to (seq_len, batch_size)
        mask = mask.transpose(1, 0).contiguous()
        ins_num = seq_len * batch_size
        ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
        feats = feats.transpose(1, 0).contiguous().view(
            ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
        ## need to consider start
        scores = feats + self.transitions.view(1, tag_size, tag_size).expand(
            ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)

        # build iter
        seq_iter = enumerate(scores)
        ## record the position of best score
        back_points = list()
        partition_history = list()
        ##  reverse mask (bug for mask = 1- mask, use this as alternative choice)
        # mask = 1 + (-1)*mask
        mask = (1 - mask.long()).byte()
        _, inivalues = next(
            seq_iter)  # bat_size * from_target_size * to_target_size
        # only need start from start_tag
        partition = inivalues[:, START_TAG, :].clone(
        )  # bat_size * to_target_size
        ## initial partition [batch_size, tag_size]
        partition_history.append(
            partition.view(batch_size, tag_size,
                           1).expand(batch_size, tag_size, nbest))
        # iter over last scores
        for idx, cur_values in seq_iter:
            if idx == 1:
                cur_values = cur_values.view(
                    batch_size, tag_size,
                    tag_size) + partition.contiguous().view(
                        batch_size, tag_size, 1).expand(
                            batch_size, tag_size, tag_size)
            else:
                # previous to_target is current from_target
                # partition: previous results log(exp(from_target)), #(batch_size * nbest * from_target)
                # cur_values: batch_size * from_target * to_target
                cur_values = cur_values.view(
                    batch_size, tag_size, 1, tag_size).expand(
                        batch_size, tag_size, nbest,
                        tag_size) + partition.contiguous().view(
                            batch_size, tag_size, nbest, 1).expand(
                                batch_size, tag_size, nbest, tag_size)
                ## compare all nbest and all from target
                cur_values = cur_values.view(batch_size, tag_size * nbest,
                                             tag_size)
                # print "cur size:",cur_values.size()
            partition, cur_bp = torch.topk(cur_values, nbest, 1)
            ## cur_bp/partition: [batch_size, nbest, tag_size], id should be normize through nbest in following backtrace step
            # print partition[:,0,:]
            # print cur_bp[:,0,:]
            # print "nbest, ",idx
            if idx == 1:
                cur_bp = cur_bp * nbest
            partition = partition.transpose(2, 1)
            cur_bp = cur_bp.transpose(2, 1)

            # print partition
            # exit(0)
            #partition: (batch_size * to_target * nbest)
            #cur_bp: (batch_size * to_target * nbest) Notice the cur_bp number is the whole position of tag_size*nbest, need to convert when decode
            partition_history.append(partition)
            ## cur_bp: (batch_size,nbest, tag_size) topn source score position in current tag
            ## set padded label as 0, which will be filtered in post processing
            ## mask[idx] ? mask[idx-1]
            cur_bp.masked_fill_(
                mask[idx].view(batch_size, 1,
                               1).expand(batch_size, tag_size, nbest), 0)
            # print cur_bp[0]
            back_points.append(cur_bp)
        ### add score to final STOP_TAG
        partition_history = torch.cat(partition_history, 0).view(
            seq_len, batch_size, tag_size, nbest).transpose(
                1, 0).contiguous()  ## (batch_size, seq_len, nbest, tag_size)
        ### get the last position for each setences, and select the last partitions using gather()
        last_position = length_mask.view(batch_size, 1, 1, 1).expand(
            batch_size, 1, tag_size, nbest) - 1
        last_partition = torch.gather(partition_history, 1,
                                      last_position).view(
                                          batch_size, tag_size, nbest, 1)
        ### calculate the score from last partition to end state (and then select the STOP_TAG from it)
        last_values = last_partition.expand(
            batch_size, tag_size, nbest, tag_size) + self.transitions.view(
                1, tag_size, 1, tag_size).expand(batch_size, tag_size, nbest,
                                                 tag_size)
        last_values = last_values.view(batch_size, tag_size * nbest, tag_size)
        end_partition, end_bp = torch.topk(last_values, nbest, 1)
        ## end_partition: (batch, nbest, tag_size)
        end_bp = end_bp.transpose(2, 1)
        # end_bp: (batch, tag_size, nbest)
        pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size,
                                                 nbest)).long()
        if self.gpu:
            pad_zero = pad_zero.cuda()
        back_points.append(pad_zero)
        back_points = torch.cat(back_points).view(seq_len, batch_size,
                                                  tag_size, nbest)

        ## select end ids in STOP_TAG
        pointer = end_bp[:, STOP_TAG, :]  ## (batch_size, nbest)
        insert_last = pointer.contiguous().view(
            batch_size, 1, 1, nbest).expand(batch_size, 1, tag_size, nbest)
        back_points = back_points.transpose(1, 0).contiguous()
        ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
        # print "lp:",last_position
        # print "il:",insert_last[0]
        # exit(0)
        ## copy the ids of last position:insert_last to back_points, though the last_position index
        ## last_position includes the length of batch sentences
        # print "old:", back_points[9,0,:,:]
        back_points.scatter_(1, last_position, insert_last)
        ## back_points: [batch_size, seq_length, tag_size, nbest]
        # print "new:", back_points[9,0,:,:]
        # exit(0)
        # print pointer[2]
        '''
        back_points: in simple demonstratration
        x,x,x,x,x,x,x,x,x,7
        x,x,x,x,x,4,0,0,0,0
        x,x,6,0,0,0,0,0,0,0
        '''

        back_points = back_points.transpose(1, 0).contiguous()
        # print back_points[0]
        ## back_points: (seq_len, batch, tag_size, nbest)
        ## decode from the end, padded position ids are 0, which will be filtered in following evaluation
        decode_idx = autograd.Variable(
            torch.LongTensor(seq_len, batch_size, nbest))
        if self.gpu:
            decode_idx = decode_idx.cuda()
        decode_idx[-1] = pointer.data / nbest
        # print "pointer-1:",pointer[2]
        # exit(0)
        # use old mask, let 0 means has token
        for idx in range(len(back_points) - 2, -1, -1):
            # print "pointer: ",idx,  pointer[3]
            # print "back:",back_points[idx][3]
            # print "mask:",mask[idx+1,3]
            new_pointer = torch.gather(
                back_points[idx].view(batch_size, tag_size * nbest), 1,
                pointer.contiguous().view(batch_size, nbest))
            decode_idx[idx] = new_pointer.data / nbest
            # # use new pointer to remember the last end nbest ids for non longest
            pointer = new_pointer + pointer.contiguous().view(
                batch_size, nbest) * mask[idx].view(batch_size, 1).expand(
                    batch_size, nbest).long()

        # exit(0)
        path_score = None
        decode_idx = decode_idx.transpose(1, 0)
        ## decode_idx: [batch, seq_len, nbest]
        # print decode_idx[:,:,0]
        # print "nbest:",nbest
        # print "diff:", decode_idx[:,:,0]- decode_idx[:,:,4]
        # print decode_idx[:,0,:]
        # exit(0)

        ### calculate probability for each sequence
        scores = end_partition[:, :, STOP_TAG]
        ## scores: [batch_size, nbest]
        max_scores, _ = torch.max(scores, 1)
        minus_scores = scores - max_scores.view(batch_size, 1).expand(
            batch_size, nbest)
        path_score = F.softmax(minus_scores, 1)
        ## path_score: [batch_size, nbest]
        # exit(0)
        return path_score, decode_idx
Example #40
0
    def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
        bsz, srclen = src_tokens.size()
        maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen

        # the max beam size is the dictionary size - 1, since we never select pad
        beam_size = beam_size if beam_size is not None else self.beam_size
        beam_size = min(beam_size, self.vocab_size - 1)

        encoder_outs = []
        incremental_states = {}
        for model in self.models:
            if not self.retain_dropout:
                model.eval()
            if isinstance(model.decoder, FairseqIncrementalDecoder):
                incremental_states[model] = {}
            else:
                incremental_states[model] = None

            # compute the encoder output for each beam
            encoder_out = model.encoder(
                src_tokens.repeat(1, beam_size).view(-1, srclen),
                src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
            )
            encoder_outs.append(encoder_out)

        # initialize buffers
        scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos
        attn, attn_buf = None, None
        nonpad_idxs = None

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early or step == maxlen or unfinalized_scores is None:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                best_unfinalized_score = unfinalized_scores[sent].max()
                if self.normalize_scores:
                    best_unfinalized_score /= maxlen ** self.len_penalty
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.
            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.
            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step + 2]  # skip the first index, which is EOS
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1) ** self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

            sents_seen = set()
            for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i][nonpad_idxs[sent]]
                        _, alignment = hypo_attn.max(dim=0)
                    else:
                        hypo_attn = None
                        alignment = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': alignment,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif not self.stop_early and score > worst_finalized[sent]['score']:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]['idx']
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
                    worst_finalized[sent] = {
                        'score': s['score'],
                        'idx': idx,
                    }

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfinalized_scores):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished

        reorder_state = None
        batch_idxs = None
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
                for i, model in enumerate(self.models):
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
                        model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
                    encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)

            lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)

            lprobs[:, self.pad] = -math.inf  # never select pad
            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty

            # Record attention scores
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
                    attn_buf = attn.clone()
                    nonpad_idxs = src_tokens.ne(self.pad)
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)
            if step < maxlen:
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
                    cand_scores = torch.gather(
                        probs_slice, dim=1,
                        index=prefix_tokens[:, step].view(-1, 1).data
                    ).expand(-1, cand_size)
                    cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
                    cand_beams = torch.zeros_like(cand_indices)
                else:
                    cand_scores, cand_indices, cand_beams = self.search.step(
                        step,
                        lprobs.view(bsz, -1, self.vocab_size),
                        scores.view(bsz, beam_size, -1)[:, :, :step],
                    )
            else:
                # make probs contain cumulative scores for each hypothesis
                lprobs.add_(scores[:, step - 1].unsqueeze(-1))

                # finalize all active hypotheses once we hit maxlen
                # pick the hypothesis with the highest prob of EOS right now
                torch.sort(
                    lprobs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
                num_remaining_sent -= len(finalize_hypos(
                    step, eos_bbsz_idx, eos_scores))
                assert num_remaining_sent == 0
                break

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)

            finalized_sents = set()
            if step >= self.minlen:
                # only consider eos when it's among the top beam_size indices
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
                if eos_bbsz_idx.numel() > 0:
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
                    finalized_sents = finalize_hypos(
                        step, eos_bbsz_idx, eos_scores, cand_scores)
                    num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < maxlen

            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = batch_mask.nonzero().squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)

                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
                    attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer('active_mask')
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
            torch.topk(
                active_mask, k=beam_size, dim=1, largest=False,
                out=(_ignore, active_hypos)
            )

            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx, dim=1, index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices, dim=1, index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step], dim=0, index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            if attn is not None:
                torch.index_select(
                    attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],
                )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)

        return finalized
Example #41
0
    def _viterbi_decode_nbest(self, feats, mask, nbest):
        """
            input:
                feats: (batch, seq_len, self.tag_size+2)
                mask: (batch, seq_len)
            output:
                decode_idx: (batch, nbest, seq_len) decoded sequence
                path_score: (batch, nbest) corresponding score for each sequence (to be implementated)
                nbest decode for sentence with one token is not well supported, to be optimized
        """
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(2)
        assert(tag_size == self.tagset_size+2)
        ## calculate sentence length for each sentence
        length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
        ## mask to (seq_len, batch_size)
        mask = mask.transpose(1,0).contiguous()
        ins_num = seq_len * batch_size
        ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
        feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
        ## need to consider start
        scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)

        # build iter
        seq_iter = enumerate(scores)
        ## record the position of best score
        back_points = list()
        partition_history = list()
        ##  reverse mask (bug for mask = 1- mask, use this as alternative choice)
        # mask = 1 + (-1)*mask
        mask =  (1 - mask.long()).byte()
        _, inivalues = next(seq_iter)  # bat_size * from_target_size * to_target_size
        # only need start from start_tag
        partition = inivalues[:, START_TAG, :].clone()  # bat_size * to_target_size
        ## initial partition [batch_size, tag_size]
        partition_history.append(partition.view(batch_size, tag_size, 1).expand(batch_size, tag_size, nbest))
        # iter over last scores
        for idx, cur_values in seq_iter:
            if idx == 1:
                cur_values = cur_values.view(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            else:
                # previous to_target is current from_target
                # partition: previous results log(exp(from_target)), #(batch_size * nbest * from_target)
                # cur_values: batch_size * from_target * to_target
                cur_values = cur_values.view(batch_size, tag_size, 1, tag_size).expand(batch_size, tag_size, nbest, tag_size) + partition.contiguous().view(batch_size, tag_size, nbest, 1).expand(batch_size, tag_size, nbest, tag_size)
                ## compare all nbest and all from target
                cur_values = cur_values.view(batch_size, tag_size*nbest, tag_size)
                # print "cur size:",cur_values.size()
            partition, cur_bp = torch.topk(cur_values, nbest, 1)
            ## cur_bp/partition: [batch_size, nbest, tag_size], id should be normize through nbest in following backtrace step
            # print partition[:,0,:]
            # print cur_bp[:,0,:]
            # print "nbest, ",idx
            if idx == 1:
                cur_bp = cur_bp*nbest
            partition = partition.transpose(2,1)
            cur_bp = cur_bp.transpose(2,1)

            # print partition
            # exit(0)
            #partition: (batch_size * to_target * nbest)
            #cur_bp: (batch_size * to_target * nbest) Notice the cur_bp number is the whole position of tag_size*nbest, need to convert when decode
            partition_history.append(partition)
            ## cur_bp: (batch_size,nbest, tag_size) topn source score position in current tag
            ## set padded label as 0, which will be filtered in post processing
            ## mask[idx] ? mask[idx-1]
            cur_bp.masked_fill_(mask[idx].view(batch_size, 1, 1).expand(batch_size, tag_size, nbest), 0)
            # print cur_bp[0]
            back_points.append(cur_bp)
        ### add score to final STOP_TAG
        partition_history = torch.cat(partition_history,0).view(seq_len, batch_size, tag_size, nbest).transpose(1,0).contiguous() ## (batch_size, seq_len, nbest, tag_size)
        ### get the last position for each setences, and select the last partitions using gather()
        last_position = length_mask.view(batch_size,1,1,1).expand(batch_size, 1, tag_size, nbest) - 1
        last_partition = torch.gather(partition_history, 1, last_position).view(batch_size, tag_size, nbest, 1)
        ### calculate the score from last partition to end state (and then select the STOP_TAG from it)
        last_values = last_partition.expand(batch_size, tag_size, nbest, tag_size) + self.transitions.view(1, tag_size, 1, tag_size).expand(batch_size, tag_size, nbest, tag_size)
        last_values = last_values.view(batch_size, tag_size*nbest, tag_size)
        end_partition, end_bp = torch.topk(last_values, nbest, 1)
        ## end_partition: (batch, nbest, tag_size)
        end_bp = end_bp.transpose(2,1)
        # end_bp: (batch, tag_size, nbest)
        pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size, nbest)).long()
        if self.gpu:
            pad_zero = pad_zero.cuda()
        back_points.append(pad_zero)
        back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size, nbest)

        ## select end ids in STOP_TAG
        pointer = end_bp[:, STOP_TAG, :] ## (batch_size, nbest)
        insert_last = pointer.contiguous().view(batch_size, 1, 1, nbest).expand(batch_size, 1, tag_size, nbest)
        back_points = back_points.transpose(1,0).contiguous()
        ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
        # print "lp:",last_position
        # print "il:",insert_last[0]
        # exit(0)
        ## copy the ids of last position:insert_last to back_points, though the last_position index
        ## last_position includes the length of batch sentences
        # print "old:", back_points[9,0,:,:]
        back_points.scatter_(1, last_position, insert_last)
        ## back_points: [batch_size, seq_length, tag_size, nbest]
        # print "new:", back_points[9,0,:,:]
        # exit(0)
        # print pointer[2]
        '''
        back_points: in simple demonstratration
        x,x,x,x,x,x,x,x,x,7
        x,x,x,x,x,4,0,0,0,0
        x,x,6,0,0,0,0,0,0,0
        '''

        back_points = back_points.transpose(1,0).contiguous()
        # print back_points[0]
        ## back_points: (seq_len, batch, tag_size, nbest)
        ## decode from the end, padded position ids are 0, which will be filtered in following evaluation
        decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size, nbest))
        if self.gpu:
            decode_idx = decode_idx.cuda()
        decode_idx[-1] = pointer.data/nbest
        # print "pointer-1:",pointer[2]
        # exit(0)
        # use old mask, let 0 means has token
        for idx in range(len(back_points)-2, -1, -1):
            # print "pointer: ",idx,  pointer[3]
            # print "back:",back_points[idx][3]
            # print "mask:",mask[idx+1,3]
            new_pointer = torch.gather(back_points[idx].view(batch_size, tag_size*nbest), 1, pointer.contiguous().view(batch_size,nbest))
            decode_idx[idx] = new_pointer.data/nbest
            # # use new pointer to remember the last end nbest ids for non longest
            pointer = new_pointer + pointer.contiguous().view(batch_size,nbest)*mask[idx].view(batch_size,1).expand(batch_size, nbest).long()

        # exit(0)
        path_score = None
        decode_idx = decode_idx.transpose(1,0)
        ## decode_idx: [batch, seq_len, nbest]
        # print decode_idx[:,:,0]
        # print "nbest:",nbest
        # print "diff:", decode_idx[:,:,0]- decode_idx[:,:,4]
        # print decode_idx[:,0,:]
        # exit(0)

        ### calculate probability for each sequence
        scores = end_partition[:, :, STOP_TAG]
        ## scores: [batch_size, nbest]
        max_scores,_ = torch.max(scores, 1)
        minus_scores = scores - max_scores.view(batch_size,1).expand(batch_size, nbest)
        path_score = F.softmax(minus_scores, 1)
        ## path_score: [batch_size, nbest]
        # exit(0)
        return path_score, decode_idx
Example #42
0
                                                       k=5)
print("Top 1 accuracy on test set is", top1_acc)

# Get the confusion matrix from test
confusion_matrix = {x: [0, 0, 0, 0, 0] for x in class_name}
# confusion_matrix = {x: [0,0,0,0,0,0,0] for x in class_name}   for smear dataset

running_top1_correct = 0
loader = dataloaders['test']
labels_array = []
pred_array = []
for inputs, labels in tqdm(loader):
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    _, preds = torch.topk(outputs, k=1, dim=1)
    for i in range(len(labels)):
        original_label = int(labels[i])
        labels_array.append(original_label)
        pred_array.append(int(preds[i]))
        confusion_matrix[class_name[original_label]][int(preds[i])] += 1

    running_top1_correct += torch.sum(preds[:, 0] == labels.data)

precision, recall, fscore, support = score(labels_array, pred_array)

epoch_top1_acc = float(running_top1_correct.double() / len(loader.dataset))
percentage = {
    x: [y / sum(confusion_matrix[x]) for y in confusion_matrix[x]]
    for x in confusion_matrix.keys()
}
Example #43
0
    def beam_search(self, src_sent: List[str], beam_size: int=5, max_decoding_time_step: int=70) -> List[Hypothesis]:
        """
        Given a single source sentence, perform beam search

        Args:
            src_sent: a single tokenized source sentence
            beam_size: beam size
            max_decoding_time_step: maximum number of time steps to unroll the decoding RNN

        Returns:
            hypotheses: a list of hypothesis, each hypothesis has two fields:
                value: List[str]: the decoded target sentence, represented as a list of words
                score: float: the log-likelihood of the target sentence
        """

        src_sents_var = self.vocab.src.to_input_tensor([src_sent], self.device)

        src_encodings, dec_init_vec = self.encode(src_sents_var, [len(src_sent)])
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        h_tm1 = dec_init_vec
        att_tm1 = torch.zeros(1, self.hidden_size, device=self.device)

        eos_id = self.vocab.tgt['</s>']

        hypotheses = [['<s>']]
        hyp_scores = torch.zeros(len(hypotheses), dtype=torch.float, device=self.device)
        completed_hypotheses = []

        t = 0
        while len(completed_hypotheses) < beam_size and t < max_decoding_time_step:
            t += 1
            hyp_num = len(hypotheses)

            exp_src_encodings = src_encodings.expand(hyp_num,
                                                     src_encodings.size(1),
                                                     src_encodings.size(2))

            exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num,
                                                                           src_encodings_att_linear.size(1),
                                                                           src_encodings_att_linear.size(2))

            y_tm1 = torch.tensor([self.vocab.tgt[hyp[-1]] for hyp in hypotheses], dtype=torch.long, device=self.device)
            y_tm1_embed = self.tgt_embed(y_tm1)

            if self.input_feed:
                x = torch.cat([y_tm1_embed, att_tm1], dim=-1)
            else:
                x = y_tm1_embed

            (h_t, cell_t), att_t, alpha_t = self.step(x, h_tm1,
                                                      exp_src_encodings, exp_src_encodings_att_linear, src_sent_masks=None)

            # log probabilities over target words
            log_p_t = F.log_softmax(self.readout(att_t), dim=-1)

            live_hyp_num = beam_size - len(completed_hypotheses)
            contiuating_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(log_p_t) + log_p_t).view(-1)
            top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(contiuating_hyp_scores, k=live_hyp_num)

            prev_hyp_ids = top_cand_hyp_pos / len(self.vocab.tgt)
            hyp_word_ids = top_cand_hyp_pos % len(self.vocab.tgt)

            new_hypotheses = []
            live_hyp_ids = []
            new_hyp_scores = []

            for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores):
                prev_hyp_id = prev_hyp_id.item()
                hyp_word_id = hyp_word_id.item()
                cand_new_hyp_score = cand_new_hyp_score.item()

                hyp_word = self.vocab.tgt.id2word[hyp_word_id]
                new_hyp_sent = hypotheses[prev_hyp_id] + [hyp_word]
                if hyp_word == '</s>':
                    completed_hypotheses.append(Hypothesis(value=new_hyp_sent[1:-1],
                                                           score=cand_new_hyp_score))
                else:
                    new_hypotheses.append(new_hyp_sent)
                    live_hyp_ids.append(prev_hyp_id)
                    new_hyp_scores.append(cand_new_hyp_score)

            if len(completed_hypotheses) == beam_size:
                break

            live_hyp_ids = torch.tensor(live_hyp_ids, dtype=torch.long, device=self.device)
            h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
            att_tm1 = att_t[live_hyp_ids]

            hypotheses = new_hypotheses
            hyp_scores = torch.tensor(new_hyp_scores, dtype=torch.float, device=self.device)

        if len(completed_hypotheses) == 0:
            completed_hypotheses.append(Hypothesis(value=hypotheses[0][1:],
                                                   score=hyp_scores[0].item()))

        completed_hypotheses.sort(key=lambda hyp: hyp.score, reverse=True)

        return completed_hypotheses
Example #44
0
    def generate(self, src_enc, src_len, tgt_lang_id, max_len=200, sample_temperature=None):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """

        # input batch
        bs = len(src_len)
        assert src_enc.size(0) == bs

        # generated sentences
        generated = src_len.new(max_len, bs)  # upcoming output
        generated.fill_(self.pad_index)       # fill upcoming ouput with <PAD>
        generated[0].fill_(self.eos_index)    # we use <EOS> for <BOS> everywhere

        # positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand(max_len, bs)

        # language IDs
        langs = src_len.new(max_len).long().fill_(tgt_lang_id)
        langs = langs.unsqueeze(1).expand(max_len, bs)

        # current position / max lengths / length of generated sentences / unfinished sentences
        cur_len = 1
        gen_len = src_len.clone().fill_(1)
        unfinished_sents = src_len.clone().fill_(1)

        # cache compute states
        cache = {'slen': 0}

        while cur_len < max_len:

            # compute word scores
            tensor = self.forward(
                'fwd',
                x=generated[:cur_len],
                lengths=gen_len,
                positions=positions[:cur_len],
                langs=langs[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                cache=cache
            )
            assert tensor.size() == (1, bs, self.dim), (cur_len, max_len, src_enc.size(), tensor.size(), (1, bs, self.dim))
            tensor = tensor.data[-1, :, :].type_as(src_enc)  # (bs, dim)
            scores = self.pred_layer.get_scores(tensor)      # (bs, n_words)

            # select next words: sample or greedy
            if sample_temperature is None:
                if self.mask_gen_lang is True:
                    next_words = torch.topk(scores, self.mask_topk)[1].squeeze(1)
                else:
                    next_words = torch.topk(scores, 1)[1].squeeze(1)
            else:
                if self.mask_gen_lang is True:
                    next_words = torch.multinomial(F.softmax(scores / sample_temperature, dim=1), self.mask_topk).squeeze(1)
                else:
                    next_words = torch.multinomial(F.softmax(scores / sample_temperature, dim=1), 1).squeeze(1)

            if self.mask_gen_lang is True:
                tmp_next_words = torch.zeros(bs, dtype=torch.long)
                for j, next_word in enumerate(next_words.cpu()):
                    has_tgt_id = False
                    for i, wi in enumerate(next_word):
                        if language_detect(self.dico.id2word[wi.item()], self.id2lang[tgt_lang_id]):
                            has_tgt_id = True
                            tmp_next_words[j] = wi
                            break
                    if has_tgt_id is False:
                        tmp_next_words[j] = next_words[j, 0]
                next_words = tmp_next_words.cuda()

            assert next_words.size() == (bs,)

            # update generations / lengths / finished sentences / current length
            generated[cur_len] = next_words * unfinished_sents + self.pad_index * (1 - unfinished_sents)
            gen_len.add_(unfinished_sents)
            unfinished_sents.mul_(next_words.ne(self.eos_index).long())
            cur_len = cur_len + 1

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

        # add <EOS> to unfinished sentences
        if cur_len == max_len:
            generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)

        # sanity check
        assert (generated == self.eos_index).sum() == 2 * bs

        return generated[:cur_len], gen_len
Example #45
0
    def translate(self, source: torch.tensor, beam_size=5, max_len=70):
        """single source-target sentence translate.
        :param source (1, src_len): one source sentence to be translated.
        :param beam_size : beam search size
        
        :return translate_result (list[int]): most likely candidate
        """
        source = source.t()
        src_len = torch.tensor([source.size(0)],
                               dtype=torch.int,
                               device=self.device)
        src_embedded = self.src_embedding(source)
        memory, init_dec_state, enc_atten_vec = self.encoder(
            src_embedded, src_len)

        # ----beam search----
        # search stop when number of completed sentences reaches beam_size, or get to max_len.
        completed = []
        last_out = torch.zeros(self.hidden_size,
                               dtype=torch.float,
                               device=self.device)
        candidate = [
            Pair([self.vocab.tgt.get_eos_info(0)], 0, init_dec_state, last_out)
        ]
        step = 0
        while len(completed) < beam_size and step < max_len:
            step += 1
            # generate batch input from candidate sentences
            last_outs = torch.stack([item.last_out for item in candidate],
                                    dim=0)  # (num_candidate, hidden_size)
            words = torch.tensor([item.sent[-1] for item in candidate],
                                 dtype=torch.long,
                                 device=self.device)
            words_embed = self.tgt_embedding(
                words)  # (num_candidate, embed_size)
            x_in = torch.cat(
                (words_embed, last_outs),
                dim=1).unsqueeze(0)  # (1, num_candidate, embed_size)
            last_hidden_state = torch.cat(
                [item.last_state[0] for item in candidate],
                dim=1)  # (num_layers, num_candidate, hidden_size)
            last_cell_state = torch.cat(
                [item.last_state[1] for item in candidate],
                dim=1)  # (num_layers, num_candidate, hidden_size)

            # expand memory, enc_atten_vec and src_len to match the batch size.
            batch_size = len(candidate)
            memory_expand = memory.expand(-1, batch_size, -1)
            enc_atten_vec_expand = enc_atten_vec.expand(batch_size, -1, -1)
            src_len_expand = src_len.expand(batch_size)

            # predict the next words for all of the candidates
            output, dec_state = self.decoder(
                x_in, (last_hidden_state, last_cell_state), memory_expand,
                enc_atten_vec_expand, src_len_expand)
            prob = F.log_softmax(self.proj_vocab(output),
                                 dim=-1)  # (num_candidate, vocab_size)
            scores = prob + torch.tensor([item.score for item in candidate],
                                         dtype=torch.float,
                                         device=self.device).view(-1, 1)

            # select the top 'beam_size' predicts and update candidates
            tops = torch.topk(scores.view(-1), beam_size)
            candidate_new = []
            for score, idx in zip(*tops):
                cand_id = idx.item() // len(self.vocab.tgt)
                vocab_id = idx.item() % len(self.vocab.tgt)
                if vocab_id == self.vocab.tgt.get_eos_info(1):
                    # the completed sentence doesn't include </s> token
                    completed.append(
                        Pair(candidate[cand_id].sent, score, None, None))
                else:
                    last_state = (dec_state[0][:, cand_id, :].unsqueeze(1),
                                  dec_state[1][:, cand_id, :].unsqueeze(1))
                    candidate_new.append(
                        Pair(candidate[cand_id].sent + [vocab_id], score,
                             last_state, output[cand_id]))
            candidate = candidate_new

        if len(completed) < beam_size:
            completed.extend(candidate[:beam_size - len(completed)])
        ave_scores = [item.score / (len(item.sent) - 1) for item in completed]
        translate_result = completed[np.argmax(ave_scores)].sent[1:]

        return translate_result
Example #46
0
def hard_mining(neg_output, neg_labels, num_hard):
    _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)))
    neg_output = torch.index_select(neg_output, 0, idcs)
    neg_labels = torch.index_select(neg_labels, 0, idcs)
    return neg_output, neg_labels
Example #47
0
    def forward(self,
                text_embedder,
                input_ids: torch.Tensor,
                mc_token_ids: Optional[torch.Tensor],
                lm_labels: Optional[torch.Tensor],
                mc_labels: Optional[torch.Tensor],
                token_type_ids: Optional[torch.Tensor],
                mode='teacher'
                ) -> List[torch.Tensor]:

        if text_embedder is not None:
            self.text_embedder = text_embedder
            self.text = self.text_embedder.tokenizer.decode
            self.eos = self.text_embedder.eos_idx
            self.usr = self.text_embedder.usr_idx
            self.sys = self.text_embedder.sys_idx
            self.pad_idx = self.text_embedder.pad_idx

        if mode == 'teacher':
            lm_loss, mc_loss, lm_logits, mc_logits, pres = self.model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids)
            # lm_loss () mc_loss () lm_logits (4, 2, 300, 50270) mc_logits (4, 2) pres.__len__() 12
            return lm_loss, mc_loss, lm_logits, mc_logits, pres

        elif mode == 'meta-train-query':
            lm_loss, lm_logits, mc_logits, pres = self.model(input_ids, lm_labels=lm_labels, token_type_ids=token_type_ids)
            return lm_loss, None, None, None, None

        # Inference
        elif mode == 'infer':

            # check last_turn is finished at system turn
            # last turn system -> next token type ids ==> user
            # last turn user -> next token type ids ==> system
            self.bsys_last_turn = self.text_embedder.sys_idx == token_type_ids[-1][-1].item()
            b_finish_eos = False
            response_rear_m1 = lm_labels

            # inference start!
            infer_iter = 0
            sys_utt = []
            history_len = input_ids[0].shape[0]
            # origin ii, tti
            origin_input_ids = copy.deepcopy(input_ids)
            origin_token_type_ids = copy.deepcopy(token_type_ids)

            # Inference while loop
            while infer_iter <= self.max_length-1:
                input_ids, token_type_ids = self.build_input(input_ids, token_type_ids, sys_utt)

                lm_logits, mc_logits, _ = self.model(input_ids, token_type_ids=token_type_ids)

                logits = lm_logits # (1, 46, 50270)
                logits = logits[0, -1, :] / self.temperature # (50270)
                logits = self.top_filtering(logits)
                probs = torch.nn.functional.softmax(logits, dim=-1)
                prev = torch.topk(probs, 1)[1] if self.no_sample else torch.multinomial(probs, 1)

                if infer_iter < self.min_length and prev.item() == self.eos:
                    while prev.item() == self.eos:
                        prev = torch.multinomial(probs, num_samples=1)

                if prev.item() == self.eos:
                    sys_utt.append(prev.item())
                    b_finish_eos = True
                    break

                infer_iter += 1
                sys_utt.append(prev.item())
            # End of inference

            b_full = False
            if b_finish_eos:
                # add eos
                input_ids, token_type_ids = self.build_input(input_ids, token_type_ids, sys_utt)
            else:
                b_full = True

            # valid_lm_loss
            valid_lm_loss = None
            try:
                if response_rear_m1 is not None:
                    # For gaining validation loss!
                    # 1. make solution if input_ids, toten_type_ids, lm_labels
                    # 2. forward and get lm_loss!

                    # solution_token_type_ids
                    if b_full:
                        gt_token_type_ids_final = token_type_ids[0, :self.max_length].unsqueeze(0)
                    else:
                        gt_response_no_m1 = response_rear_m1[response_rear_m1 != -1] # [1, 2, 3, -1, -1, -1]-> [1, 2, 3]
                        sys_token_type_num = gt_response_no_m1.shape[0] # 4

                        assert(sys_token_type_num >= 0)
                        if self.bsys_last_turn:
                            gt_resp_token_types = torch.ones((sys_token_type_num), device='cuda', dtype=torch.int64) * self.text_embedder.usr_idx
                        else:
                            gt_resp_token_types = torch.ones((sys_token_type_num), device='cuda', dtype=torch.int64) * self.text_embedder.sys_idx
                        gt_resp_token_types = gt_resp_token_types.unsqueeze(0)

                        gt_token_type_ids = torch.cat((origin_token_type_ids, gt_resp_token_types), dim=1) # generated
                        if gt_token_type_ids.shape[1] > self.max_length:
                            gt_token_type_ids = gt_token_type_ids[:, :self.max_length]
                        assert(gt_token_type_ids.shape[1] <= self.max_length)

                        token_type_pad_num = self.max_length - gt_token_type_ids.shape[1]
                        assert(token_type_pad_num >= 0)
                        token_type_pad = torch.ones((token_type_pad_num), device='cuda', dtype=torch.int64) * self.text_embedder.pad_idx
                        token_type_pad = token_type_pad.unsqueeze(0)
                        gt_token_type_ids_final = torch.cat((gt_token_type_ids, token_type_pad), dim=1)
                        if gt_token_type_ids_final.shape[1] > self.max_length:
                            gt_token_type_ids_final = gt_token_type_ids_final[:, :self.max_length]
                        assert(gt_token_type_ids_final.shape[1] <= self.max_length)

                    assert(gt_token_type_ids_final.shape[1] == self.max_length)

                    # gt_lm_labels
                    response_front_m1_num = self.max_length - response_rear_m1.shape[1]
                    assert(response_front_m1_num >= 0)
                    resp_front_m1 = torch.ones((response_front_m1_num), device='cuda', dtype=torch.int64) * -1
                    resp_front_m1 = resp_front_m1.unsqueeze(0)
                    gt_lm_labels = torch.cat((resp_front_m1, response_rear_m1), dim=1)
                    assert(gt_lm_labels.shape[1] == self.max_length)

                    # my_models' predict_input_ids
                    if b_full:
                        predict_input_ids = input_ids[:, :self.max_length]
                    else:
                        if input_ids.shape[1] > self.max_length:
                            input_ids = input_ids[:, :self.max_length]

                        input_ids_pad_num = self.max_length - input_ids.shape[1]

                        assert(input_ids_pad_num >= 0)
                        input_ids_pad = torch.ones((input_ids_pad_num), device='cuda', dtype=torch.int64) * self.text_embedder.pad_idx
                        input_ids_pad = input_ids_pad.unsqueeze(0)
                        predict_input_ids = torch.cat((input_ids, input_ids_pad), dim=1)
                    assert(predict_input_ids.shape[1] == self.max_length)

                    assert (predict_input_ids.shape == gt_token_type_ids_final.shape == gt_lm_labels.shape)
                    valid_lm_loss, lm_logits, mc_logits, pres = self.model(predict_input_ids, token_type_ids=gt_token_type_ids_final, lm_labels=gt_lm_labels)
            except:
                import ipdb; ipdb.set_trace()

            # sentence
            sentence = self.text_embedder.tokenizer.decode(sys_utt)

            # resp_tokens
            #input_ids_list = input_ids[0].cpu().numpy().tolist()
            #input_tokens = text_embedder.tokenizer.convert_ids_to_tokens(input_ids_list, skip_special_tokens=False)
            #input_tokens = [self.transform_byte2normal(text_embedder.tokenizer, text_embedder.tokenizer.byte_decoder, token) for token in input_tokens]
            #token_type_ids_list = token_type_ids[0].cpu().numpy().tolist()
            #token_type_tokens = text_embedder.tokenizer.convert_ids_to_tokens(token_type_ids_list, skip_special_tokens=False)
            #token_type_tokens = [self.transform_byte2normal(text_embedder.tokenizer, text_embedder.tokenizer.byte_decoder, token) for token in
            #                     token_type_tokens]
            #self.print_toks(input_tokens, token_type_tokens, history_len=history_len)
            #resp_idx = len(input_tokens) - 1 - input_tokens[::-1].index('<user>')
            #resp_tokens = input_tokens[resp_idx:]
            resp_tokens = text_embedder.tokenizer.convert_ids_to_tokens(sys_utt)
            resp_tokens = [self.transform_byte2normal(text_embedder.tokenizer, text_embedder.tokenizer.byte_decoder, token) for token in resp_tokens]


            return valid_lm_loss, sentence, resp_tokens, None, None
Example #48
0
    def forward(self, x, device):

        if 0:
            # MPC
            # trained wide beam number
            k = 7
            # batch size
            batch_size = 16
            # save narrow beam prediction results
            y1 = torch.zeros((10, 16, 64)).to(device)
            # save wide beam prediction results
            y2 = torch.zeros((10, 16, 16)).to(device)

            # first loop for wide beam training numbers
            for i in range(10):

                # if i = 0, full wide beam training
                if i == 0:
                    x_test = x[:, :, i, :]

                    #CNN
                    x_test = self.bn0(x_test)
                    x_test = self.conv1(x_test)
                    x_test = self.bn1(x_test)
                    x_test = F.relu(x_test)
                    x_test = self.conv2(x_test)
                    x_test = self.bn2(x_test)
                    x_test = F.relu(x_test)
                    P_dim_size = x_test.shape[2]
                    x_test = nn.MaxPool1d(kernel_size=P_dim_size)(x_test)

                    x_test = x_test.permute(2, 0, 1)
                    # predict narrow beam
                    y_test, (hn, cn) = self.lstm1(x_test)
                    # predict wide beam
                    y_test2, (hn2, cn2) = self.lstm2(x_test)

                # else, partial wide beam training
                else:
                    # select partial beams based on MPC
                    x_test = torch.zeros((batch_size, 2, 16)).to(device)
                    for b in range(batch_size):
                        x_test[b, :, max_id[b, :]] = x[b, :, i, max_id[b, :]]

                    #CNN
                    x_test = self.bn0(x_test)
                    x_test = self.conv1(x_test)
                    x_test = self.bn1(x_test)
                    x_test = F.relu(x_test)
                    x_test = self.conv2(x_test)
                    x_test = self.bn2(x_test)
                    x_test = F.relu(x_test)

                    P_dim_size = x_test.shape[2]
                    x_test = nn.MaxPool1d(kernel_size=P_dim_size)(x_test)

                    x_test = x_test.permute(2, 0, 1)
                    # predict narrow beam
                    y_test, (hn, cn) = self.lstm1(x_test, (hn, cn))
                    # predict wide beam
                    y_test2, (hn2, cn2) = self.lstm2(x_test, (hn2, cn2))

                # predict wide beam
                y_guide = self.drop2(y_test2)
                y_guide = self.fc2(y_guide)
                # predict narrow beam
                y_test = self.drop1(y_test)
                y_test = self.fc1(y_test)
                # MPC based beam selection
                max_value, max_id = torch.topk(y_guide, k)
                max_id = torch.squeeze(max_id)
                y1[i, :, :] = y_test
                y2[i, :, :] = y_guide

        # ONC
        # code structure is similar
        #if 0:
        k = 7
        batch_size = 16
        y1 = torch.zeros((10, 16, 64)).to(device)
        y2 = torch.zeros((10, 16, 16)).to(device)
        candidate_beam = torch.linspace(0, 15, steps=16)
        candidate_beam = candidate_beam.repeat(16, 1).to(device)

        for i in range(10):

            if i == 0:
                x_test = x[:, :, i, :]

                x_test = self.bn0(x_test)
                x_test = self.conv1(x_test)
                x_test = self.bn1(x_test)
                x_test = F.relu(x_test)
                x_test = self.conv2(x_test)
                x_test = self.bn2(x_test)
                x_test = F.relu(x_test)

                P_dim_size = x_test.shape[2]
                x_test = nn.MaxPool1d(kernel_size=P_dim_size)(x_test)

                x_test = x_test.permute(2, 0, 1)
                y_test, (hn, cn) = self.lstm1(x_test)
                y_test2, (hn2, cn2) = self.lstm2(x_test)

            else:
                x_test = torch.zeros((batch_size, 2, 16)).to(device)
                # ONC based beam selection
                for b in range(batch_size):
                    x_test[b, :, max_id[b, :]] = x[b, :, i, max_id[b, :]]

                x_test = self.bn0(x_test)
                x_test = self.conv1(x_test)
                x_test = self.bn1(x_test)
                x_test = F.relu(x_test)
                x_test = self.conv2(x_test)
                x_test = self.bn2(x_test)
                x_test = F.relu(x_test)

                P_dim_size = x_test.shape[2]
                x_test = nn.MaxPool1d(kernel_size=P_dim_size)(x_test)

                x_test = x_test.permute(2, 0, 1)
                y_test, (hn, cn) = self.lstm1(x_test, (hn, cn))
                y_test2, (hn2, cn2) = self.lstm2(x_test, (hn2, cn2))

            y_guide = self.drop2(y_test2)
            y_guide = self.fc2(y_guide)
            y_test = self.drop1(y_test)
            y_test = self.fc1(y_test)
            # ONC based beam selection
            max_value, max_id = torch.topk(y_guide, 1)
            max_id = torch.squeeze(max_id)
            max_id = max_id.repeat(16, 1).T
            max_value, max_id = torch.topk(-torch.abs(candidate_beam - max_id),
                                           k)
            y1[i, :, :] = y_test
            y2[i, :, :] = y_guide

        return y1, y2
Example #49
0
 def _sort_state(self, sort_mask: torch.Tensor = None):
     if sort_mask is None:
         _, sort_mask = torch.topk(
             self._scores, min(self._search_size, self._scores.size(0)))
     self._apply_slice_to_state(sort_mask)
Example #50
0
    def advance(self, log_probs, attn):
        vocab_size = log_probs.size(-1)

        # using integer division to get an integer _B without casting
        _B = log_probs.shape[0] // self.beam_size

        if self._stepwise_cov_pen and self._prev_penalty is not None:
            self.topk_log_probs += self._prev_penalty
            self.topk_log_probs -= self.global_scorer.cov_penalty(
                self._coverage + attn, self.global_scorer.beta).view(
                _B, self.beam_size)

        # force the output to be longer than self.min_length
        step = len(self)
        self.ensure_min_length(log_probs)

        # Multiply probs by the beam probability.
        log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)

        self.block_ngram_repeats(log_probs)

        # if the sequence ends now, then the penalty is the current
        # length + 1, to include the EOS token
        length_penalty = self.global_scorer.length_penalty(
            step + 1, alpha=self.global_scorer.alpha)

        # Flatten probs into a list of possibilities.
        curr_scores = log_probs / length_penalty
        curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
        torch.topk(curr_scores,  self.beam_size, dim=-1,
                   out=(self.topk_scores, self.topk_ids))

        # Recover log probs.
        # Length penalty is just a scalar. It doesn't matter if it's applied
        # before or after the topk.
        torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs)

        # Resolve beam origin and map to batch index flat representation.
        torch.div(self.topk_ids, vocab_size, out=self._batch_index)
        self._batch_index += self._beam_offset[:_B].unsqueeze(1)
        self.select_indices = self._batch_index.view(_B * self.beam_size)

        self.topk_ids.fmod_(vocab_size)  # resolve true word ids

        # Append last prediction.
        self.alive_seq = torch.cat(
            [self.alive_seq.index_select(0, self.select_indices),
             self.topk_ids.view(_B * self.beam_size, 1)], -1)
        if self.return_attention or self._cov_pen:
            current_attn = attn.index_select(1, self.select_indices)
            if step == 1:
                self.alive_attn = current_attn
                # update global state (step == 1)
                if self._cov_pen:  # coverage penalty
                    self._prev_penalty = torch.zeros_like(self.topk_log_probs)
                    self._coverage = current_attn
            else:
                self.alive_attn = self.alive_attn.index_select(
                    1, self.select_indices)
                self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
                # update global state (step > 1)
                if self._cov_pen:
                    self._coverage = self._coverage.index_select(
                        1, self.select_indices)
                    self._coverage += current_attn
                    self._prev_penalty = self.global_scorer.cov_penalty(
                        self._coverage, beta=self.global_scorer.beta).view(
                            _B, self.beam_size)

        if self._vanilla_cov_pen:
            # shape: (batch_size x beam_size, 1)
            cov_penalty = self.global_scorer.cov_penalty(
                self._coverage,
                beta=self.global_scorer.beta)
            self.topk_scores -= cov_penalty.view(_B, self.beam_size)

        self.is_finished = self.topk_ids.eq(self.eos)
        self.ensure_max_length()
def test_model(model, datasetloader, dataset):
    since = time.time()
    running_corrects = 0
    total = 0
    result_dict = {}
    for batch_index, batch_datums in enumerate(datasetloader):

        ### forward (compute predictions) ###
        model.hidden = model.init_hidden()
        if use_gpu:
            actual_label_tensor = Variable(
                batch_datums['answer_index'].cuda(),
                requires_grad=False).view(-1)
            token_sequence_tensor = Variable(
                batch_datums['question_token_ids'].cuda(), requires_grad=False)
            answer_vector_tensor = Variable(
                batch_datums['answer_vector_tensor'].cuda(),
                requires_grad=False)
            question_id = Variable(
                batch_datums['question_id'].cuda(), requires_grad=False)
        else:
            actual_label_tensor = Variable(
                batch_datums['answer_index'], requires_grad=False).view(-1)
            token_sequence_tensor = Variable(
                batch_datums['question_token_ids'], requires_grad=False)
            answer_vector_tensor = Variable(
                batch_datums['answer_vector_tensor'], requires_grad=False)
            question_id = Variable(
                batch_datums['question_id'], requires_grad=False)
        prediction_scores_tensor = model.forward(token_sequence_tensor)

        ### Statistics ###
        topK = 3
        _, predictions = torch.max(prediction_scores_tensor.data, 1)
        prediction = predictions.view(-1)
        answer_vector = answer_vector_tensor.data.view(-1)
        top_predicted_indices = torch.topk(prediction_scores_tensor.data, topK,
                                           1)[1][0].tolist()
        correctness = max([
            int(answer_vector[prediction])
            for prediction in top_predicted_indices
        ])
        running_corrects += correctness
        answers = [
            dataset.answer_dict.idx2word(idx) for idx in top_predicted_indices
        ]
        qid = int(question_id[0].data[0])
        dict_entry = {
            'question_id': qid,
            'answers': answers,
            'correctness': correctness
        }
        result_dict[qid] = dict_entry
        running_corrects += answer_vector[prediction]
        total += len(batch_datums[batch_datums.keys()[0]])

    time_elapsed = time.time() - since
    accuracy = running_corrects / float(total)
    logging.info('Testing complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logging.info('Acc: {:4f}'.format(accuracy))
    return result_dict, accuracy
Example #52
0
    def translate(self, x, trans_args, char_list=None, rnnlm=None, use_jit=False):
        """Translate source text.

        :param list x: input source text feature (T,)
        :param Namespace trans_args: argment Namespace contraining options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        self.eval()  # NOTE: this is important because self.encode() is not used
        assert isinstance(x, list)

        # make a utt list (1) to use the same interface for encoder
        if self.multilingual:
            x = to_device(self, torch.from_numpy(np.fromiter(map(int, x[0][1:]), dtype=np.int64)))
        else:
            x = to_device(self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64)))

        xs_pad = x.unsqueeze(0)
        tgt_lang = None
        if trans_args.tgt_lang:
            tgt_lang = char_list.index(trans_args.tgt_lang)
        xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang)
        enc_output, _ = self.encoder(xs_pad, None)
        h = enc_output.squeeze(0)

        logging.info('input lengths: ' + str(h.size(0)))
        # search parms
        beam = trans_args.beam_size
        penalty = trans_args.penalty

        # preprare sos
        y = self.sos
        vy = h.new_zeros(1).long()

        if trans_args.maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(trans_args.maxlenratio * h.size(0)))
        minlen = int(trans_args.minlenratio * h.size(0))
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialize hypothesis
        if rnnlm:
            hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None}
        else:
            hyp = {'score': 0.0, 'yseq': [y]}
        hyps = [hyp]
        ended_hyps = []

        import six
        traced_decoder = None
        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                vy.unsqueeze(1)
                vy[0] = hyp['yseq'][i]

                # get nbest local scores and their ids
                ys_mask = subsequent_mask(i + 1).unsqueeze(0)
                ys = torch.tensor(hyp['yseq']).unsqueeze(0)
                # FIXME: jit does not match non-jit result
                if use_jit:
                    if traced_decoder is None:
                        traced_decoder = torch.jit.trace(self.decoder.forward_one_step,
                                                         (ys, ys_mask, enc_output))
                    local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0]
                else:
                    local_att_scores = self.decoder.forward_one_step(ys, ys_mask, enc_output)[0]

                if rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy)
                    local_scores = local_att_scores + trans_args.lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores

                local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)

                for j in six.moves.range(beam):
                    new_hyp = {}
                    new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j])
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j])
                    if rnnlm:
                        new_hyp['rnnlm_prev'] = rnnlm_state
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(
                    hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug('number of pruned hypothes: ' + str(len(hyps)))
            if char_list is not None:
                logging.debug(
                    'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.eos)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        if rnnlm:  # Word LM needs to add final <eos> score
                            hyp['score'] += trans_args.lm_weight * rnnlm.final(
                                hyp['rnnlm_prev'])
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            from espnet.nets.e2e_asr_common import end_detect
            if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0:
                logging.info('end detected at %d', i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug('remeined hypothes: ' + str(len(hyps)))
            else:
                logging.info('no hypothesis. Finish decoding.')
                break

            if char_list is not None:
                for hyp in hyps:
                    logging.debug(
                        'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))

            logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))

        nbest_hyps = sorted(
            ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), trans_args.nbest)]

        # check number of hypotheis
        if len(nbest_hyps) == 0:
            logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
            # should copy becasuse Namespace will be overwritten globally
            trans_args = Namespace(**vars(trans_args))
            trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1)
            return self.translate(x, trans_args, char_list, rnnlm)

        logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
        logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
        return nbest_hyps
def visualize_model(model, datasetloader, dataset, num_questions=10, gui=False, req_correct=True):
    if gui:
        import matplotlib.pyplot as plt
    images_so_far = 0
    for batch_index, batch_datums in enumerate(datasetloader):

        ### forward ###
        model.hidden = model.init_hidden()
        if use_gpu:
            actual_label_tensor = Variable(
                batch_datums['answer_index'].cuda(),
                requires_grad=False).view(-1)
            token_sequence_tensor = Variable(
                batch_datums['question_token_ids'].cuda(), requires_grad=False)
            answer_vector_tensor = Variable(
                batch_datums['answer_vector_tensor'].cuda(),
                requires_grad=False)
        else:
            actual_label_tensor = Variable(
                batch_datums['answer_index'], requires_grad=False).view(-1)
            token_sequence_tensor = Variable(
                batch_datums['question_token_ids'], requires_grad=False)
            answer_vector_tensor = Variable(
                batch_datums['answer_vector_tensor'], requires_grad=False)
        prediction_scores_tensor = model.forward(token_sequence_tensor)

        #### show
        label_texts = []
        answer_index = batch_datums['answer_index'][0].numpy()
        qtoken_ids = batch_datums['question_token_ids'][0].numpy()
        answer_vector = batch_datums['answer_vector_tensor'][0].numpy().reshape(
            [-1])
        qid = int(batch_datums['question_id'][0].numpy())

        question = [
            q for q in dataset.questions_dict if q['question_id'] == qid
        ][0]
        question_text = question['question']
        title_text = "Q: {}".format(question_text)

        layout = dataset.qid2layout_dict[str(question['question_id'])]
        label_texts.append("L: {}".format(layout))

        topK = 3
        top_predicted_indices = torch.topk(prediction_scores_tensor.data, topK,
                                           1)[1][0].tolist()
        predicted_answers = [
            dataset.answer_dict.idx2word(answer_id)
            for answer_id in top_predicted_indices
        ]
        label_texts.append('A: ' + ' ; '.join(predicted_answers))

        image_id = question['image_id']
        raw_image_file = os.path.join(
            root_dir, 'raw_data/Images/%s/COCO_%s_%012d.jpg' %
            (dataset.set_name, dataset.set_name, int(image_id)))

        correctness = max([
            int(answer_vector[prediction])
            for prediction in top_predicted_indices
        ])

        if (req_correct is None) or not( bool(req_correct)^bool(correctness) ):
            if not gui:
                print("Question:")
                print(question_text)
                print("Answer:")
                print('\n'.join(label_texts))
                print("correctness:")
                print(bool(correctness))
            else:
                plt.figure()
                image_data = plt.imread(raw_image_file, format='jpg')
                plt.title(question_text)
                plt.xlabel('\n'.join(label_texts))
                plt.imshow(image_data)
                # plt.show()

            images_so_far += 1
            if images_so_far > num_questions:
                break
Example #54
0
    def forward(self,
                hidden_states,
                start_positions=None,
                end_positions=None,
                cls_index=None,
                is_impossible=None,
                p_mask=None):
        outputs = ()

        start_logits = self.start_logits(hidden_states, p_mask=p_mask)

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, let's remove the dimension added by batch splitting
            for x in (start_positions, end_positions, cls_index,
                      is_impossible):
                if x is not None and x.dim() > 1:
                    x.squeeze_(-1)

            # during training, compute the end logits based on the ground truth of the start position
            end_logits = self.end_logits(hidden_states,
                                         start_positions=start_positions,
                                         p_mask=p_mask)

            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

            if cls_index is not None and is_impossible is not None:
                # Predict answerability from the representation of CLS and START
                cls_logits = self.answer_class(hidden_states,
                                               start_positions=start_positions,
                                               cls_index=cls_index)
                loss_fct_cls = nn.BCEWithLogitsLoss()
                cls_loss = loss_fct_cls(cls_logits, is_impossible)

                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
                total_loss += cls_loss * 0.5

            outputs = (total_loss, ) + outputs

        else:
            # during inference, compute the end logits based on beam search
            bsz, slen, hsz = hidden_states.size()
            start_log_probs = F.softmax(start_logits,
                                        dim=-1)  # shape (bsz, slen)

            start_top_log_probs, start_top_index = torch.topk(
                start_log_probs, self.start_n_top,
                dim=-1)  # shape (bsz, start_n_top)
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(
                -1, -1, hsz)  # shape (bsz, start_n_top, hsz)
            start_states = torch.gather(
                hidden_states, -2,
                start_top_index_exp)  # shape (bsz, start_n_top, hsz)
            start_states = start_states.unsqueeze(1).expand(
                -1, slen, -1, -1)  # shape (bsz, slen, start_n_top, hsz)

            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
                start_states)  # shape (bsz, slen, start_n_top, hsz)
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
            end_logits = self.end_logits(hidden_states_expanded,
                                         start_states=start_states,
                                         p_mask=p_mask)
            end_log_probs = F.softmax(end_logits,
                                      dim=1)  # shape (bsz, slen, start_n_top)

            end_top_log_probs, end_top_index = torch.topk(
                end_log_probs, self.end_n_top,
                dim=1)  # shape (bsz, end_n_top, start_n_top)
            end_top_log_probs = end_top_log_probs.view(
                -1, self.start_n_top * self.end_n_top)
            end_top_index = end_top_index.view(
                -1, self.start_n_top * self.end_n_top)

            start_states = torch.einsum("blh,bl->bh", hidden_states,
                                        start_log_probs)
            cls_logits = self.answer_class(hidden_states,
                                           start_states=start_states,
                                           cls_index=cls_index)

            outputs = (start_top_log_probs, start_top_index, end_top_log_probs,
                       end_top_index, cls_logits) + outputs

        # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
        # or (if labels are provided) (total_loss,)
        return outputs
Example #55
0
    def parse(self, src_sent, context=None, beam_size=5):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """

        args = self.args
        primitive_vocab = self.vocab.primitive

        src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        dec_init_vec = self.init_decoder_state(last_state, last_cell)
        if args.lstm == 'parent_feed':
            h_tm1 = dec_init_vec[0], dec_init_vec[1], \
                    Variable(self.new_tensor(args.hidden_size).zero_()), \
                    Variable(self.new_tensor(args.hidden_size).zero_())
        else:
            h_tm1 = dec_init_vec

        zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_())

        hyp_scores = Variable(self.new_tensor([0.]), volatile=True)

        src_token_vocab_ids = [primitive_vocab[token] for token in src_sent]
        src_unk_pos_list = [pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id]
        # sometimes a word may appear multi-times in the source, in this case,
        # we just copy its first appearing position. Therefore we mask the words
        # appearing second and onwards to -1
        token_set = set()
        for i, tid in enumerate(src_token_vocab_ids):
            if tid in token_set:
                src_token_vocab_ids[i] = -1
            else: token_set.add(tid)

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2))

            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)
                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.att_vec_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                a_tm1_embeds = []
                for a_tm1 in actions_tm1:
                    if a_tm1:
                        if isinstance(a_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[a_tm1.production]]
                        elif isinstance(a_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        else:
                            a_tm1_embed = self.primitive_embed.weight[self.vocab.primitive[a_tm1.token]]

                        a_tm1_embeds.append(a_tm1_embed)
                    else:
                        a_tm1_embeds.append(zero_action_embed)
                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            if args.lstm == 'lstm_with_dropout':
                self.decoder_lstm.set_dropout_masks(hyp_num)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # Variable(batch_size, grammar_size)
            # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # Variable(batch_size, src_sent_len)
            primitive_copy_prob = self.src_pointer_net(src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

            # Variable(batch_size, primitive_vocab_size)
            gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            # Variable(batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # Variable(batch_size, primitive_vocab_size)
            primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(1) * gen_from_vocab_prob
            if src_unk_pos_list:
                primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

            gentoken_prev_hyp_ids = []
            gentoken_new_hyp_unks = []
            gentoken_copy_infos = []
            applyrule_new_hyp_scores = []
            applyrule_new_hyp_prod_ids = []
            applyrule_prev_hyp_ids = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id].data[0]
                            new_hyp_score = hyp.score + prod_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(prod_id)
                            applyrule_prev_hyp_ids.append(hyp_id)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)].data[0]
                        new_hyp_score = hyp.score + action_score

                        applyrule_new_hyp_scores.append(new_hyp_score)
                        applyrule_new_hyp_prod_ids.append(len(self.grammar))
                        applyrule_prev_hyp_ids.append(hyp_id)
                    else:
                        # GenToken action
                        gentoken_prev_hyp_ids.append(hyp_id)
                        hyp_copy_info = dict()  # of (token_pos, copy_prob)
                        # first, we compute copy probabilities for tokens in the source sentence
                        for token_pos, token_vocab_id in enumerate(src_token_vocab_ids):
                            if args.no_copy is False and token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id:
                                p_copy = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos]
                                primitive_prob[hyp_id, token_vocab_id] = primitive_prob[hyp_id, token_vocab_id] + p_copy

                                token = src_sent[token_pos]
                                hyp_copy_info[token] = (token_pos, p_copy.data[0])

                        # second, add the probability of copying the most probable unk word
                        if args.no_copy is False and src_unk_pos_list:
                            unk_pos = primitive_copy_prob[hyp_id][src_unk_pos_list].data.cpu().numpy().argmax()
                            unk_pos = src_unk_pos_list[unk_pos]
                            token = src_sent[unk_pos]
                            gentoken_new_hyp_unks.append(token)

                            unk_copy_score = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos]
                            primitive_prob[hyp_id, primitive_vocab.unk_id] = unk_copy_score

                            hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0])

                        gentoken_copy_infos.append(hyp_copy_info)

            new_hyp_scores = None
            if applyrule_new_hyp_scores:
                new_hyp_scores = Variable(self.new_tensor(applyrule_new_hyp_scores))
            if gentoken_prev_hyp_ids:
                primitive_log_prob = torch.log(primitive_prob)
                gen_token_new_hyp_scores = (hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores
                else: new_hyp_scores = torch.cat([new_hyp_scores, gen_token_new_hyp_scores])

            top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores,
                                                             k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()):
                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)) // primitive_prob.size(1)
                    # try:
                    copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]
                    # except:
                    #     print('k=%d' % k, file=sys.stderr)
                    #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                    #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                    #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                    #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                    #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                    #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                    #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                    #
                    #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                    #
                    #     # exit(-1)
                    #     raise ValueError()

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word[primitive_vocab.unk_id]
                    else:
                        token = primitive_vocab.id2word[token_id]

                    action = GenTokenAction(token)

                    if token in copy_info:
                        action_info.copy_from_src = True
                        action_info.src_token_position = copy_info[token][0]

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                hyp_scores = Variable(self.new_tensor([hyp.score for hyp in hypotheses]))
                t += 1
            else:
                break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
Example #56
0
def _decode(tl_heat,
            br_heat,
            tl_tag,
            br_tag,
            tl_regr,
            br_regr,
            ct_heat,
            ct_regr,
            K=100,
            kernel=1,
            ae_threshold=1,
            num_dets=1000):
    batch, cat, height, width = tl_heat.size()

    tl_heat = torch.sigmoid(tl_heat)
    br_heat = torch.sigmoid(br_heat)
    ct_heat = torch.sigmoid(ct_heat)

    # perform nms on heatmaps
    tl_heat = _nms(tl_heat, kernel=kernel)
    br_heat = _nms(br_heat, kernel=kernel)
    ct_heat = _nms(ct_heat, kernel=kernel)

    tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
    br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
    ct_scores, ct_inds, ct_clses, ct_ys, ct_xs = _topk(ct_heat, K=K)

    tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K)
    tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K)
    br_ys = br_ys.view(batch, 1, K).expand(batch, K, K)
    br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)
    ct_ys = ct_ys.view(batch, 1, K).expand(batch, K, K)
    ct_xs = ct_xs.view(batch, 1, K).expand(batch, K, K)

    if tl_regr is not None and br_regr is not None:
        tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
        tl_regr = tl_regr.view(batch, K, 1, 2)
        br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
        br_regr = br_regr.view(batch, 1, K, 2)
        ct_regr = _tranpose_and_gather_feat(ct_regr, ct_inds)
        ct_regr = ct_regr.view(batch, 1, K, 2)

        tl_xs = tl_xs + tl_regr[..., 0]
        tl_ys = tl_ys + tl_regr[..., 1]
        br_xs = br_xs + br_regr[..., 0]
        br_ys = br_ys + br_regr[..., 1]
        ct_xs = ct_xs + ct_regr[..., 0]
        ct_ys = ct_ys + ct_regr[..., 1]

    # all possible boxes based on top k corners (ignoring class)
    bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)

    tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
    tl_tag = tl_tag.view(batch, K, 1)
    br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
    br_tag = br_tag.view(batch, 1, K)
    dists = torch.abs(tl_tag - br_tag)

    tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
    br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
    scores = (tl_scores + br_scores) / 2

    # reject boxes based on classes
    tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
    br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
    cls_inds = (tl_clses != br_clses)

    # reject boxes based on distances
    dist_inds = (dists > ae_threshold)

    # reject boxes based on widths and heights
    width_inds = (br_xs < tl_xs)
    height_inds = (br_ys < tl_ys)

    scores[cls_inds] = -1
    scores[dist_inds] = -1
    scores[width_inds] = -1
    scores[height_inds] = -1

    scores = scores.view(batch, -1)
    scores, inds = torch.topk(scores, num_dets)
    scores = scores.unsqueeze(2)

    bboxes = bboxes.view(batch, -1, 4)
    bboxes = _gather_feat(bboxes, inds)

    #width = (bboxes[:,:,2] - bboxes[:,:,0]).unsqueeze(2)
    #height = (bboxes[:,:,2] - bboxes[:,:,0]).unsqueeze(2)

    clses = tl_clses.contiguous().view(batch, -1, 1)
    clses = _gather_feat(clses, inds).float()

    tl_scores = tl_scores.contiguous().view(batch, -1, 1)
    tl_scores = _gather_feat(tl_scores, inds).float()
    br_scores = br_scores.contiguous().view(batch, -1, 1)
    br_scores = _gather_feat(br_scores, inds).float()

    ct_xs = ct_xs[:, 0, :]
    ct_ys = ct_ys[:, 0, :]

    center = torch.cat([
        ct_xs.unsqueeze(2),
        ct_ys.unsqueeze(2),
        ct_clses.float().unsqueeze(2),
        ct_scores.unsqueeze(2)
    ],
                       dim=2)
    detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses],
                           dim=2)
    return detections, center
def sample_sequence(history,
                    graph,
                    tokenizer,
                    model,
                    args,
                    current_output=None):
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    padding = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])
    if current_output is None:
        current_output = []
    if (args.flatten_KB):
        history += graph['edges']
    for i in range(args.max_length):
        instance = build_input_from_segments(args,
                                             history,
                                             current_output,
                                             graph,
                                             tokenizer,
                                             with_eos=False)
        input_ids = torch.tensor(instance["input_ids"],
                                 device=args.device).unsqueeze(0)
        token_type_ids = torch.tensor(instance["token_type_ids"],
                                      device=args.device).unsqueeze(0)
        nodes_ids = None
        if (args.graph
                or args.edge_list) and len(instance["input_graph_ids"]) > 0:
            max_c = max(len(col) for col in instance["input_graph_ids"])
            temp = []
            for clmn in instance["input_graph_ids"]:
                temp.append(clmn + [padding] * (max_c - len(clmn)))
            nodes_ids = torch.tensor([temp], device=args.device)

        att_mask = None
        if (args.unilm):
            att_mask = instance["attention_mask"].unsqueeze(0).unsqueeze(0).to(
                input_ids.device)
            if (args.graph or args.edge_list):
                att_mask = att_mask.squeeze().squeeze()
                max_l = len(instance["input_ids"]) + len(
                    instance["input_graph_ids"])
                max_r = len(instance["input_graph_ids"])
                mask_padded = torch.zeros(max_l,
                                          max_l,
                                          dtype=torch.long,
                                          device=args.device)
                mask_padded[max_r:len(att_mask[0]) + max_r,
                            max_r:len(att_mask[0]) + max_r].copy_(att_mask)
                ## add missing one for row
                row_stripe_padded = torch.ones(max_r,
                                               max_r +
                                               instance["len_token_a"] + 1,
                                               dtype=torch.long,
                                               device=args.device)
                mask_padded[:max_r, :max_r + instance["len_token_a"] +
                            1].copy_(row_stripe_padded)
                ## add missing one for clmn
                cmn_stripe_padded = torch.ones(len(att_mask[0]),
                                               max_r,
                                               dtype=torch.long,
                                               device=args.device)
                mask_padded[max_r:max_r +
                            len(att_mask[0]), :max_r].copy_(cmn_stripe_padded)
                if (args.adj_graph):
                    r_net = len(
                        instance["input_graph_networks"])  ## square matrix
                    c_net = len(
                        instance["input_graph_networks"][0])  ## square matrix
                    if (r_net and c_net):
                        mask_padded[:r_net, :r_net].copy_(
                            torch.tensor(instance["input_graph_networks"],
                                         dtype=torch.long,
                                         device=args.device))
                att_mask = mask_padded.unsqueeze(0).unsqueeze(0)

        logits = model(input_ids,
                       token_type_ids=token_type_ids,
                       nodes=nodes_ids,
                       attention_mask=att_mask)
        if isinstance(logits, tuple):  # for gpt2 and maybe others
            logits = logits[0]
        logits = logits[0, -1, :] / args.temperature
        logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p)
        probs = F.softmax(logits, dim=-1)

        prev = torch.topk(
            probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1)
        if i < args.min_length and prev.item() in special_tokens_ids:
            while prev.item() in special_tokens_ids:
                if probs.max().item() == 1:
                    warnings.warn(
                        "Warning: model generating special token with probability 1."
                    )
                    break  # avoid infinitely looping over special token
                prev = torch.multinomial(probs, num_samples=1)

        if prev.item() in special_tokens_ids:
            break
        current_output.append(prev.item())

    return current_output
    def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, fstlm=None):
        '''beam search implementation

        :param Variable h:
        :param Namespace recog_args:
        :param char_list:
        :return:
        '''
        logging.info('input lengths: ' + str(h.size(0)))
        # initialization
        c_list = [self.zero_state(h.unsqueeze(0))]
        z_list = [self.zero_state(h.unsqueeze(0))]
        for l in six.moves.range(1, self.dlayers):
            c_list.append(self.zero_state(h.unsqueeze(0)))
            z_list.append(self.zero_state(h.unsqueeze(0)))
        a = None
        self.att.reset()  # reset pre-computation of h

        # search parms
        beam = recog_args.beam_size
        penalty = recog_args.penalty
        ctc_weight = recog_args.ctc_weight

        # preprate sos
        y = self.sos        
        vy = h.new_zeros(1).long()

        if recog_args.maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
        minlen = int(recog_args.minlenratio * h.size(0))
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialize hypothesis
        if rnnlm:
            hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list,
                   'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None}
        else:
            hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a}
        if fstlm is not None:
            hyp['fstlm_prev'] = None
            
        if lpz is not None:
            ctc_prefix_score = CTCPrefixScore(lpz.cpu().numpy(), 0, self.eos, np)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            if ctc_weight != 1.0:
                # pre-pruning based on attention scores
                ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
            else:
                ctc_beam = lpz.shape[-1]
        hyps = [hyp]
        ended_hyps = []
        
        rnnlm_state_prev = None    
        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                vy.unsqueeze(1)
                vy[0] = hyp['yseq'][i]
                ey = self.embed(vy)           # utt list (1) x zdim
                ey.unsqueeze(0)
                att_c, att_w = self.att(h.unsqueeze(0), [h.size(0)], hyp['z_prev'][0], hyp['a_prev'])
                ey = torch.cat((ey, att_c), dim=1)   # utt(1) x (zdim + hdim)
                z_list[0], c_list[0] = self.decoder[0](ey, (hyp['z_prev'][0], hyp['c_prev'][0]))
                for l in six.moves.range(1, self.dlayers):
                    z_list[l], c_list[l] = self.decoder[l](
                        z_list[l - 1], (hyp['z_prev'][l], hyp['c_prev'][l]))
                
                if self.fusion == 'deep_fusion' and self.rnnlm is not None:
                    rnnlm_state, lm_scores = self.rnnlm.predict(rnnlm_state_prev, vy)
                    lm_state = rnnlm_state['h2']        
                    gi = F.sigmoid(self.gate_linear(lm_state))
                    output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)            
                    rnnlm_state_prev = rnnlm_state  
                elif self.fusion == 'cold_fusion' and self.rnnlm is not None:
                    rnnlm_state, lm_scores = self.rnnlm.predict(rnnlm_state_prev, vy)
                    lm_state = F.relu(self.lm_linear(lm_scores))       
                    gi = F.sigmoid(self.gate_linear(torch.cat((lm_state, z_list[-1]), dim=1)))
                    output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)            
                    rnnlm_state_prev = rnnlm_state                                       
                else:
                    output_in = z_list[-1]
                
                # get nbest local scores and their ids
                local_att_scores = F.log_softmax(self.output(output_in), dim=1).data
                if fstlm:
                    '''local_best_scores, local_best_ids = torch.topk(local_att_scores, kenlm_beam, dim=1)
                    kenlm_state, kenlm_scores = kenlm.predict(hyp['kenlm_prev'], local_best_ids[0])                
                    local_scores = local_att_scores[:, local_best_ids[0]] + recog_args.lm_weight * torch.from_numpy(kenlm_scores)
                    local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]'''
                    fstlm_state, local_lm_scores = fstlm.predict(hyp['fstlm_prev'], vy)
                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores                    
                elif rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy)
                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores

                if lpz is not None:
                    local_best_scores, local_best_ids = torch.topk(
                        local_att_scores, ctc_beam, dim=1)
                    ctc_scores, ctc_states = ctc_prefix_score(
                        hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
                    local_scores = \
                        (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
                        + ctc_weight * to_cuda(self, torch.from_numpy(ctc_scores - hyp['ctc_score_prev']))
                    if rnnlm:
                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
                    elif fstlm:
                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
                    ##print('vy', vy)
                    ##print('local_att_scores', local_scores, local_scores.shape)
                    ##print('local_lm_scores', recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]])
                    local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    ##if not kenlm:
                    local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)

                for j in six.moves.range(beam):
                    new_hyp = {}
                    # [:] is needed!
                    new_hyp['z_prev'] = z_list[:]
                    new_hyp['c_prev'] = c_list[:]
                    new_hyp['a_prev'] = att_w[:]
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j])
                    if rnnlm:
                        new_hyp['rnnlm_prev'] = rnnlm_state
                    if fstlm:
                        new_hyp['fstlm_prev'] = fstlm_state
                    if lpz is not None:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]]
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(
                    hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug('number of pruned hypothes: ' + str(len(hyps)))
            logging.debug(
                'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.eos)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
                logging.info('end detected at %d', i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug('remeined hypothes: ' + str(len(hyps)))
            else:
                logging.info('no hypothesis. Finish decoding.')
                break

            for hyp in hyps:
                logging.debug(
                    'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))

            logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))

        nbest_hyps = sorted(
            ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]
        logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
        logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))

        # remove sos
        return nbest_hyps
Example #59
0
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_loss_path = 0.0
            running_corrects1 = 0
            running_corrects3 = 0

            # Iterate over data.
            for i, data in enumerate(dataloaders[phase]):
                # get the inputs
                inputs, labels = data

                # wrap them in Variable
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs = model(inputs)

                _, preds = torch.topk(outputs.data, 3, 1)
                loss = criterion(outputs, labels)
                pred = preds.t()
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.data[0]
                running_loss_path += loss.data[0]
                correct = pred.eq(labels.data.view(1, -1).expand_as(pred))
                running_corrects1 += torch.sum(correct[:1].view(-1).float())
                running_corrects3 += torch.sum(correct[:3].view(-1).float())

                if i % 200 == 199:
                    print('{} iter: {:.4f} Loss_path: {:.4f}'.format(phase, i+1, running_loss_path))
                    running_loss_path = 0.0

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc1 = running_corrects1 / dataset_sizes[phase]
            epoch_acc3 = running_corrects3 / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc1: {:.4f} Acc3: {:.4f}'.format(
                phase, epoch_loss, epoch_acc1, epoch_acc3))

            # deep copy the model
            if phase == 'val' and epoch_acc3 > best_acc:
                best_acc = epoch_acc3
                best_model_wts = model.state_dict()


        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    torch.save(best_model_wts, 'best_model_wts_vgg19_bn.pkl')
    torch.save(best_acc, 'best_acc_vgg19_bn.pkl')
    return model
Example #60
0
    def beam_search(self, src_sents, decode_max_time_step, beam_size=5, to_word=True):
        """
        given a not-batched source, sentence perform beam search to find the n-best
        :param src_sent: List[word_id], encoded source sentence
        :return: list[list[word_id]] top-k predicted natural language sentence in the beam
        """
        src_sents_var = nn_utils.to_input_variable(src_sents, self.src_vocab,
                                                   cuda=self.cuda, training=False, append_boundary_sym=False)

        #TODO(junxian): check if src_sents_var(src_seq_length, embed_size) is ok
        src_encodings, (last_state, last_cell) = self.encode(src_sents_var, [len(src_sents[0])])
        # (1, query_len, hidden_size * 2)
        src_encodings = src_encodings.permute(1, 0, 2)
        src_encodings_att_linear = self.att_src_linear(src_encodings)
        h_tm1 = self.init_decoder_state(last_state, last_cell)

        # tensor constructors
        new_float_tensor = src_encodings.data.new
        if self.cuda:
            new_long_tensor = torch.cuda.LongTensor
        else:
            new_long_tensor = torch.LongTensor

        att_tm1 = Variable(torch.zeros(1, self.hidden_size), volatile=True)
        hyp_scores = Variable(torch.zeros(1), volatile=True)
        if self.cuda:
            att_tm1 = att_tm1.cuda()
            hyp_scores = hyp_scores.cuda()

        eos_id = self.tgt_vocab['</s>']
        bos_id = self.tgt_vocab['<s>']
        tgt_vocab_size = len(self.tgt_vocab)

        hypotheses = [[bos_id]]
        completed_hypotheses = []
        completed_hypothesis_scores = []

        t = 0
        while len(completed_hypotheses) < beam_size and t < decode_max_time_step:
            t += 1
            hyp_num = len(hypotheses)

            expanded_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2))
            expanded_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2))

            y_tm1 = Variable(new_long_tensor([hyp[-1] for hyp in hypotheses]), volatile=True)
            y_tm1_embed = self.tgt_embed(y_tm1)

            x = torch.cat([y_tm1_embed, att_tm1], 1)

            # h_t: (hyp_num, hidden_size)
            (h_t, cell_t), att_t, score_t = self.step(x, h_tm1,
                                                      expanded_src_encodings, expanded_src_encodings_att_linear,
                                                      src_sent_masks=None)

            p_t = F.log_softmax(score_t)

            live_hyp_num = beam_size - len(completed_hypotheses)
            new_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(p_t) + p_t).view(-1)
            top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=live_hyp_num)
            prev_hyp_ids = top_new_hyp_pos / tgt_vocab_size
            word_ids = top_new_hyp_pos % tgt_vocab_size

            new_hypotheses = []

            live_hyp_ids = []
            new_hyp_scores = []
            for prev_hyp_id, word_id, new_hyp_score in zip(prev_hyp_ids.cpu().data, word_ids.cpu().data, top_new_hyp_scores.cpu().data):
                hyp_tgt_words = hypotheses[prev_hyp_id] + [word_id]
                if word_id == eos_id:
                    completed_hypotheses.append(hyp_tgt_words[1:-1])  # remove <s> and </s> in completed hypothesis
                    completed_hypothesis_scores.append(new_hyp_score)
                else:
                    new_hypotheses.append(hyp_tgt_words)
                    live_hyp_ids.append(prev_hyp_id)
                    new_hyp_scores.append(new_hyp_score)

            if len(completed_hypotheses) == beam_size:
                break

            live_hyp_ids = new_long_tensor(live_hyp_ids)
            h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
            att_tm1 = att_t[live_hyp_ids]

            hyp_scores = Variable(new_float_tensor(new_hyp_scores), volatile=True)  # new_hyp_scores[live_hyp_ids]
            hypotheses = new_hypotheses

        if len(completed_hypotheses) == 0:
            completed_hypotheses = [hypotheses[0][1:-1]]  # remove <s> and </s> in completed hypothesis
            completed_hypothesis_scores = [0.0]

        if to_word:
            for i, hyp in enumerate(completed_hypotheses):
                completed_hypotheses[i] = [self.tgt_vocab.id2word[w] for w in hyp]

        ranked_hypotheses = sorted(zip(completed_hypotheses, completed_hypothesis_scores), key=lambda x: x[1], reverse=True)

        return [hyp for hyp, score in ranked_hypotheses]