Ejemplo n.º 1
0
    def generateSampleBatch(
        self,
        model: typing.TypeVar("model.BertPreTrainedModel"),
        device: torch.device,
        input_ids: torch.LongTensor,
        input_features: torch.LongTensor,
        prediction_scores: torch.FloatTensor,
        position_ids: torch.LongTensor,
        is_live: bool,
    ) -> typing.Tuple[typing.List[np.array], typing.List[typing.List[int]]]:
        """
    Get a batch of input ids and iteratively fill the holes and return a batch of samples.
    """
        batch_size, sequence_length = tuple(input_ids.shape)
        input_idxs = torch.arange(batch_size).to(device)
        sample_indices = torch.full((batch_size, sequence_length),
                                    self.tokenizer.padToken).to(device)

        res_idx = 0
        samples = torch.zeros_like(input_ids)

        new_holes = self.step_batch(input_ids, input_idxs, sample_indices,
                                    None, prediction_scores, device)
        open_holes = torch.where(new_holes == True)[0]
        closed_holes = torch.where(new_holes == False)[0]

        samples[res_idx:res_idx + len(closed_holes)] = input_ids[closed_holes]
        res_idx += len(closed_holes)
        input_ids = torch.index_select(input_ids, 0, open_holes.to(device))
        attention_mask = (input_ids != self.tokenizer.padToken)

        while torch.any(new_holes):

            prediction_scores, _, _, _ = model.get_output(
                input_ids,
                attention_mask,
                position_ids[:len(input_ids)],
                input_features,
            )

            new_holes = self.step_batch(input_ids, input_idxs, sample_indices,
                                        None, prediction_scores, device)
            open_holes = torch.where(new_holes == True)[0]
            closed_holes = torch.where(new_holes == False)[0]

            samples[res_idx:res_idx +
                    len(closed_holes)] = input_ids[closed_holes]
            res_idx += len(closed_holes)
            input_ids = torch.index_select(input_ids, 0, open_holes.to(device))
            attention_mask = (input_ids != self.tokenizer.padToken)

        return samples, sample_indices, None
Ejemplo n.º 2
0
    def StepHoleSeq(
        self,
        batch: torch.LongTensor,
        batch_idxs: torch.LongTensor,
        sample_indices: torch.LongTensor,
        indices_lengths: torch.LongTensor,
        prediction_scores: torch.LongTensor,
        device,
    ) -> typing.Tuple[bool, torch.LongTensor, np.array, ]:
        """
    Applies sample step with hole predictions to input batch.

    !!!!!!WARNING!!!!!
    This function works appropriately ONLY for 1 [HOLE] per sequence.
    If more HOLES existed, then further operations would be needed to
    re-calculate the proceeding hole indices, which would lead to unnecessary
    operations. Removing this feature keeps things faster for 1 hole scenario.
    """
        endTokens = self.tokenizer.metaTokenValues
        # Array of boolean values, shows where holes are still left.
        new_hole = torch.zeros(len(batch), dtype=np.bool)

        # [seq_idx, hole_idx] of batch.
        idxs, targets = torch.where(batch == self.tokenizer.holeToken)
        # Predictions for these indices.
        predictions = self.argmax(prediction_scores[(idxs, targets)])

        for seq_idx, el_idx in zip(idxs, targets):
            # seq_idx -> indices within the batch
            # el_idx  -> element index within a sequence
            if int(predictions[seq_idx]) in endTokens:
                # Close hole, shift left one position, add pad to the end.
                batch[seq_idx] = torch.cat(
                    (batch[seq_idx][:el_idx], batch[seq_idx][el_idx + 1:],
                     torch.LongTensor([self.tokenizer.padToken]).to(device)),
                    0)
            elif int(batch[seq_idx][-1]) != self.tokenizer.padToken or (
                    indices_lengths is not None and indices_lengths[seq_idx] >=
                    FLAGS.sample_indices_limit - 1):
                # No pads remaining to the right, replace hole with prediction but don't insert new hole.
                # batch[seq_idx] = torch.cat((batch[seq_idx][:el_idx], predictions[seq_idx].unsqueeze(0), batch[seq_idx][el_idx+1:]), 0)
                batch[seq_idx][el_idx] = predictions[seq_idx]
            else:
                # Replace with prediction and keep hole.
                batch[seq_idx] = torch.cat((batch[seq_idx][:el_idx],
                                            predictions[seq_idx].unsqueeze(0),
                                            batch[seq_idx][el_idx:][:-1]), 0)
                new_hole[seq_idx] = True
            q_idx = batch_idxs[seq_idx]
            sample_indices[q_idx][el_idx] = predictions[seq_idx]
            if indices_lengths is not None:
                indices_lengths[seq_idx] += 1

        return new_hole
