Пример #1
0
class Normal(PlumObject):

    mean = HP(type=props.REAL)
    std = HP(type=props.POSITIVE)

    def __call__(self, tensor):
        torch.nn.init.normal_(tensor, mean=self.mean, std=self.std)
Пример #2
0
class LaptopsSearchLogger(PlumObject):

    file_prefix = HP(type=props.STRING)
    search_fields = HP()
    input_fields = HP(required=False)
    reference_fields = HP(required=False)

    def __pluminit__(self):
        self._epoch = 0
        self._log_dir = None
        self._file = None
        self._fp = None

    def set_log_dir(self, log_dir):
        self._log_dir = Path(log_dir)
        self._log_dir.mkdir(exist_ok=True, parents=True)

    def __call__(self, forward_state, batch):
        
        search = resolve_getters(self.search_fields, forward_state)
        inputs = resolve_getters(self.input_fields, batch)
        references = resolve_getters(self.reference_fields, batch)
 
        for i, output in enumerate(search.output()):
            if inputs:
                print("inputs:", file=self._fp)
                #print(inputs[i], file=self._fp)
                print(batch["mr"][i], file=self._fp)
            if references:
                print("references:", file=self._fp)
                if isinstance(references[i], (list, tuple)):
                    print("\n".join(references[i]), file=self._fp)
                else:
                    print(references[i], file=self._fp)
            print("hypothesis:", file=self._fp)
            print(preproc.lexicalize(" ".join(output), inputs[i]), 
                  file=self._fp)
            print(file=self._fp)
            
            
    def next_epoch(self):
        
        self.close()

        self._epoch += 1
        self._file = self._log_dir / "{}.{}.log".format(
            self.file_prefix, self._epoch)
        self._fp = self._file.open("w")
         
    def close(self):
        if self._fp:
            self._fp.close()
Пример #3
0
class ThresholdFeature(PlumObject):

    thresholds = HP()

    def __len__(self):
        return len(self.thresholds) + 1

    def __call__(self, value):
        if not isinstance(value, (int, float)):
            raise Exception("Expecting numerical values, int or float.")
        # this should be a binary search but I'm tired.
        bin = 0

        while bin != len(self.thresholds) and value > self.thresholds[bin]:
            bin += 1

#        if bin == 0:
#            print(value, self.thresholds[bin])
#        elif bin == len(self.thresholds):
#            print(self.thresholds[bin-1], vaulue)
#        else:
#            print(self.thresholds[bin-1], value, self.thresholds[bin])

        return bin
Пример #4
0
class S2SEvaluator(PlumObject):

    batches = HP()
    searches = HP(default={})
    metrics = HP(default={})
    loggers = HP(default={})
    loss_function = HP(required=False)
    checkpoint = HP(required=False)


    def _get_default_checkpoint(self, env):
        for ckpt, md in env["checkpoints"].items():
            if md.get("default", False):
                return ckpt
        return ckpt

    def run(self, env, verbose=False):
        if self.checkpoint is None:
            ckpt = self._get_default_checkpoint(env)
        else:
            ckpt = self.checkpoint
        if ckpt is None:
            raise RuntimeError("No checkpoints found!")
        
        ckpt_path = env["checkpoints"][ckpt]["path"]
        if verbose:
            print("Reading checkpoint from {}".format(ckpt_path))
        model = plum.load(ckpt_path).eval()
        self.preflight_checks(model, env, verbose=verbose)

        if self.loss_function is not None:
            self.loss_function.reset()

        if self.metrics is not None:
            self.metrics.reset()

        self.reset_loggers(self.loggers)
        num_batches = len(self.batches)

        for step, batch in enumerate(self.batches, 1):
           
            forward_state = model(batch)
            if self.loss_function is not None:
                self.loss_function(forward_state, batch)
            if self.metrics is not None:
                self.metrics(forward_state, batch)
            self.apply_loggers(forward_state, batch, self.loggers)

            print("eval: {}/{} loss={:7.6f}".format(
                step, num_batches, self.loss_function.scalar_result()), 
                end="\r" if step < num_batches else "\n", flush=True)
        if self.metrics is not None:
            print(self.metrics.pretty_result())
        result = {
            "loss": {
                "combined": self.loss_function.scalar_result(),
                "detail": self.loss_function.compute(),
            },
            "metrics": self.valid_metrics.compute(),
        }
        
        self.log_results(result)
        print()

        self.close_loggers()

           
 
    def preflight_checks(self, model, env, verbose=False):
        if env["gpu"] > -1:
            if verbose:
                print("Moving model to device: {}".format(env["gpu"]))
            model.cuda(env["gpu"])
            if verbose:
                print("Moving batches to device: {}".format(env["gpu"]))
            self.batches.gpu = env["gpu"]

        model.search_algos.update(self.searches)

#?        if verbose:
#?            print("Logging to tensorboard directory: {}".format(
#?                env["tensorboard_dir"]))
#?        self._tb_writer = SummaryWriter(log_dir=env["tensorboard_dir"])

        log_dir = env["proj_dir"] / "logging.valid" 
        if verbose:
            print("Setting log directory: {}".format(log_dir))
        for logger in self.loggers.values():
            logger.set_log_dir(log_dir)

    def apply_loggers(self, forward_state, batch, loggers):
        for logger in loggers.values():
            logger(forward_state, batch)

    def reset_loggers(self, loggers):
        for logger in loggers.values():
            logger.next_epoch()

    def close_loggers(self):
        for logger in self.loggers.values():
            logger.close()