Ejemplo n.º 3
0
    def StepMaskSeq(
        self,
        batch: torch.LongTensor,
        batch_idxs: torch.LongTensor,
        sample_indices: torch.LongTensor,
        indices_lengths: torch.LongTensor,
        prediction_scores: torch.LongTensor,
        device,
    ) -> typing.Tuple[bool, torch.LongTensor, np.array, ]:
        """
    Applies sample step with mask predictions to input batch.
    """
        # [seq_idx, hole_idx] of batch.
        idxs, targets = torch.where(batch == self.tokenizer.maskToken)
        # Predictions for these indices.
        predictions = self.argmax(prediction_scores[(idxs, targets)])
        for p_idx, (seq_idx, el_idx) in enumerate(
                zip(idxs.flip(dims=(0, )), targets.flip(dims=(0, )))):
            # seq_idx -> indices within the batch
            # el_idx  -> element index within a sequence
            if int(predictions[idxs.size(0) - 1 -
                               p_idx]) in self.tokenizer.metaTokenValues:
                # Close hole, shift left one position, add pad to the end.
                batch[seq_idx] = torch.cat(
                    (batch[seq_idx][:el_idx], batch[seq_idx][el_idx + 1:],
                     torch.LongTensor([self.tokenizer.padToken]).to(device)),
                    0)
            else:
                # Casually replace the [MASK] with the single predicted token.
                batch[seq_idx][el_idx] = predictions[idxs.size(0) - 1 - p_idx]
            q_idx = batch_idxs[seq_idx]
            sample_indices[q_idx][el_idx] = predictions[idxs.size(0) - 1 -
                                                        p_idx]
            if indices_lengths is not None:
                indices_lengths[seq_idx] += 1

        return torch.zeros(len(batch), dtype=np.bool)
Ejemplo n.º 4
0
    def generateSampleWorkload(
        self,
        model: typing.TypeVar("model.BertPreTrainedModel"),
        device: torch.device,
        workload_input_ids: torch.LongTensor,
        workload_attention_mask: torch.LongTensor,
        workload_input_features: torch.LongTensor,
        prediction_scores: torch.FloatTensor,
        position_ids: torch.LongTensor,
        bar: 'tqdm.tqdm' = None,
    ) -> typing.Tuple[typing.List[np.array], typing.List[typing.List[int]]]:
        """
    This function receives a full workload of input ids to be sampled.
    Heavy optimisations are perfmormed to keep the GPU busy at all times.

    The workload is streamed online and when a sequence is finished it is replaced
    with a new one from the workload queue.

    Returns a fullworkload of sampled instances.
    """
        # [workload_size x batch_size x sequence_length]
        wload_size, batch_size, sequence_length = tuple(
            workload_input_ids.shape)
        # Number of sequences
        nseq = wload_size * batch_size
        # Iteration idx of workload
        w_idx = batch_size

        # Get current input_ids - attention mask.
        input_ids = workload_input_ids[0]
        input_idxs = torch.arange(batch_size).to(device)
        attention_mask = workload_attention_mask[0]
        if workload_input_features is not None:
            input_features = workload_input_features[0]
        else:
            input_features = None
        # sample indices array that will be returned.
        sample_indices = torch.full((nseq, sequence_length),
                                    self.tokenizer.padToken).to(device)

        if FLAGS.sample_indices_limit is not None:
            sidx_length = torch.full((batch_size, 1), 0).to(device)

        # Workload of input_ids and attention_mask pairs.
        # queue input_idxs ensure direct ordering from inputs -> outputs.
        queue_input_ids = torch.reshape(workload_input_ids,
                                        (1, nseq, sequence_length)).squeeze()
        queue_input_idxs = torch.arange(nseq).to(device)
        queue_attention_mask = torch.reshape(
            workload_attention_mask, (1, nseq, sequence_length)).squeeze()
        if workload_input_features is not None:
            queue_input_features = torch.reshape(
                workload_input_features, (1, nseq, sequence_length)).squeeze()

        #! This is the return queue [nseq x sequence_length].
        queue = torch.zeros(tuple(queue_input_ids.shape)).to(device)

        new_holes = self.step_batch(
            input_ids, input_idxs, sample_indices,
            sidx_length if FLAGS.sample_indices_limit else None,
            prediction_scores, device)
        open_holes = torch.where(new_holes == True)[0].to(device)
        closed_holes = torch.where(new_holes == False)[0]

        for i in closed_holes:
            queue[input_idxs[i]] = input_ids[i]
            if bar:
                bar.update(1)

        input_ids = torch.index_select(input_ids, 0, open_holes)
        input_idxs = torch.index_select(input_idxs, 0, open_holes)
        attention_mask = (input_ids != self.tokenizer.padToken)
        if FLAGS.sample_indices_limit:
            sidx_length = torch.index_select(sidx_length, 0, open_holes)

        res = batch_size - len(input_ids)
        if res > 0:
            input_ids = torch.cat(
                (input_ids, queue_input_ids[w_idx:w_idx + res]), 0)
            input_idxs = torch.cat(
                (input_idxs, queue_input_idxs[w_idx:w_idx + res]), 0)
            attention_mask = torch.cat(
                (attention_mask, queue_attention_mask[w_idx:w_idx + res]), 0)
            if input_features is not None:
                input_features = torch.cat(
                    (input_features, queue_input_features[w_idx:w_idx + res]),
                    0)
            if FLAGS.sample_indices_limit:
                sidx_length = torch.cat((sidx_length, torch.full(
                    (res, 1), 0).to(device)), 0)
            w_idx += res

        while w_idx < nseq or torch.any(new_holes):

            prediction_scores, _, _, _ = model.get_output(
                input_ids, attention_mask, position_ids[:len(input_ids)],
                input_features)
            # Array of new hole existence per seq idx
            new_holes = self.step_batch(
                input_ids, input_idxs, sample_indices,
                sidx_length if FLAGS.sample_indices_limit else None,
                prediction_scores, device)
            # Fill these holes.
            open_holes = torch.where(new_holes == True)[0].to(device)
            # Those are done.
            closed_holes = torch.where(new_holes == False)[0]

            # Add to return queue those that have finished.
            for i in closed_holes:
                queue[input_idxs[i]] = input_ids[i]
                if bar:
                    bar.update(1)

            input_ids = torch.index_select(input_ids, 0, open_holes)
            input_idxs = torch.index_select(input_idxs, 0, open_holes)
            attention_mask = (input_ids != self.tokenizer.padToken)
            if FLAGS.sample_indices_limit:
                sidx_length = torch.index_select(sidx_length, 0, open_holes)

            res = batch_size - len(input_ids)
            if res > 0:
                input_ids = torch.cat(
                    (input_ids, queue_input_ids[w_idx:w_idx + res]), 0)
                input_idxs = torch.cat(
                    (input_idxs, queue_input_idxs[w_idx:w_idx + res]), 0)
                attention_mask = torch.cat(
                    (attention_mask, queue_attention_mask[w_idx:w_idx + res]),
                    0)
                if input_features is not None:
                    input_features = torch.cat(
                        (input_features,
                         queue_input_features[w_idx:w_idx + res]), 0)
                if FLAGS.sample_indices_limit:
                    sidx_length = torch.cat(
                        (sidx_length, torch.full((res, 1), 0).to(device)), 0)
                w_idx += res
        return queue, sample_indices