Пример #5
0
class E2EPredict(PlumObject):

    checkpoint = HP(required=False)
    beam_size = HP(default=1, type=props.INTEGER)
    source_vocab = HP()
    target_vocab = HP()
    input_path = HP(type=props.EXISTING_PATH)
    filename = HP()
    delex = HP(default=False, type=props.BOOLEAN)

    FIELDS = [
        "eat_type", "near", "area", "family_friendly", "customer_rating",
        "price_range", "food", "name"
    ]

    FIELD_DICT = {
        "food": [
            'French', 'Japanese', 'Chinese', 'English', 'Indian', 'Fast food',
            'Italian'
        ],
        "family_friendly": ['no', 'yes'],
        "area": ['city centre', 'riverside'],
        "near": [
            'Café Adriatic',
            'Café Sicilia',
            'Yippee Noodle Bar',
            'Café Brazil',
            'Raja Indian Cuisine',
            'Ranch',
            'Clare Hall',
            'The Bakers',
            'The Portland Arms',
            'The Sorrento',
            'All Bar One',
            'Avalon',
            'Crowne Plaza Hotel',
            'The Six Bells',
            'Rainbow Vegetarian Café',
            'Express by Holiday Inn',
            'The Rice Boat',
            'Burger King',
            'Café Rouge',
        ],
        "eat_type": ['coffee shop', 'pub', 'restaurant'],
        "customer_rating":
        ['3 out of 5', '5 out of 5', 'high', 'average', 'low', '1 out of 5'],
        "price_range": [
            'more than £30', 'high', '£20-25', 'cheap', 'less than £20',
            'moderate'
        ],
        "name": [
            'Cocum',
            'Midsummer House',
            'The Golden Curry',
            'The Vaults',
            'The Cricketers',
            'The Phoenix',
            'The Dumpling Tree',
            'Bibimbap House',
            'The Golden Palace',
            'Wildwood',
            'The Eagle',
            'Taste of Cambridge',
            'Clowns',
            'Strada',
            'The Mill',
            'The Waterman',
            'Green Man',
            'Browns Cambridge',
            'Cotto',
            'The Olive Grove',
            'Giraffe',
            'Zizzi',
            'Alimentum',
            'The Punter',
            'Aromi',
            'The Rice Boat',
            'Fitzbillies',
            'Loch Fyne',
            'The Cambridge Blue',
            'The Twenty Two',
            'Travellers Rest Beefeater',
            'Blue Spice',
            'The Plough',
            'The Wrestlers',
        ],
    }

    def run(self, env, verbose=False):

        output_path = env["proj_dir"] / "output" / self.filename
        output_path.parent.mkdir(exist_ok=True, parents=True)

        if self.checkpoint is None:
            ckpt = self._get_default_checkpoint(env)
        else:
            ckpt = self.checkpoint
        if ckpt is None:
            raise RuntimeError("No checkpoints found!")

        ckpt_path = env["checkpoints"][ckpt]["path"]
        if verbose:
            print("Loading model from {}".format(ckpt_path))
        model = plum.load(ckpt_path).eval()
        if env["gpu"] > -1:
            model.cuda(env["gpu"])
        self._gpu = env["gpu"]

        with open(self.input_path, 'r') as fp, \
                open(output_path, "w") as out_fp:
            for line in fp:
                labels = json.loads(line)["labels"]

                tokens, text = self._get_outputs(model, labels)
                data = json.dumps({
                    "labels": labels,
                    "tokens": tokens,
                    "text": text,
                })
                print(data, file=out_fp)

    def _get_outputs(self, model, labels):
        batch = self._batch_labels([labels])
        state = model.encode(batch)
        if self.beam_size > 1:
            search = plum.seq2seq.search.BeamSearch(max_steps=100,
                                                    beam_size=self.beam_size,
                                                    vocab=self.target_vocab)
        else:
            search = plum.seq2seq.search.GreedySearch(max_steps=100,
                                                      vocab=self.target_vocab)
        search(model.decoder, state)
        outputs = search.output()
        raw_tokens = outputs[0][:-1]
        postedited_output = postedit.detokenize(raw_tokens)
        postedited_output = postedit.lexicalize(postedited_output, labels)

        return raw_tokens, postedited_output

    def _labels2input(self, labels):
        inputs = [self.source_vocab.start_token]
        for field in self.FIELDS:
            value = labels.get(field, "N/A").replace(" ", "_")
            inputs.append(field.replace("_", "").upper() + "_" + value)
        if self.delex:
            if inputs[2] != "NEAR_N/A":
                inputs[2] = "NEAR_PRESENT"
            inputs.pop(-1)

        inputs.append(self.source_vocab.stop_token)
        return inputs

    def _batch_labels(self, labels_batch):
        input_tokens = torch.LongTensor(
            [[self.source_vocab[tok] for tok in self._labels2input(labels)]
             for labels in labels_batch])
        length = input_tokens.size(1)
        inputs = Variable(input_tokens.t(),
                          lengths=torch.LongTensor([length] *
                                                   len(labels_batch)),
                          length_dim=0,
                          batch_dim=1,
                          pad_value=-1)
        if self._gpu > -1:
            inputs = inputs.cuda(self._gpu)
        return {"source_inputs": inputs}

    def _get_default_checkpoint(self, env):
        for ckpt, md in env["checkpoints"].items():
            if md.get("default", False):
                return ckpt
        return ckpt
Пример #6
0
class SearchOutputLogger(PlumObject):

    file_prefix = HP(type=props.STRING)
    search_fields = HP()
    input_fields = HP(required=False)
    reference_fields = HP(required=False)

    def __pluminit__(self):
        self._epoch = 0
        self._log_dir = None
        self._file = None
        self._fp = None

    def set_log_dir(self, log_dir):
        self._log_dir = Path(log_dir)
        self._log_dir.mkdir(exist_ok=True, parents=True)

    def _apply_fields(self, item, fields):
        if not isinstance(fields, (list, tuple)):
            fields = [fields]
        for field in fields:
            if hasattr(field, "__call__"):
                item = field(item)
            else:
                item = item[field]
        return item

    def __call__(self, forward_state, batch):

        search = self._apply_fields(forward_state, self.search_fields)
        if self.input_fields:
            inputs = self._apply_fields(batch, self.input_fields)
        else:
            inputs = None

        if self.reference_fields:
            references = self._apply_fields(batch, self.reference_fields)
        else:
            references = None

        for i, output in enumerate(search.output()):
            if inputs:
                print("inputs:", file=self._fp)
                print(inputs[i], file=self._fp)
            if references:
                print("references:", file=self._fp)
                print(references[i], file=self._fp)
            print("hypothesis:", file=self._fp)
            print(" ".join(output), file=self._fp)
            print(file=self._fp)

    def next_epoch(self):

        self.close()

        self._epoch += 1
        self._file = self._log_dir / "{}.{}.log".format(
            self.file_prefix, self._epoch)
        self._fp = self._file.open("w")

    def close(self):
        if self._fp:
            self._fp.close()
Пример #7
0
class BeamSearch(PlumObject):

    max_steps = HP(default=999999, type=props.INTEGER)
    beam_size = HP(default=4, type=props.INTEGER)
    vocab = HP()

    def __pluminit__(self):
        self.reset()

    def reset(self):
        self.is_finished = False
        self.steps = 0
        self._states = []
        self._outputs = []

    def init_state(self, batch_size, encoder_state):

        beam_state = {}

        n, bs, ds = encoder_state.size()
        assert batch_size == bs

        beam_state["decoder_state"] = encoder_state.unsqueeze(2)\
            .repeat(1, 1, self.beam_size, 1)\
            .view(n, batch_size * self.beam_size, ds)

        beam_state["output"] = Variable(
            torch.LongTensor(
                [self.vocab.start_index] * batch_size * self.beam_size)\
                .view(1, -1),
            lengths=torch.LongTensor([1] * batch_size * self.beam_size),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)

        if str(encoder_state.device) != "cpu":
            beam_state["output"] = beam_state["output"].cuda(
                encoder_state.device)

        # Start the first beam of each batch with 0 log prob, and all others
        # with -inf.
        beam_state["accum_log_prob"] = (self._init_accum_log_probs(
            batch_size, encoder_state.device))

        # At the first time step no sequences have been terminated so this mask
        # is all 0s.
        beam_state["terminal_mask"] = (encoder_state.new().byte().new(
            1, batch_size * self.beam_size, 1).fill_(0))

        return beam_state

    def _init_accum_log_probs(self, batch_size, device):
        lp = torch.FloatTensor(1, batch_size, self.beam_size, 1)
        if "cuda" in str(device):
            lp = lp.cuda(device)
        lp.data.fill_(0)
        lp.data[:, :, 1:].fill_(float("-inf"))
        return lp.view(1, batch_size * self.beam_size, 1)

    def init_context(self, encoder_state):
        n, bs, ds = encoder_state["output"].size()
        beam_encoder_output = encoder_state["output"].repeat_batch_dim(
            self.beam_size)
        #.repeat(1, 1, self.beam_size, 1)\
        #.view(n, bs * self.beam_size, ds)
        return {"encoder_output": beam_encoder_output}

    def __call__(self, decoder, encoder_state, controls=None):
        self.reset()

        # TODO get batch size in a more reliable way. This will probably break
        # for cnn or transformer based models.
        batch_size = encoder_state["state"].size(1)
        search_state = self.init_state(batch_size, encoder_state["state"])
        search_context = self.init_context(encoder_state)
        active_items = search_state["decoder_state"].new(batch_size).byte() \
            .fill_(1)

        if controls is not None:
            controls = controls.repeat_batch_dim(self.beam_size)

        self._beam_scores = [list() for _ in range(batch_size)]
        self._num_complete = search_state["decoder_state"].new()\
            .long().new(batch_size).fill_(0)
        self._terminal_info = [list() for _ in range(batch_size)]

        # Perform search until we either trigger a termination condition for
        # each batch item or we reach the maximum number of search steps.
        while self.steps < self.max_steps and not self.is_finished:

            search_state = self.next_state(decoder, batch_size, search_state,
                                           search_context, active_items,
                                           controls)
            active_items = self.check_termination(search_state, active_items)
            self._is_finished = torch.all(~active_items)

            self._states.append(search_state)
            self.steps += 1

        # Finish the search by collecting final sequences, and other
        # stats.
        self._collect_search_states(active_items)
        self._incomplete_items = active_items
        self._is_finished = True

        return self

    def next_state(self, decoder, batch_size, prev_state, context,
                   active_items, controls):

        # Get next state from the decoder.
        next_state = decoder.next_state(prev_state, context, controls=controls)

        # Compute the top beam_size next outputs for each beam item.
        # topk_lps (1 x batch size x beam size x beam size)
        # candidate_outputs (1 x batch size x beam size x beam size)
        topk_lps, candidate_outputs = torch.topk(
            next_state["log_probs"].data \
                .view(1, batch_size, self.beam_size, -1),
            k=self.beam_size, dim=3)

        # If any sequence was completed last step, we should mask it's log
        # prob so that we don't generate from the terminal token.
        # slp (1 x batch_size x beam size x 1)
        slp = prev_state["accum_log_prob"] \
            .masked_fill(prev_state["terminal_mask"], float("-inf")) \
            .view(1, batch_size, self.beam_size, 1)

        # Combine next step log probs with the previous sequences cumulative
        # log probs, i.e.
        #     log P(y_t) = log P(y_<t) + log P(y_t)
        # candidate_log_probs (1 x batch size x beam size x beam size)
        candidate_log_probs = slp + topk_lps

        # Rerank and select the beam_size best options from the available
        # beam_size ** 2 candidates.
        # b_seq_lps (1 x (batch size * beam size) x 1)
        # b_scores (1 x (batch size * beam size) x 1)
        # b_output (1 x (batch size * beam size))
        # b_indices ((batch size * beam size))
        b_seq_lps, b_scores, b_output, b_indices = self._next_candidates(
            batch_size, candidate_log_probs, candidate_outputs)

        # TODO re-implement this behavior
        #next_state.stage_indexing("batch", b_indices)

        next_state = {
            "decoder_state": next_state["decoder_state"]\
                .index_select(1, b_indices),
            "output": b_output,
            "accum_log_prob": b_seq_lps,
            "beam_score": b_scores,
            "beam_indices": b_indices,
        }
        return next_state
        #exit()
        #next_state = {"decoder_state": next_state["decoder_state"]
        #print(next_state.keys())

        next_state["output"] = (b_output, ("batch", "sequence"))
        next_state["cumulative_log_probability"] = (b_seq_lps,
                                                    ("sequence", "batch",
                                                     "placeholder"))
        next_state["beam_score"] = (b_scores, ("sequence", "batch",
                                               "placeholder"))
        next_state["beam_indices"] = (b_indices, ("batch"))

        return next_state

    def _next_candidates(self, batch_size, log_probs, candidates):
        # TODO seq_lps should really be called cumulative log probs.

        # flat_beam_lps (batch size x (beam size ** 2))
        flat_beam_lps = log_probs.view(batch_size, -1)

        flat_beam_scores = flat_beam_lps / (self.steps + 1)

        # beam_seq_scores (batch size x beam size)
        # relative_indices (batch_size x beam size)
        beam_seq_scores, relative_indices = torch.topk(flat_beam_scores,
                                                       k=self.beam_size,
                                                       dim=1)

        # beam_seq_lps (batch size x beam size)
        beam_seq_lps = flat_beam_lps.gather(1, relative_indices)

        # TODO make these ahead of time.
        offset1 = (torch.arange(batch_size, device=beam_seq_lps.device) *
                   self.beam_size).view(batch_size, 1)

        offset2 = offset1 * self.beam_size

        beam_indexing = ((relative_indices // self.beam_size) + offset1) \
            .view(-1)

        # beam_seq_lps (1 x (batch_size * beam_size) x 1)
        beam_seq_lps = beam_seq_lps \
            .view(1, batch_size * self.beam_size, 1)

        # beam_seq_scores (1 x (batch_size * beam_size) x 1)
        beam_seq_scores = beam_seq_scores \
            .view(1, batch_size * self.beam_size, 1)

        # next_output (1 x (batch size * beam size))
        next_candidate_indices = (relative_indices + offset2).view(-1)
        next_output = Variable(
            candidates.view(-1)[next_candidate_indices].view(1, -1),
            lengths=candidates.new().long().new(batch_size * self.beam_size)\
                .fill_(1),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)

        return beam_seq_lps, beam_seq_scores, next_output, beam_indexing

    def check_termination(self, next_state, active_items):

        # view as batch size x beam size
        next_output = next_state["output"].data \
            .view(-1, self.beam_size)
        batch_size = next_output.size(0)

        is_complete = next_output.eq(self.vocab.stop_index)
        complete_indices = np.where(is_complete.cpu().data.numpy())

        for batch, beam in zip(*complete_indices):
            if self._num_complete[batch] == self.beam_size:
                continue
            else:
                self._num_complete[batch] += 1

                # Store step and beam that finished so we can retrace it
                # later and recover arbitrary search state item.
                self._terminal_info[batch].append(
                    (self.steps, beam + batch * self.beam_size))

                IDX = batch * self.beam_size + beam
                self._beam_scores[batch].append(
                    next_state["beam_score"][0, IDX, 0].view(1))

        next_state["terminal_mask"] = (is_complete.view(
            1, batch_size * self.beam_size, 1))
        active_items = self._num_complete < self.beam_size

        return active_items

    def _collect_search_states(self, active_items):

        batch_size = active_items.size(0)

        last_state = self._states[-1]
        last_step = self.steps - 1
        for batch in range(batch_size):
            beam = 0
            while len(self._beam_scores[batch]) < self.beam_size:
                IDX = batch * self.beam_size + beam
                self._beam_scores[batch].append(
                    last_state["beam_score"][0, IDX, 0].view(1))
                self._terminal_info[batch].append(
                    (last_step, beam + batch * self.beam_size))
                beam += 1

        # TODO consider removing beam indices from state
        beam_indices = torch.stack(
            [state["beam_indices"] for state in self._states])

        self._beam_scores = torch.stack(
            [torch.cat(bs) for bs in self._beam_scores])

        lengths = self._states[0]["output"].new(
            [[step + 1 for step, beam in self._terminal_info[batch]]
             for batch in range(batch_size)])

        selector = self._states[0]["output"].new(batch_size, self.beam_size,
                                                 lengths.max())
        mask = selector.new().byte().new(selector.size()).fill_(1)

        for batch in range(batch_size):
            for beam in range(self.beam_size):
                step, real_beam = self._terminal_info[batch][beam]
                mask[batch, beam, :step + 1].fill_(0)
                self._collect_beam(batch, real_beam, step, beam_indices,
                                   selector[batch, beam])
        selector = selector.view(batch_size * self.beam_size, -1)

        ## RESORTING HERE ##
        #if self.sort_by_score:
        # TODO make this an option again
        if True:
            self._beam_scores, I = torch.sort(self._beam_scores,
                                              dim=1,
                                              descending=True)
            offset1 = (torch.arange(batch_size, device=I.device) *
                       self.beam_size).view(batch_size, 1)
            II = I + offset1
            selector = selector[II.view(-1)]
            mask = mask.view(batch_size * self.beam_size,-1)[II]\
                .view(batch_size, self.beam_size, -1)
            lengths = lengths.gather(1, I)
        ##

        # TODO reimplement staged indexing
#        for step, sel_step in enumerate(selector.split(1, dim=1)):
#            self._states[step].stage_indexing("batch", sel_step.view(-1))

        self._output = []
        for step, sel_step in enumerate(selector.split(1, dim=1)):
            self._output.append(self._states[step]["output"].index_select(
                1, sel_step.view(-1)))
        #print(self._states[0]["output"].size())
        self._output = plum.cat([o.data for o in self._output], 0).t()\
            .view(batch_size, self.beam_size, -1)

        for i in range(batch_size):
            for j in range(self.beam_size):
                self._output[i, j, lengths[i, j]:].fill_(self.vocab.pad_index)

        self._lengths = lengths

        return self

#        for batch_out in self._output:
#        #print(self._output.t().view(batch_size)
#        #for batch_out in self._output.t().view(batch_size, self.beam_size, -1):
#            for row in batch_out:
#                print(" ".join([self.vocab[t] for t in row if t != self.vocab.pad_index]))
#            print()
#        print(lengths)
#        print(lengths.size())
#        print(batch_size)
#        exit()
#
#        states = self._states[0]
#        for state in self._states[1:]:
#            states.append(state)
#
#        self._states = states
#        self._selector = selector
#        self._lengths = lengths
#        self._selector_mask = mask.view(self.batch_size * self.beam_size, -1)
#        self._selector_mask_T = self._selector_mask.t().contiguous()

    def _collect_beam(self, batch, beam, step, beam_indices, selector_out):
        selection = [0] * beam_indices.size(0)
        selector_out[step + 1:].fill_(0)
        while step >= 0:
            selection[step] = beam
            selector_out[step].fill_(beam)
            next_beam = beam_indices[step, beam].item()
            beam = next_beam
            step -= 1

    def output(self, as_indices=False, n_best=-1):
        if n_best < 1:
            o = self._output[:, 0]
            if as_indices:
                return o
            tokens = []
            for row in o:
                tokens.append(
                    [self.vocab[t] for t in row if t != self.vocab.pad_index])
            return tokens

        elif n_best < self.beam_size:
            o = self._output[:, :n_best]
        else:
            o = self._output

        if as_indices:
            return o

        beams = []
        for beam in o:
            tokens = []
            for row in beam:
                tokens.append(
                    [self.vocab[t] for t in row if t != self.vocab.pad_index])
            beams.append(tokens)
        return beams
Пример #8
0
class TVMetrics(PlumModule):

    path = HP(type=props.EXISTING_PATH)
    search_fields = HP()
    references_fields = HP()

    def __pluminit__(self):
        self._cache = None
        self._queue = Queue(maxsize=0)
        self._thread = None
        self._thread = Thread(target=self._process_result)
        self._thread.setDaemon(True)
        self._thread.start()
        self._hyp_fp = NamedTemporaryFile("w")
        self._ref_fp = NamedTemporaryFile("w")

    def postprocess(self, tokens, mr):
        # TODO right now this is specific to the e2e dataset. Need to 
        # generalize how to do post processing. 
        tokens = [t for t in tokens if t[0] != "<" and t[-1] != ">"]
        text = " ".join(tokens)
        return preproc.lexicalize(text, mr)



    def _process_result(self):
        while True:
            hyp, refs, mr = self._queue.get()

            print(self.postprocess(hyp, mr), file=self._hyp_fp)
            #print(" ".join(hyp), file=self._hyp_fp)
            
            if isinstance(refs, (list, tuple)):
                refs = "\n".join(refs)
            
            print(refs, file=self._ref_fp, end="\n\n")

            self._queue.task_done()

    def reset(self):
        self._cache = None
        while not self._queue.empty():
            self._queue.get()
            self._queue.task_done()
        self._hyp_fp = NamedTemporaryFile("w")
        self._ref_fp = NamedTemporaryFile("w")

    def apply_fields(self, fields, obj):
        if not isinstance(fields, (list, tuple)):
            fields = [fields]
        for field in fields:
            if hasattr(field, "__call__"):
                obj = field(obj)
            else:
                obj = obj[field]
        return obj

    def forward(self, forward_state, batch):
        search = self.apply_fields(self.search_fields, forward_state)
        hypotheses = search.output()
        reference_sets = self.apply_fields(self.references_fields, batch)

        for i, (hyp, refs) in enumerate(zip(hypotheses, reference_sets)):
            self._queue.put([hyp, refs, batch["mr"][i]])

    def run_script(self):

        self._queue.join()

        self._ref_fp.flush()
        self._hyp_fp.flush()

        script_path = Path(self.path).resolve()
        result_bytes = check_output(
            [str(script_path), self._hyp_fp.name, self._ref_fp.name])
        result = json.loads(result_bytes.decode("utf8"))
        self._cache = result

        self._ref_fp = None
        self._hyp_fp = None

    def compute(self):
        if self._cache is None:
            self.run_script()
        return self._cache

    def pretty_result(self):
        return str(self.compute())
Пример #9
0
class GreedyNPAD(PlumObject):
    # Greedy Noisy Parallel Approximate Decoding for conditional recurrent
    # language model, Kyunghyun Cho 2016.

    max_steps = HP(default=999999, type=props.INTEGER)
    samples = HP(default=25, type=props.INTEGER)
    std = HP(default=0.1, type=props.REAL)
    mean = HP(default=0.0, type=props.REAL)
    vocab = HP()

    def __pluminit__(self):
        self.reset()

    def reset(self):
        self.is_finished = False
        self.steps = 0
        self._states = []
        self._outputs = []

    def init_state_context(self, encoder_state):
        batch_size = encoder_state["state"].size(1) * self.samples

        output = Variable(torch.LongTensor([self.vocab.start_index] *
                                           batch_size).view(1, -1),
                          lengths=torch.LongTensor([1] * batch_size),
                          length_dim=0,
                          batch_dim=1,
                          pad_value=self.vocab.pad_index)

        if str(encoder_state["state"].device) != "cpu":
            output = output.cuda(encoder_state["state"].device)

        layers = encoder_state["state"].size(0)
        decoder_state = encoder_state["state"].unsqueeze(2)\
            .repeat(1, 1, self.samples, 1).view(layers, batch_size, -1)
        search_state = {"output": output, "decoder_state": decoder_state}

        encoder_output = encoder_state["output"].repeat_batch_dim(self.samples)
        context = {"encoder_output": encoder_output}

        return search_state, context

    def next_state(self, decoder, prev_state, context, active_items, controls):

        # Draw noise sample and inject into decoder hidden state.
        eps = prev_state["decoder_state"].new(
            prev_state["decoder_state"].size()).normal_(
                self.mean, self.std / self.steps)
        prev_state["decoder_state"] = prev_state["decoder_state"] + eps

        # Get next state from the decoder.
        next_state = decoder.next_state(prev_state, context, controls=controls)
        output_log_probs, output = next_state["log_probs"].max(2)
        next_state["output_log_probs"] = output_log_probs
        next_state["output"] = output

        # Mask outputs if we have already completed that batch item.
        next_state["output"].data.view(-1).masked_fill_(
            ~active_items, self.vocab.pad_index)

        return next_state

    def check_termination(self, next_state, active_items):

        # Check for stop tokens and batch item inactive if so.
        nonstop_tokens = next_state["output"].data.view(-1).ne(
            self.vocab.stop_index)
        active_items = active_items.data.mul_(nonstop_tokens)

        return active_items

    def _collect_search_states(self, active_items):
        # TODO implement search states api.
        #search_state = self._states[0]
        #for next_state in self._states[1:]:
        #    search_state.append(next_state)
        #self._states = search_state
        self._outputs = torch.cat([o.data for o in self._outputs], dim=0)
        self._output_log_probs = torch.cat(
            [state["output_log_probs"].data for state in self._states], 0)
        self._output_log_probs = self._output_log_probs.masked_fill(
            self._mask_T, 0)
        avg_log_probs = (self._output_log_probs.sum(0) /
                         (~self._mask_T).float().sum(0)).view(
                             -1, self.samples)
        avg_log_probs, argsort = avg_log_probs.sort(1, descending=True)

        batch_size = avg_log_probs.size(0)
        offsets = (torch.arange(0, batch_size, device=argsort.device) \
            * self.samples).view(-1, 1)
        reindex = argsort + offsets

        self._outputs = self._outputs.index_select(1, reindex.view(-1))\
            .view(-1, batch_size, self.samples).permute(1, 2, 0)
        self._output_log_probs = self._output_log_probs.index_select(
            1, reindex.view(-1)).view(-1, batch_size, self.samples)\
            .permute(1, 2, 0)
        self._mask_T = None
        self._mask = None
        self._avg_log_probs = avg_log_probs

    def __call__(self, decoder, encoder_state, controls=None):

        self.reset()
        # TODO get batch size in a more reliable way. This will probably break
        # for cnn or transformer based models.
        batch_size = encoder_state["state"].size(1)
        search_state, context = self.init_state_context(encoder_state)

        active_items = search_state["decoder_state"]\
            .new(batch_size * self.samples).byte().fill_(1)

        step_masks = []
        # Perform search until we either trigger a termination condition for
        # each batch item or we reach the maximum number of search steps.
        while self.steps < self.max_steps and not self.is_finished:

            inactive_items = ~active_items

            # Mask any inputs that are finished, so that greedy would
            # be identitcal to forward passes.
            search_state["output"].data.view(-1).masked_fill_(
                inactive_items, self.vocab.pad_index)

            step_masks.append(inactive_items)
            self.steps += 1
            search_state = self.next_state(decoder, search_state, context,
                                           active_items, controls)

            self._states.append(search_state)
            self._outputs.append(search_state["output"].clone())

            active_items = self.check_termination(search_state, active_items)
            self.is_finished = torch.all(~active_items)

        # Finish the search by collecting final sequences, and other
        # stats.

        self._mask_T = torch.stack(step_masks)
        self._mask = self._mask_T.t().contiguous()
        self._collect_search_states(active_items)
        self._incomplete_items = active_items
        self._is_finished = True

        return self

    def __getitem__(self, key):
        if key == "output":
            return self._outputs

    def output(self, as_indices=False, n_best=-1):
        if n_best < 1:
            o = self._outputs[:, 0]
            if as_indices:
                return o
            tokens = []
            for row in o:
                tokens.append(
                    [self.vocab[t] for t in row if t != self.vocab.pad_index])
            return tokens

        elif n_best < self.samples:
            o = self._outputs[:, :n_best]
        else:
            o = self._outputs

        if as_indices:
            return o

        beams = []
        for beam in o:
            tokens = []
            for row in beam:
                tokens.append(
                    [self.vocab[t] for t in row if t != self.vocab.pad_index])
            beams.append(tokens)
        return beams
Пример #10
0
class Constant(PlumObject):

    value = HP(type=props.REAL)

    def __call__(self, tensor):
        torch.nn.init.constant_(tensor, self.value)
Пример #11
0
class ClassificationLogger(PlumObject):

    file_prefix = HP(type=props.STRING)
    input_fields = HP()
    output_fields = HP()
    target_fields = HP(required=False)
    vocab = HP(required=False)
    log_every = HP(default=1, type=props.INTEGER)

    def __pluminit__(self):
        self._epoch = 0
        self._log_dir = None
        self._file = None
        self._fp = None
        self._steps = 0

    def set_log_dir(self, log_dir):
        self._log_dir = Path(log_dir)
        self._log_dir.mkdir(exist_ok=True, parents=True)

    def __call__(self, forward_state, batch):

        self._steps += 1
        if self._steps % self.log_every != 0:
            return
        ref_inputs = resolve_getters(self.input_fields, batch)
        pred_labels = resolve_getters(self.output_fields, forward_state)\
            .tolist()

        target = resolve_getters(self.target_fields, batch)

        for i, (ref_input,
                pred_label) in enumerate(zip(ref_inputs, pred_labels)):

            if not isinstance(pred_label, str) and self.vocab is not None:
                pred_label = self.vocab[pred_label]

            print("input: {}".format(ref_input), file=self._fp)
            if target is not None:

                true_label = target[i]

                if not isinstance(true_label, str) and self.vocab is not None:
                    true_label = self.vocab[true_label]

                print("pred_label: {} target_label: {}".format(
                    pred_label, true_label),
                      file=self._fp)
            else:
                print("pred_label: {}".format(pred_label), file=self._fp)
            print(file=self._fp)

    def next_epoch(self):

        self.close()

        self._steps = 0
        self._epoch += 1
        self._file = self._log_dir / "{}.{}.log".format(
            self.file_prefix, self._epoch)
        self._fp = self._file.open("w")

    def close(self):
        if self._fp:
            self._fp.close()

    def __del__(self):
        self.close()
Пример #12
0
class GreedySearch(PlumObject):

    max_steps = HP(default=999999, type=props.INTEGER)
    vocab = HP()

    def __pluminit__(self):
        self.reset()

    def reset(self):
        self.is_finished = False
        self.steps = 0
        self._states = []
        self._outputs = []

    def init_state(self, batch_size, encoder_state):
        output = Variable(
            torch.LongTensor([self.vocab.start_index] * batch_size)\
                .view(1, -1),
            lengths=torch.LongTensor([1] * batch_size),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)

        if str(encoder_state.device) != "cpu":
            output = output.cuda(encoder_state.device)

        return {"output": output, "decoder_state": encoder_state}

    def next_state(self, decoder, prev_state, context, active_items, controls):

        # Get next state from the decoder.
        next_state = decoder.next_state(prev_state, context, controls=controls)

        # Mask outputs if we have already completed that batch item.
        next_state["output"].data.view(-1).masked_fill_(
            ~active_items, self.vocab.pad_index)

        return next_state

    def check_termination(self, next_state, active_items):

        # Check for stop tokens and batch item inactive if so.
        nonstop_tokens = next_state["output"].data.view(-1).ne(
            self.vocab.stop_index)
        active_items = active_items.data.mul_(nonstop_tokens)

        return active_items

    def _collect_search_states(self, active_items):
        # TODO implement search states api.
        #search_state = self._states[0]
        #for next_state in self._states[1:]:
        #    search_state.append(next_state)
        #self._states = search_state
        self._outputs = torch.cat([o.data for o in self._outputs], dim=0)

    def __call__(self, decoder, encoder_state, controls=None):

        self.reset()
        # TODO get batch size in a more reliable way. This will probably break
        # for cnn or transformer based models.
        batch_size = encoder_state["state"].size(1)
        search_state = self.init_state(batch_size, encoder_state["state"])
        context = {
            "encoder_output": encoder_state["output"],
        }
        active_items = search_state["decoder_state"].new(batch_size).byte() \
            .fill_(1)

        step_masks = []
        # Perform search until we either trigger a termination condition for
        # each batch item or we reach the maximum number of search steps.
        while self.steps < self.max_steps and not self.is_finished:

            inactive_items = ~active_items

            # Mask any inputs that are finished, so that greedy would
            # be identitcal to forward passes.
            search_state["output"].data.view(-1).masked_fill_(
                inactive_items, self.vocab.pad_index)

            step_masks.append(inactive_items)
            self.steps += 1
            search_state = self.next_state(decoder, search_state, context,
                                           active_items, controls)

            self._states.append(search_state)
            self._outputs.append(search_state["output"].clone())

            active_items = self.check_termination(search_state, active_items)
            self.is_finished = torch.all(~active_items)

        # Finish the search by collecting final sequences, and other
        # stats.
        self._collect_search_states(active_items)
        self._incomplete_items = active_items
        self._is_finished = True

        self._mask_T = torch.stack(step_masks)
        self._mask = self._mask_T.t().contiguous()
        return self

    def __getitem__(self, key):
        if key == "output":
            return self._outputs

    def output(self, as_indices=False):

        if as_indices:
            return self._outputs.t()

        tokens = []
        for output in self._outputs.t():
            tokens.append([
                self.vocab[index] for index in output
                if index != self.vocab.pad_index
            ])

        return tokens
Пример #13
0
class LaptopSystematicSelect(PlumObject):

    checkpoint = HP(required=False)
    beam_size = HP(default=1, type=props.INTEGER)
    source_vocab = HP()
    target_vocab = HP()
    filename = HP()

    def run(self, env, verbose=False):

        output_path = env["proj_dir"] / "output" / self.filename
        output_path.parent.mkdir(exist_ok=True, parents=True)

        if self.checkpoint is None:
            ckpt = self._get_default_checkpoint(env)
        else:
            ckpt = self.checkpoint
        if ckpt is None:
            raise RuntimeError("No checkpoints found!")

        ckpt_path = env["checkpoints"][ckpt]["path"]
        if verbose:
            print("Loading model from {}".format(ckpt_path))
        model = plum.load(ckpt_path).eval()
        if env["gpu"] > -1:
            model.cuda(env["gpu"])
        self._gpu = env["gpu"]

        samples = self.make_samples()

        with open(output_path, "w") as out_fp:
            for i, mr in enumerate(samples, 1):
                print("{}/{}".format(i, len(samples)),
                      end="\r" if i < len(samples) else "\n",
                      flush=True)

                gen_input = self.make_generator_inputs(mr)

                source = preproc.mr2source_inputs(mr)
                tokens = self._get_outputs(model, gen_input)
                data = json.dumps({
                    "source": source,
                    "mr": mr,
                    "text": " ".join(tokens),
                })
                print(data, file=out_fp, flush=True)

    def make_samples(self):

        mrs = []
        for field in VALUES[:-1]:
            opts = [
                ('', 'dontcare'),
                ('', ''),
                ('dontcare', ''),
            ]
            for opt in opts:
                mr = {"da": "?select", "fields": {}}
                item1 = {}
                if opt[0] == '':
                    item1[field] = {"lex_value": "PLACEHOLDER"}
                else:
                    item1[field] = {"no_lex_value": "dontcare"}
                item2 = {}
                if opt[1] == '':
                    item2[field] = {"lex_value": "PLACEHOLDER"}
                else:
                    item2[field] = {"no_lex_value": "dontcare"}
                mr["fields"]["item1"] = item1
                mr["fields"]["item2"] = item2
                mrs.append(mr)

        opts = [("true", "true"), ('true', 'false'), ('true', 'dontcare'),
                ("false", "true"), ('false', 'false'), ('false', 'dontcare'),
                ("dontcare", "true"), ('dontcare', 'false'),
                ('dontcare', 'dontcare')]

        for opt in opts:
            mr = {
                "da": "?select",
                "fields": {
                    "item1": {
                        "isforbusinesscomputing": {
                            "no_lex_value": opt[0]
                        },
                    },
                    "item2": {
                        "isforbusinesscomputing": {
                            "no_lex_value": opt[1]
                        },
                    },
                },
            }
            mrs.append(mr)
        return mrs

    def make_generator_inputs(self, data):
        source = preproc.mr2source_inputs(data)
        tokens = [self.source_vocab.start_token] + source \
            + [self.source_vocab.stop_token]
        inputs = Variable(torch.LongTensor(
            [[self.source_vocab[t] for t in tokens]]).t(),
                          lengths=torch.LongTensor([len(tokens)]),
                          length_dim=0,
                          batch_dim=1,
                          pad_value=self.source_vocab.pad_index)
        if self._gpu > -1:
            inputs = inputs.cuda(self._gpu)
        return {"source_inputs": inputs}

    def _get_outputs(self, model, inputs):
        state = model.encode(inputs)
        if self.beam_size > 1:
            search = plum.seq2seq.search.BeamSearch(max_steps=100,
                                                    beam_size=self.beam_size,
                                                    vocab=self.target_vocab)
        else:
            search = plum.seq2seq.search.GreedySearch(max_steps=100,
                                                      vocab=self.target_vocab)
        search(model.decoder, state)
        outputs = search.output()
        raw_tokens = outputs[0][:-1]
        return raw_tokens

    def _get_default_checkpoint(self, env):
        for ckpt, md in env["checkpoints"].items():
            if md.get("default", False):
                return ckpt
        return ckpt
Пример #14
0
class SequenceClassificationError(PlumModule):

    input_vocab = HP()
    classifier = HP()
    gpu = HP(default=-1)
    target_fields = HP()
    search_fields = HP()

    def __pluminit__(self, classifier, gpu):
        if gpu > -1:
            classifier.cuda(gpu)

    def reset(self):
        self._errors = 0
        self._total = 0

    def _make_classifier_inputs(self, batch):
        lens = []
        clf_inputs = []
        for out in batch:
            clf_inputs.append([self.input_vocab.start_index] +
                              [self.input_vocab[t] for t in out[:-1]] +
                              [self.input_vocab.stop_index])
            lens.append(len(out) + 1)
        lens = torch.LongTensor(lens)
        max_len = lens.max().item()
        clf_inputs = torch.LongTensor([
            inp + [self.input_vocab.pad_index] * (max_len - len(inp))
            for inp in clf_inputs
        ]).t()

        clf_inputs = Variable(clf_inputs,
                              lengths=lens,
                              batch_dim=1,
                              length_dim=0,
                              pad_value=self.input_vocab.pad_index)
        if self.gpu > -1:
            return clf_inputs.cuda(self.gpu)
        else:
            return clf_inputs

    def forward(self, forward_state, batch):

        self.classifier.eval()
        search_outputs = resolve_getters(self.search_fields,
                                         forward_state).output()
        clf_inputs = self._make_classifier_inputs(search_outputs)

        fs = self.classifier({"inputs": clf_inputs})
        targets = resolve_getters(self.target_fields, batch)

        if self.gpu > -1:
            targets = targets.cuda(self.gpu)

        errors = (fs["output"] != targets).long().sum().item()
        total = targets.size(0)

        self._errors += errors
        self._total += total

    def compute(self):
        return {
            "rate": self._errors / self._total if self._total > 0 else 0.,
            "count": self._errors
        }

    def pretty_result(self):
        return str(self.compute())
Пример #15
0
class E2ESystematicGeneration(PlumObject):

    FIELDS = [
        "eat_type", "near", "area", "family_friendly", "customer_rating",
        "price_range", "food", "name"
    ]

    FIELD_DICT = {
        "food": [
            'French', 'Japanese', 'Chinese', 'English', 'Indian', 'Fast food',
            'Italian'
        ],
        "family_friendly": ['no', 'yes'],
        "area": ['city centre', 'riverside'],
        "near": [
            'Café Adriatic',
            'Café Sicilia',
            'Yippee Noodle Bar',
            'Café Brazil',
            'Raja Indian Cuisine',
            'Ranch',
            'Clare Hall',
            'The Bakers',
            'The Portland Arms',
            'The Sorrento',
            'All Bar One',
            'Avalon',
            'Crowne Plaza Hotel',
            'The Six Bells',
            'Rainbow Vegetarian Café',
            'Express by Holiday Inn',
            'The Rice Boat',
            'Burger King',
            'Café Rouge',
        ],
        "eat_type": ['coffee shop', 'pub', 'restaurant'],
        "customer_rating":
        ['3 out of 5', '5 out of 5', 'high', 'average', 'low', '1 out of 5'],
        "price_range": [
            'more than £30', 'high', '£20-25', 'cheap', 'less than £20',
            'moderate'
        ],
        "name": [
            'Cocum',
            'Midsummer House',
            'The Golden Curry',
            'The Vaults',
            'The Cricketers',
            'The Phoenix',
            'The Dumpling Tree',
            'Bibimbap House',
            'The Golden Palace',
            'Wildwood',
            'The Eagle',
            'Taste of Cambridge',
            'Clowns',
            'Strada',
            'The Mill',
            'The Waterman',
            'Green Man',
            'Browns Cambridge',
            'Cotto',
            'The Olive Grove',
            'Giraffe',
            'Zizzi',
            'Alimentum',
            'The Punter',
            'Aromi',
            'The Rice Boat',
            'Fitzbillies',
            'Loch Fyne',
            'The Cambridge Blue',
            'The Twenty Two',
            'Travellers Rest Beefeater',
            'Blue Spice',
            'The Plough',
            'The Wrestlers',
        ],
    }
    #batches = HP()
    checkpoint = HP(required=False)
    mr_size = HP(type=props.INTEGER, required=True)
    batch_size = HP(default=8, type=props.INTEGER)
    beam_size = HP(default=1, type=props.INTEGER)
    source_vocab = HP()
    target_vocab = HP()
    filename = HP()
    delex = HP(default=False, type=props.BOOLEAN)

    def _get_default_checkpoint(self, env):
        for ckpt, md in env["checkpoints"].items():
            if md.get("default", False):
                return ckpt
        return ckpt

    def _field_subsets_iter(self, size):

        for subset in combinations(self.FIELDS[:-1], size - 1):
            yield ("name", ) + subset

    def _instance_iter(self, fields):
        options = [self.FIELD_DICT[f] for f in fields]
        for values in product(*options):
            yield {field: value for field, value in zip(fields, values)}

    @property
    def total_subsets(self):
        return int(comb(7, self.mr_size - 1))

    def total_settings(self, fields):
        return np.prod([len(self.FIELD_DICT[f]) for f in fields])

    def _labels2input(self, labels):
        inputs = [self.source_vocab.start_token]
        for field in self.FIELDS:
            value = labels.get(field, "N/A").replace(" ", "_")
            inputs.append(field.replace("_", "").upper() + "_" + value)
        if self.delex:
            if inputs[2] != "NEAR_N/A":
                inputs[2] = "NEAR_PRESENT"
            inputs.pop(-1)

        inputs.append(self.source_vocab.stop_token)
        return inputs

    def _batch_labels(self, labels_batch):
        input_tokens = torch.LongTensor(
            [[self.source_vocab[tok] for tok in self._labels2input(labels)]
             for labels in labels_batch])
        length = input_tokens.size(1)
        inputs = Variable(input_tokens.t(),
                          lengths=torch.LongTensor([length] *
                                                   len(labels_batch)),
                          length_dim=0,
                          batch_dim=1,
                          pad_value=-1)
        if self._gpu > -1:
            inputs = inputs.cuda(self._gpu)
        return {"source_inputs": inputs}

    def _get_outputs(self, model, labels):
        batch = self._batch_labels(labels)
        state = model.encode(batch)
        if self.beam_size > 1:
            search = plum.seq2seq.search.BeamSearch(max_steps=100,
                                                    beam_size=self.beam_size,
                                                    vocab=self.target_vocab)
        else:
            search = plum.seq2seq.search.GreedySearch(max_steps=100,
                                                      vocab=self.target_vocab)
        search(model.decoder, state)
        outputs = search.output()

        raw_tokens = []
        postedited_outputs = []
        for i, output in enumerate(outputs):
            raw_tokens.append(output[:-1])
            output = postedit.detokenize(output)
            output = postedit.lexicalize(output, labels[i])
            #            print(labels[i])
            #            print(output)
            #            input()
            postedited_outputs.append(output)

#        print()
#input()

        return raw_tokens, postedited_outputs

    def run(self, env, verbose=False):

        output_path = env["proj_dir"] / "output" / self.filename
        output_path.parent.mkdir(exist_ok=True, parents=True)

        if self.checkpoint is None:
            ckpt = self._get_default_checkpoint(env)
        else:
            ckpt = self.checkpoint
        if ckpt is None:
            raise RuntimeError("No checkpoints found!")

        ckpt_path = env["checkpoints"][ckpt]["path"]
        if verbose:
            print("Loading model from {}".format(ckpt_path))
        model = plum.load(ckpt_path).eval()
        if env["gpu"] > -1:
            model.cuda(env["gpu"])
        self._gpu = env["gpu"]

        with output_path.open("w") as fp:

            field_subsets = self._field_subsets_iter(self.mr_size)
            for i, field_subset in enumerate(field_subsets, 1):
                print("slot subset {}/{}".format(i, self.total_subsets))
                if verbose:
                    print("    slots: {}".format(field_subset))

                total_mrs = self.total_settings(field_subset)

                batch = []
                inst_iter = self._instance_iter(field_subset)
                for j, labels in enumerate(inst_iter, 1):
                    print("setting {}/{}".format(j, total_mrs),
                          end="\r" if j < total_mrs else "\n",
                          flush=True)
                    batch.append(labels)
                    if len(batch) == self.batch_size:
                        output_tokens, output_strings = self._get_outputs(
                            model, batch)
                        for labels, tokens, string in zip(
                                batch, output_tokens, output_strings):
                            data = json.dumps({
                                "labels": labels,
                                "tokens": tokens,
                                "text": string
                            })
                            print(data, file=fp)
                        batch = []

                if len(batch) > 0:
                    output_tokens, output_strings = self._get_outputs(
                        model, batch)
                    for labels, tokens, string in zip(batch, output_tokens,
                                                      output_strings):
                        data = json.dumps({
                            "labels": labels,
                            "tokens": tokens,
                            "text": string
                        })
                        print(data, file=fp)
Пример #16
0
class TVPredict(PlumObject):

    checkpoint = HP(required=False)
    beam_size = HP(default=1, type=props.INTEGER)
    source_vocab = HP()
    target_vocab = HP()
    input_path = HP(type=props.EXISTING_PATH)
    filename = HP()

    def run(self, env, verbose=False):

        output_path = env["proj_dir"] / "output" / self.filename
        output_path.parent.mkdir(exist_ok=True, parents=True)

        if self.checkpoint is None:
            ckpt = self._get_default_checkpoint(env)
        else:
            ckpt = self.checkpoint
        if ckpt is None:
            raise RuntimeError("No checkpoints found!")

        ckpt_path = env["checkpoints"][ckpt]["path"]
        if verbose:
            print("Loading model from {}".format(ckpt_path))
        model = plum.load(ckpt_path).eval()
        if env["gpu"] > -1:
            model.cuda(env["gpu"])
        self._gpu = env["gpu"]

        with open(self.input_path, 'r') as fp, \
                open(output_path, "w") as out_fp:
            for line in fp:

                data = json.loads(line)
                gen_input = self.make_generator_inputs(data)
                tokens, text = self._get_outputs(model, gen_input)
                data = json.dumps({
                    "mr": data["mr"],
                    "tokens": tokens,
                    "text": text,
                })
                print(data, file=out_fp, flush=True)

    def make_generator_inputs(self, data):
        tokens = [self.source_vocab.start_token] + data["source"] \
            + [self.source_vocab.stop_token]
        inputs = Variable(torch.LongTensor(
            [[self.source_vocab[t] for t in tokens]]).t(),
                          lengths=torch.LongTensor([len(tokens)]),
                          length_dim=0,
                          batch_dim=1,
                          pad_value=self.source_vocab.pad_index)
        if self._gpu > -1:
            inputs = inputs.cuda(self._gpu)
        return {"source_inputs": inputs}

    def _get_outputs(self, model, inputs):
        state = model.encode(inputs)
        if self.beam_size > 1:
            search = plum.seq2seq.search.BeamSearch(max_steps=100,
                                                    beam_size=self.beam_size,
                                                    vocab=self.target_vocab)
        else:
            search = plum.seq2seq.search.GreedySearch(max_steps=100,
                                                      vocab=self.target_vocab)
        search(model.decoder, state)
        outputs = search.output()
        raw_tokens = outputs[0][:-1]
        text = " ".join(raw_tokens)
        #postedited_output = postedit.detokenize(raw_tokens)
        #postedited_output = postedit.lexicalize(postedited_output, labels)

        return raw_tokens, text

    def _labels2input(self, labels):
        inputs = [self.source_vocab.start_token]
        for field in self.FIELDS:
            value = labels.get(field, "N/A").replace(" ", "_")
            inputs.append(field.replace("_", "").upper() + "_" + value)
        if self.delex:
            if inputs[2] != "NEAR_N/A":
                inputs[2] = "NEAR_PRESENT"
            inputs.pop(-1)

        inputs.append(self.source_vocab.stop_token)
        return inputs

    def _batch_labels(self, labels_batch):
        input_tokens = torch.LongTensor(
            [[self.source_vocab[tok] for tok in self._labels2input(labels)]
             for labels in labels_batch])
        length = input_tokens.size(1)
        inputs = Variable(input_tokens.t(),
                          lengths=torch.LongTensor([length] *
                                                   len(labels_batch)),
                          length_dim=0,
                          batch_dim=1,
                          pad_value=-1)
        if self._gpu > -1:
            inputs = inputs.cuda(self._gpu)
        return {"source_inputs": inputs}

    def _get_default_checkpoint(self, env):
        for ckpt, md in env["checkpoints"].items():
            if md.get("default", False):
                return ckpt
        return ckpt
Пример #17
0
class TVSystematicInformNoMatch(PlumObject):

    checkpoint = HP(required=False)
    beam_size = HP(default=1, type=props.INTEGER)
    source_vocab = HP()
    target_vocab = HP()
    filename = HP()

    def run(self, env, verbose=False):

        output_path = env["proj_dir"] / "output" / self.filename
        output_path.parent.mkdir(exist_ok=True, parents=True)

        if self.checkpoint is None:
            ckpt = self._get_default_checkpoint(env)
        else:
            ckpt = self.checkpoint
        if ckpt is None:
            raise RuntimeError("No checkpoints found!")

        ckpt_path = env["checkpoints"][ckpt]["path"]
        if verbose:
            print("Loading model from {}".format(ckpt_path))
        model = plum.load(ckpt_path).eval()
        if env["gpu"] > -1:
            model.cuda(env["gpu"])
        self._gpu = env["gpu"]

        samples = self.make_samples()

        with open(output_path, "w") as out_fp:
            for i, mr in enumerate(samples, 1):
                print("{}/{}".format(i, len(samples)),
                      end="\r" if i < len(samples) else "\n",
                      flush=True)

                gen_input = self.make_generator_inputs(mr)

                tokens = self._get_outputs(model, gen_input)
                source = preproc.mr2source_inputs(mr)
                data = json.dumps({
                    "source": source,
                    "mr": mr,
                    "text": " ".join(tokens),
                })
                print(data, file=out_fp, flush=True)

    def make_samples(self):

        settings = []
        for size in [1, 2, 3]:
            for fields in combinations(VALUES, size):
                fields = fields
                if 'hasusbport' in fields:
                    idx = fields.index('hasusbport')
                    for val in ['_true', '_false']:
                        out = list(fields)
                        out[idx] = out[idx] + val
                        settings.append(out)
                else:
                    settings.append(fields)

        mrs = []
        for setting in settings:
            mr = {"da": "inform_no_match", "fields": {}}
            for field in setting:
                if field.startswith("hasusbport"):
                    f, v = field.split("_")
                    mr["fields"][f] = {"no_lex_value": v}
                else:
                    mr["fields"][field] = {"lex_value": "PLACEHOLDER"}
            mrs.append(mr)
        return mrs

    def make_generator_inputs(self, data):
        source = preproc.mr2source_inputs(data)
        tokens = [self.source_vocab.start_token] + source \
            + [self.source_vocab.stop_token]
        inputs = Variable(torch.LongTensor(
            [[self.source_vocab[t] for t in tokens]]).t(),
                          lengths=torch.LongTensor([len(tokens)]),
                          length_dim=0,
                          batch_dim=1,
                          pad_value=self.source_vocab.pad_index)
        if self._gpu > -1:
            inputs = inputs.cuda(self._gpu)
        return {"source_inputs": inputs}

    def _get_outputs(self, model, inputs):
        state = model.encode(inputs)
        if self.beam_size > 1:
            search = plum.seq2seq.search.BeamSearch(max_steps=100,
                                                    beam_size=self.beam_size,
                                                    vocab=self.target_vocab)
        else:
            search = plum.seq2seq.search.GreedySearch(max_steps=100,
                                                      vocab=self.target_vocab)
        search(model.decoder, state)
        outputs = search.output()
        raw_tokens = outputs[0][:-1]
        return raw_tokens

    def _get_default_checkpoint(self, env):
        for ckpt, md in env["checkpoints"].items():
            if md.get("default", False):
                return ckpt
        return ckpt
Пример #18
0
class FeedForwardAttention(PlumModule):

    query_net = SM(default=Identity())
    key_net = SM(default=Identity())
    value_net = SM(default=Identity())
    hidden_size = HP(type=props.INTEGER)

    weight = P("hidden_size", tags=["weight", "fully_connected"])

    def forward(self, query, key, value=None):
        if value is None:
            value = key

        if isinstance(query, Variable):
            return self._variable_forward(query, key, value)
        else:
            return self._tensor_forward(query, key, value)

    def _variable_forward(self, query, key, value):
        key = self.key_net(key)
        query = self.query_net(query)

        key = key.permute_as_sequence_batch_features()
        query = query.permute_as_sequence_batch_features()

        assert key.dim() == query.dim() == 3
        # TODO use named tensors to allow aribitrary seq dims

        with torch.no_grad():
            mask = ~torch.einsum("qbh,kbh->qkb", [(~query.mask).float(),
                                                  (~key.mask).float()]).byte()

        query_uns = query.data.unsqueeze(query.length_dim + 1)
        key_uns = key.data.unsqueeze(query.length_dim)

        hidden = torch.tanh(key_uns + query_uns)
        scores = hidden.matmul(self.weight)

        scores = scores.masked_fill(mask, float("-inf"))

        attention = torch.softmax(scores.transpose(1, 2), dim=2)
        attention = attention.masked_fill(attention != attention, 0.)

        comp = torch.einsum("ijk,kjh->ijh",
                            [attention, self.value_net(value).data])
        comp = Variable(comp, lengths=query.lengths, length_dim=0, batch_dim=1)

        return {"attention": attention, "output": comp}

    def _tensor_forward(self, query, key, value):

        key = self.key_net(key)
        query = self.query_net(query)

        key = key.unsqueeze(0)
        query = query.unsqueeze(1)

        hidden = torch.tanh(key + query)
        scores = hidden.matmul(self.weight)

        attention_batch_last = torch.softmax(scores, dim=1)
        attention_batch_second = attention_batch_last.transpose(1, 2)
        attention_batch_first = attention_batch_last.permute(2, 0, 1)

        value_batch_first = value.transpose(1, 0)

        result = LazyDict(attention=attention_batch_second)
        comp = _curry_composition(attention_batch_first, value_batch_first,
                                  self.value_net)
        result.lazy_set("output", comp)

        return result