Ejemplo n.º 5
0
    def StepTrainingSeq(
        self,
        seq: torch.LongTensor,
        prediction_scores: torch.FloatTensor,
    ) -> typing.Tuple[bool, torch.LongTensor, np.array]:
        """
    Applies step predictions to input sequence.
    Specifically optimized for training; does not compute sample indices for speed-up.
    """
        seq_length = tuple(seq.shape)[0]
        allowed_incr = (seq_length -
                        int(torch.where(seq == self.tokenizer.padToken)[0][0])
                        if self.tokenizer.padToken in seq else 0)

        endTokens = self.tokenizer.metaTokenValues
        closed_hole = np.zeros(seq_length, dtype=np.bool)
        new_hole = np.zeros(seq_length, dtype=np.bool)
        temp_seq = seq.numpy().copy()

        for target_idx in torch.where((seq == self.tokenizer.holeToken)
                                      | (seq == self.tokenizer.maskToken))[0]:
            idx = int(target_idx)
            prediction = int(self.argmax(prediction_scores[target_idx]))
            is_hole = temp_seq[idx] == self.tokenizer.holeToken

            if prediction in endTokens:
                # Model predicted sth that will close the hole.
                closed_hole[idx] = True
                continue

            # We replace the hole with a prediction
            temp_seq[idx] = prediction
            rem_adds = allowed_incr + np.sum(closed_hole) - np.sum(new_hole)
            if is_hole and rem_adds:
                # if this was a hole and we have more empty space, reinsert the hole
                new_hole[idx] = True

        new_seq = np.full(seq_length, self.tokenizer.padToken, dtype=np.int64)
        new_idx = 0
        for idx, t in enumerate(temp_seq):
            if closed_hole[idx]:
                continue
            try:
                new_seq[new_idx] = t
            except IndexError:
                l.logger().info("seq: {}".format(
                    self.tokenizer.tokensToString(
                        [x for x in seq.cpu().numpy()])))
                l.logger().info("temp_seq {}".format(
                    self.tokenizer.tokensToString([x for x in temp_seq])))
                l.logger().info("pred idx: {}".format(
                    torch.where((seq == self.tokenizer.holeToken)
                                | (seq == self.tokenizer.maskToken))[0]))
                l.logger().info("pred_toks {}".format(
                    self.tokenizer.tokensToString([
                        int(self.argmax(prediction_scores[idx])) for idx in
                        torch.where((seq == self.tokenizer.holeToken)
                                    | (seq == self.tokenizer.maskToken))[0]
                    ])))
                l.logger().info("allowed_incr: {}".format(allowed_incr))
                l.logger().info("new_hole: {}".format(new_hole))
                l.logger().info("closed_hole: {}".format(closed_hole))
            new_idx += 1
            if new_hole[idx]:
                try:
                    new_seq[new_idx] = self.tokenizer.holeToken
                except IndexError:
                    l.logger().warn("seq: {}".format(
                        self.tokenizer.tokensToString(
                            [x for x in seq.cpu().numpy()])))
                    l.logger().warn("temp_seq {}".format(
                        self.tokenizer.tokensToString([x for x in temp_seq])))
                    l.logger().warn("pred idx: {}".format(
                        torch.where((seq == self.tokenizer.holeToken)
                                    | (seq == self.tokenizer.maskToken))[0]))
                    l.logger().warn("pred_toks {}".format(
                        self.tokenizer.tokensToString([
                            int(self.argmax(prediction_scores[idx])) for idx in
                            torch.where((seq == self.tokenizer.holeToken)
                                        | (seq == self.tokenizer.maskToken))[0]
                        ])))
                    l.logger().warn("allowed_incr: {}".format(allowed_incr))
                    l.logger().warn("new_hole: {}".format(new_hole))
                    l.logger().warn("closed_hole: {}".format(closed_hole))
                new_idx += 1
            if new_idx >= seq_length:
                break

        new_seq = torch.LongTensor([new_seq])
        attention_mask = (new_seq != self.tokenizer.padToken)
        return np.any(new_hole), new_seq, attention_mask