class TVsSearchLogger(PlumObject): file_prefix = HP(type=props.STRING) search_fields = HP() input_fields = HP(required=False) reference_fields = HP(required=False) def __pluminit__(self): self._epoch = 0 self._log_dir = None self._file = None self._fp = None def set_log_dir(self, log_dir): self._log_dir = Path(log_dir) self._log_dir.mkdir(exist_ok=True, parents=True) def __call__(self, forward_state, batch): search = resolve_getters(self.search_fields, forward_state) inputs = resolve_getters(self.input_fields, batch) references = resolve_getters(self.reference_fields, batch) for i, output in enumerate(search.output()): if inputs: print("inputs:", file=self._fp) #print(inputs[i], file=self._fp) print(batch["mr"][i], file=self._fp) if references: print("references:", file=self._fp) if isinstance(references[i], (list, tuple)): print("\n".join(references[i]), file=self._fp) else: print(references[i], file=self._fp) print("hypothesis:", file=self._fp) print(preproc.lexicalize(" ".join(output), inputs[i]), file=self._fp) print(file=self._fp) def next_epoch(self): self.close() self._epoch += 1 self._file = self._log_dir / "{}.{}.log".format( self.file_prefix, self._epoch) self._fp = self._file.open("w") def close(self): if self._fp: self._fp.close()
class ThresholdFeature(PlumObject): thresholds = HP() def __len__(self): return len(self.thresholds) + 1 def __call__(self, value): if not isinstance(value, (int, float)): raise Exception("Expecting numerical values, int or float.") # this should be a binary search but I'm tired. bin = 0 while bin != len(self.thresholds) and value > self.thresholds[bin]: bin += 1 # if bin == 0: # print(value, self.thresholds[bin]) # elif bin == len(self.thresholds): # print(self.thresholds[bin-1], vaulue) # else: # print(self.thresholds[bin-1], value, self.thresholds[bin]) return bin
class TVPredict(PlumObject): checkpoint = HP(required=False) beam_size = HP(default=1, type=props.INTEGER) source_vocab = HP() target_vocab = HP() input_path = HP(type=props.EXISTING_PATH) filename = HP() def run(self, env, verbose=False): output_path = env["proj_dir"] / "output" / self.filename output_path.parent.mkdir(exist_ok=True, parents=True) if self.checkpoint is None: ckpt = self._get_default_checkpoint(env) else: ckpt = self.checkpoint if ckpt is None: raise RuntimeError("No checkpoints found!") ckpt_path = env["checkpoints"][ckpt]["path"] if verbose: print("Loading model from {}".format(ckpt_path)) model = plum.load(ckpt_path).eval() if env["gpu"] > -1: model.cuda(env["gpu"]) self._gpu = env["gpu"] with open(self.input_path, 'r') as fp, \ open(output_path, "w") as out_fp: for line in fp: data = json.loads(line) gen_input = self.make_generator_inputs(data) tokens, text = self._get_outputs(model, gen_input) data = json.dumps({ "mr": data["mr"], "tokens": tokens, "text": text, }) print(data, file=out_fp, flush=True) def make_generator_inputs(self, data): tokens = [self.source_vocab.start_token] + data["source"] \ + [self.source_vocab.stop_token] inputs = Variable(torch.LongTensor( [[self.source_vocab[t] for t in tokens]]).t(), lengths=torch.LongTensor([len(tokens)]), length_dim=0, batch_dim=1, pad_value=self.source_vocab.pad_index) if self._gpu > -1: inputs = inputs.cuda(self._gpu) return {"source_inputs": inputs} def _get_outputs(self, model, inputs): state = model.encode(inputs) if self.beam_size > 1: search = plum.seq2seq.search.BeamSearch(max_steps=100, beam_size=self.beam_size, vocab=self.target_vocab) else: search = plum.seq2seq.search.GreedySearch(max_steps=100, vocab=self.target_vocab) search(model.decoder, state) outputs = search.output() raw_tokens = outputs[0][:-1] text = " ".join(raw_tokens) #postedited_output = postedit.detokenize(raw_tokens) #postedited_output = postedit.lexicalize(postedited_output, labels) return raw_tokens, text def _labels2input(self, labels): inputs = [self.source_vocab.start_token] for field in self.FIELDS: value = labels.get(field, "N/A").replace(" ", "_") inputs.append(field.replace("_", "").upper() + "_" + value) if self.delex: if inputs[2] != "NEAR_N/A": inputs[2] = "NEAR_PRESENT" inputs.pop(-1) inputs.append(self.source_vocab.stop_token) return inputs def _batch_labels(self, labels_batch): input_tokens = torch.LongTensor( [[self.source_vocab[tok] for tok in self._labels2input(labels)] for labels in labels_batch]) length = input_tokens.size(1) inputs = Variable(input_tokens.t(), lengths=torch.LongTensor([length] * len(labels_batch)), length_dim=0, batch_dim=1, pad_value=-1) if self._gpu > -1: inputs = inputs.cuda(self._gpu) return {"source_inputs": inputs} def _get_default_checkpoint(self, env): for ckpt, md in env["checkpoints"].items(): if md.get("default", False): return ckpt return ckpt
class S2SEvaluator(PlumObject): batches = HP() searches = HP(default={}) metrics = HP(default={}) loggers = HP(default={}) loss_function = HP(required=False) checkpoint = HP(required=False) def _get_default_checkpoint(self, env): for ckpt, md in env["checkpoints"].items(): if md.get("default", False): return ckpt return ckpt def run(self, env, verbose=False): if self.checkpoint is None: ckpt = self._get_default_checkpoint(env) else: ckpt = self.checkpoint if ckpt is None: raise RuntimeError("No checkpoints found!") ckpt_path = env["checkpoints"][ckpt]["path"] if verbose: print("Reading checkpoint from {}".format(ckpt_path)) model = plum.load(ckpt_path).eval() self.preflight_checks(model, env, verbose=verbose) if self.loss_function is not None: self.loss_function.reset() if self.metrics is not None: self.metrics.reset() self.reset_loggers(self.loggers) num_batches = len(self.batches) for step, batch in enumerate(self.batches, 1): forward_state = model(batch) if self.loss_function is not None: self.loss_function(forward_state, batch) if self.metrics is not None: self.metrics(forward_state, batch) self.apply_loggers(forward_state, batch, self.loggers) print("eval: {}/{} loss={:7.6f}".format( step, num_batches, self.loss_function.scalar_result()), end="\r" if step < num_batches else "\n", flush=True) if self.metrics is not None: print(self.metrics.pretty_result()) result = { "loss": { "combined": self.loss_function.scalar_result(), "detail": self.loss_function.compute(), }, "metrics": self.valid_metrics.compute(), } self.log_results(result) print() self.close_loggers() def preflight_checks(self, model, env, verbose=False): if env["gpu"] > -1: if verbose: print("Moving model to device: {}".format(env["gpu"])) model.cuda(env["gpu"]) if verbose: print("Moving batches to device: {}".format(env["gpu"])) self.batches.gpu = env["gpu"] model.search_algos.update(self.searches) #? if verbose: #? print("Logging to tensorboard directory: {}".format( #? env["tensorboard_dir"])) #? self._tb_writer = SummaryWriter(log_dir=env["tensorboard_dir"]) log_dir = env["proj_dir"] / "logging.valid" if verbose: print("Setting log directory: {}".format(log_dir)) for logger in self.loggers.values(): logger.set_log_dir(log_dir) def apply_loggers(self, forward_state, batch, loggers): for logger in loggers.values(): logger(forward_state, batch) def reset_loggers(self, loggers): for logger in loggers.values(): logger.next_epoch() def close_loggers(self): for logger in self.loggers.values(): logger.close()
class E2EPredict(PlumObject): checkpoint = HP(required=False) beam_size = HP(default=1, type=props.INTEGER) source_vocab = HP() target_vocab = HP() input_path = HP(type=props.EXISTING_PATH) filename = HP() delex = HP(default=False, type=props.BOOLEAN) FIELDS = [ "eat_type", "near", "area", "family_friendly", "customer_rating", "price_range", "food", "name" ] FIELD_DICT = { "food": [ 'French', 'Japanese', 'Chinese', 'English', 'Indian', 'Fast food', 'Italian' ], "family_friendly": ['no', 'yes'], "area": ['city centre', 'riverside'], "near": [ 'Café Adriatic', 'Café Sicilia', 'Yippee Noodle Bar', 'Café Brazil', 'Raja Indian Cuisine', 'Ranch', 'Clare Hall', 'The Bakers', 'The Portland Arms', 'The Sorrento', 'All Bar One', 'Avalon', 'Crowne Plaza Hotel', 'The Six Bells', 'Rainbow Vegetarian Café', 'Express by Holiday Inn', 'The Rice Boat', 'Burger King', 'Café Rouge', ], "eat_type": ['coffee shop', 'pub', 'restaurant'], "customer_rating": ['3 out of 5', '5 out of 5', 'high', 'average', 'low', '1 out of 5'], "price_range": [ 'more than £30', 'high', '£20-25', 'cheap', 'less than £20', 'moderate' ], "name": [ 'Cocum', 'Midsummer House', 'The Golden Curry', 'The Vaults', 'The Cricketers', 'The Phoenix', 'The Dumpling Tree', 'Bibimbap House', 'The Golden Palace', 'Wildwood', 'The Eagle', 'Taste of Cambridge', 'Clowns', 'Strada', 'The Mill', 'The Waterman', 'Green Man', 'Browns Cambridge', 'Cotto', 'The Olive Grove', 'Giraffe', 'Zizzi', 'Alimentum', 'The Punter', 'Aromi', 'The Rice Boat', 'Fitzbillies', 'Loch Fyne', 'The Cambridge Blue', 'The Twenty Two', 'Travellers Rest Beefeater', 'Blue Spice', 'The Plough', 'The Wrestlers', ], } def run(self, env, verbose=False): output_path = env["proj_dir"] / "output" / self.filename output_path.parent.mkdir(exist_ok=True, parents=True) if self.checkpoint is None: ckpt = self._get_default_checkpoint(env) else: ckpt = self.checkpoint if ckpt is None: raise RuntimeError("No checkpoints found!") ckpt_path = env["checkpoints"][ckpt]["path"] if verbose: print("Loading model from {}".format(ckpt_path)) model = plum.load(ckpt_path).eval() if env["gpu"] > -1: model.cuda(env["gpu"]) self._gpu = env["gpu"] with open(self.input_path, 'r') as fp, \ open(output_path, "w") as out_fp: for line in fp: labels = json.loads(line)["labels"] tokens, text = self._get_outputs(model, labels) data = json.dumps({ "labels": labels, "tokens": tokens, "text": text, }) print(data, file=out_fp) def _get_outputs(self, model, labels): batch = self._batch_labels([labels]) state = model.encode(batch) if self.beam_size > 1: search = plum.seq2seq.search.BeamSearch(max_steps=100, beam_size=self.beam_size, vocab=self.target_vocab) else: search = plum.seq2seq.search.GreedySearch(max_steps=100, vocab=self.target_vocab) search(model.decoder, state) outputs = search.output() raw_tokens = outputs[0][:-1] postedited_output = postedit.detokenize(raw_tokens) postedited_output = postedit.lexicalize(postedited_output, labels) return raw_tokens, postedited_output def _labels2input(self, labels): inputs = [self.source_vocab.start_token] for field in self.FIELDS: value = labels.get(field, "N/A").replace(" ", "_") inputs.append(field.replace("_", "").upper() + "_" + value) if self.delex: if inputs[2] != "NEAR_N/A": inputs[2] = "NEAR_PRESENT" inputs.pop(-1) inputs.append(self.source_vocab.stop_token) return inputs def _batch_labels(self, labels_batch): input_tokens = torch.LongTensor( [[self.source_vocab[tok] for tok in self._labels2input(labels)] for labels in labels_batch]) length = input_tokens.size(1) inputs = Variable(input_tokens.t(), lengths=torch.LongTensor([length] * len(labels_batch)), length_dim=0, batch_dim=1, pad_value=-1) if self._gpu > -1: inputs = inputs.cuda(self._gpu) return {"source_inputs": inputs} def _get_default_checkpoint(self, env): for ckpt, md in env["checkpoints"].items(): if md.get("default", False): return ckpt return ckpt
class SearchOutputLogger(PlumObject): file_prefix = HP(type=props.STRING) search_fields = HP() input_fields = HP(required=False) reference_fields = HP(required=False) def __pluminit__(self): self._epoch = 0 self._log_dir = None self._file = None self._fp = None def set_log_dir(self, log_dir): self._log_dir = Path(log_dir) self._log_dir.mkdir(exist_ok=True, parents=True) def _apply_fields(self, item, fields): if not isinstance(fields, (list, tuple)): fields = [fields] for field in fields: if hasattr(field, "__call__"): item = field(item) else: item = item[field] return item def __call__(self, forward_state, batch): search = self._apply_fields(forward_state, self.search_fields) if self.input_fields: inputs = self._apply_fields(batch, self.input_fields) else: inputs = None if self.reference_fields: references = self._apply_fields(batch, self.reference_fields) else: references = None for i, output in enumerate(search.output()): if inputs: print("inputs:", file=self._fp) print(inputs[i], file=self._fp) if references: print("references:", file=self._fp) print(references[i], file=self._fp) print("hypothesis:", file=self._fp) print(" ".join(output), file=self._fp) print(file=self._fp) def next_epoch(self): self.close() self._epoch += 1 self._file = self._log_dir / "{}.{}.log".format( self.file_prefix, self._epoch) self._fp = self._file.open("w") def close(self): if self._fp: self._fp.close()
class BeamSearch(PlumObject): max_steps = HP(default=999999, type=props.INTEGER) beam_size = HP(default=4, type=props.INTEGER) vocab = HP() def __pluminit__(self): self.reset() def reset(self): self.is_finished = False self.steps = 0 self._states = [] self._outputs = [] def init_state(self, batch_size, encoder_state): beam_state = {} n, bs, ds = encoder_state.size() assert batch_size == bs beam_state["decoder_state"] = encoder_state.unsqueeze(2)\ .repeat(1, 1, self.beam_size, 1)\ .view(n, batch_size * self.beam_size, ds) beam_state["output"] = Variable( torch.LongTensor( [self.vocab.start_index] * batch_size * self.beam_size)\ .view(1, -1), lengths=torch.LongTensor([1] * batch_size * self.beam_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) if str(encoder_state.device) != "cpu": beam_state["output"] = beam_state["output"].cuda( encoder_state.device) # Start the first beam of each batch with 0 log prob, and all others # with -inf. beam_state["accum_log_prob"] = (self._init_accum_log_probs( batch_size, encoder_state.device)) # At the first time step no sequences have been terminated so this mask # is all 0s. beam_state["terminal_mask"] = (encoder_state.new().byte().new( 1, batch_size * self.beam_size, 1).fill_(0)) return beam_state def _init_accum_log_probs(self, batch_size, device): lp = torch.FloatTensor(1, batch_size, self.beam_size, 1) if "cuda" in str(device): lp = lp.cuda(device) lp.data.fill_(0) lp.data[:, :, 1:].fill_(float("-inf")) return lp.view(1, batch_size * self.beam_size, 1) def init_context(self, encoder_state): n, bs, ds = encoder_state["output"].size() beam_encoder_output = encoder_state["output"].repeat_batch_dim( self.beam_size) #.repeat(1, 1, self.beam_size, 1)\ #.view(n, bs * self.beam_size, ds) return {"encoder_output": beam_encoder_output} def __call__(self, decoder, encoder_state, controls=None): self.reset() # TODO get batch size in a more reliable way. This will probably break # for cnn or transformer based models. batch_size = encoder_state["state"].size(1) search_state = self.init_state(batch_size, encoder_state["state"]) search_context = self.init_context(encoder_state) active_items = search_state["decoder_state"].new(batch_size).byte() \ .fill_(1) if controls is not None: controls = controls.repeat_batch_dim(self.beam_size) self._beam_scores = [list() for _ in range(batch_size)] self._num_complete = search_state["decoder_state"].new()\ .long().new(batch_size).fill_(0) self._terminal_info = [list() for _ in range(batch_size)] # Perform search until we either trigger a termination condition for # each batch item or we reach the maximum number of search steps. while self.steps < self.max_steps and not self.is_finished: search_state = self.next_state(decoder, batch_size, search_state, search_context, active_items, controls) active_items = self.check_termination(search_state, active_items) self._is_finished = torch.all(~active_items) self._states.append(search_state) self.steps += 1 # Finish the search by collecting final sequences, and other # stats. self._collect_search_states(active_items) self._incomplete_items = active_items self._is_finished = True return self def next_state(self, decoder, batch_size, prev_state, context, active_items, controls): # Get next state from the decoder. next_state = decoder.next_state(prev_state, context, controls=controls) # Compute the top beam_size next outputs for each beam item. # topk_lps (1 x batch size x beam size x beam size) # candidate_outputs (1 x batch size x beam size x beam size) topk_lps, candidate_outputs = torch.topk( next_state["log_probs"].data \ .view(1, batch_size, self.beam_size, -1), k=self.beam_size, dim=3) # If any sequence was completed last step, we should mask it's log # prob so that we don't generate from the terminal token. # slp (1 x batch_size x beam size x 1) slp = prev_state["accum_log_prob"] \ .masked_fill(prev_state["terminal_mask"], float("-inf")) \ .view(1, batch_size, self.beam_size, 1) # Combine next step log probs with the previous sequences cumulative # log probs, i.e. # log P(y_t) = log P(y_<t) + log P(y_t) # candidate_log_probs (1 x batch size x beam size x beam size) candidate_log_probs = slp + topk_lps # Rerank and select the beam_size best options from the available # beam_size ** 2 candidates. # b_seq_lps (1 x (batch size * beam size) x 1) # b_scores (1 x (batch size * beam size) x 1) # b_output (1 x (batch size * beam size)) # b_indices ((batch size * beam size)) b_seq_lps, b_scores, b_output, b_indices = self._next_candidates( batch_size, candidate_log_probs, candidate_outputs) # TODO re-implement this behavior #next_state.stage_indexing("batch", b_indices) next_state = { "decoder_state": next_state["decoder_state"]\ .index_select(1, b_indices), "output": b_output, "accum_log_prob": b_seq_lps, "beam_score": b_scores, "beam_indices": b_indices, } return next_state #exit() #next_state = {"decoder_state": next_state["decoder_state"] #print(next_state.keys()) next_state["output"] = (b_output, ("batch", "sequence")) next_state["cumulative_log_probability"] = (b_seq_lps, ("sequence", "batch", "placeholder")) next_state["beam_score"] = (b_scores, ("sequence", "batch", "placeholder")) next_state["beam_indices"] = (b_indices, ("batch")) return next_state def _next_candidates(self, batch_size, log_probs, candidates): # TODO seq_lps should really be called cumulative log probs. # flat_beam_lps (batch size x (beam size ** 2)) flat_beam_lps = log_probs.view(batch_size, -1) flat_beam_scores = flat_beam_lps / (self.steps + 1) # beam_seq_scores (batch size x beam size) # relative_indices (batch_size x beam size) beam_seq_scores, relative_indices = torch.topk(flat_beam_scores, k=self.beam_size, dim=1) # beam_seq_lps (batch size x beam size) beam_seq_lps = flat_beam_lps.gather(1, relative_indices) # TODO make these ahead of time. offset1 = (torch.arange(batch_size, device=beam_seq_lps.device) * self.beam_size).view(batch_size, 1) offset2 = offset1 * self.beam_size beam_indexing = ((relative_indices // self.beam_size) + offset1) \ .view(-1) # beam_seq_lps (1 x (batch_size * beam_size) x 1) beam_seq_lps = beam_seq_lps \ .view(1, batch_size * self.beam_size, 1) # beam_seq_scores (1 x (batch_size * beam_size) x 1) beam_seq_scores = beam_seq_scores \ .view(1, batch_size * self.beam_size, 1) # next_output (1 x (batch size * beam size)) next_candidate_indices = (relative_indices + offset2).view(-1) next_output = Variable( candidates.view(-1)[next_candidate_indices].view(1, -1), lengths=candidates.new().long().new(batch_size * self.beam_size)\ .fill_(1), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) return beam_seq_lps, beam_seq_scores, next_output, beam_indexing def check_termination(self, next_state, active_items): # view as batch size x beam size next_output = next_state["output"].data \ .view(-1, self.beam_size) batch_size = next_output.size(0) is_complete = next_output.eq(self.vocab.stop_index) complete_indices = np.where(is_complete.cpu().data.numpy()) for batch, beam in zip(*complete_indices): if self._num_complete[batch] == self.beam_size: continue else: self._num_complete[batch] += 1 # Store step and beam that finished so we can retrace it # later and recover arbitrary search state item. self._terminal_info[batch].append( (self.steps, beam + batch * self.beam_size)) IDX = batch * self.beam_size + beam self._beam_scores[batch].append( next_state["beam_score"][0, IDX, 0].view(1)) next_state["terminal_mask"] = (is_complete.view( 1, batch_size * self.beam_size, 1)) active_items = self._num_complete < self.beam_size return active_items def _collect_search_states(self, active_items): batch_size = active_items.size(0) last_state = self._states[-1] last_step = self.steps - 1 for batch in range(batch_size): beam = 0 while len(self._beam_scores[batch]) < self.beam_size: IDX = batch * self.beam_size + beam self._beam_scores[batch].append( last_state["beam_score"][0, IDX, 0].view(1)) self._terminal_info[batch].append( (last_step, beam + batch * self.beam_size)) beam += 1 # TODO consider removing beam indices from state beam_indices = torch.stack( [state["beam_indices"] for state in self._states]) self._beam_scores = torch.stack( [torch.cat(bs) for bs in self._beam_scores]) lengths = self._states[0]["output"].new( [[step + 1 for step, beam in self._terminal_info[batch]] for batch in range(batch_size)]) selector = self._states[0]["output"].new(batch_size, self.beam_size, lengths.max()) mask = selector.new().byte().new(selector.size()).fill_(1) for batch in range(batch_size): for beam in range(self.beam_size): step, real_beam = self._terminal_info[batch][beam] mask[batch, beam, :step + 1].fill_(0) self._collect_beam(batch, real_beam, step, beam_indices, selector[batch, beam]) selector = selector.view(batch_size * self.beam_size, -1) ## RESORTING HERE ## #if self.sort_by_score: # TODO make this an option again if True: self._beam_scores, I = torch.sort(self._beam_scores, dim=1, descending=True) offset1 = (torch.arange(batch_size, device=I.device) * self.beam_size).view(batch_size, 1) II = I + offset1 selector = selector[II.view(-1)] mask = mask.view(batch_size * self.beam_size,-1)[II]\ .view(batch_size, self.beam_size, -1) lengths = lengths.gather(1, I) ## # TODO reimplement staged indexing # for step, sel_step in enumerate(selector.split(1, dim=1)): # self._states[step].stage_indexing("batch", sel_step.view(-1)) self._output = [] for step, sel_step in enumerate(selector.split(1, dim=1)): self._output.append(self._states[step]["output"].index_select( 1, sel_step.view(-1))) #print(self._states[0]["output"].size()) self._output = plum.cat([o.data for o in self._output], 0).t()\ .view(batch_size, self.beam_size, -1) for i in range(batch_size): for j in range(self.beam_size): self._output[i, j, lengths[i, j]:].fill_(self.vocab.pad_index) self._lengths = lengths return self # for batch_out in self._output: # #print(self._output.t().view(batch_size) # #for batch_out in self._output.t().view(batch_size, self.beam_size, -1): # for row in batch_out: # print(" ".join([self.vocab[t] for t in row if t != self.vocab.pad_index])) # print() # print(lengths) # print(lengths.size()) # print(batch_size) # exit() # # states = self._states[0] # for state in self._states[1:]: # states.append(state) # # self._states = states # self._selector = selector # self._lengths = lengths # self._selector_mask = mask.view(self.batch_size * self.beam_size, -1) # self._selector_mask_T = self._selector_mask.t().contiguous() def _collect_beam(self, batch, beam, step, beam_indices, selector_out): selection = [0] * beam_indices.size(0) selector_out[step + 1:].fill_(0) while step >= 0: selection[step] = beam selector_out[step].fill_(beam) next_beam = beam_indices[step, beam].item() beam = next_beam step -= 1 def output(self, as_indices=False, n_best=-1): if n_best < 1: o = self._output[:, 0] if as_indices: return o tokens = [] for row in o: tokens.append( [self.vocab[t] for t in row if t != self.vocab.pad_index]) return tokens elif n_best < self.beam_size: o = self._output[:, :n_best] else: o = self._output if as_indices: return o beams = [] for beam in o: tokens = [] for row in beam: tokens.append( [self.vocab[t] for t in row if t != self.vocab.pad_index]) beams.append(tokens) return beams
class GreedyNPAD(PlumObject): # Greedy Noisy Parallel Approximate Decoding for conditional recurrent # language model, Kyunghyun Cho 2016. max_steps = HP(default=999999, type=props.INTEGER) samples = HP(default=25, type=props.INTEGER) std = HP(default=0.1, type=props.REAL) mean = HP(default=0.0, type=props.REAL) vocab = HP() def __pluminit__(self): self.reset() def reset(self): self.is_finished = False self.steps = 0 self._states = [] self._outputs = [] def init_state_context(self, encoder_state): batch_size = encoder_state["state"].size(1) * self.samples output = Variable(torch.LongTensor([self.vocab.start_index] * batch_size).view(1, -1), lengths=torch.LongTensor([1] * batch_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) if str(encoder_state["state"].device) != "cpu": output = output.cuda(encoder_state["state"].device) layers = encoder_state["state"].size(0) decoder_state = encoder_state["state"].unsqueeze(2)\ .repeat(1, 1, self.samples, 1).view(layers, batch_size, -1) search_state = {"output": output, "decoder_state": decoder_state} encoder_output = encoder_state["output"].repeat_batch_dim(self.samples) context = {"encoder_output": encoder_output} return search_state, context def next_state(self, decoder, prev_state, context, active_items, controls): # Draw noise sample and inject into decoder hidden state. eps = prev_state["decoder_state"].new( prev_state["decoder_state"].size()).normal_( self.mean, self.std / self.steps) prev_state["decoder_state"] = prev_state["decoder_state"] + eps # Get next state from the decoder. next_state = decoder.next_state(prev_state, context, controls=controls) output_log_probs, output = next_state["log_probs"].max(2) next_state["output_log_probs"] = output_log_probs next_state["output"] = output # Mask outputs if we have already completed that batch item. next_state["output"].data.view(-1).masked_fill_( ~active_items, self.vocab.pad_index) return next_state def check_termination(self, next_state, active_items): # Check for stop tokens and batch item inactive if so. nonstop_tokens = next_state["output"].data.view(-1).ne( self.vocab.stop_index) active_items = active_items.data.mul_(nonstop_tokens) return active_items def _collect_search_states(self, active_items): # TODO implement search states api. #search_state = self._states[0] #for next_state in self._states[1:]: # search_state.append(next_state) #self._states = search_state self._outputs = torch.cat([o.data for o in self._outputs], dim=0) self._output_log_probs = torch.cat( [state["output_log_probs"].data for state in self._states], 0) self._output_log_probs = self._output_log_probs.masked_fill( self._mask_T, 0) avg_log_probs = (self._output_log_probs.sum(0) / (~self._mask_T).float().sum(0)).view( -1, self.samples) avg_log_probs, argsort = avg_log_probs.sort(1, descending=True) batch_size = avg_log_probs.size(0) offsets = (torch.arange(0, batch_size, device=argsort.device) \ * self.samples).view(-1, 1) reindex = argsort + offsets self._outputs = self._outputs.index_select(1, reindex.view(-1))\ .view(-1, batch_size, self.samples).permute(1, 2, 0) self._output_log_probs = self._output_log_probs.index_select( 1, reindex.view(-1)).view(-1, batch_size, self.samples)\ .permute(1, 2, 0) self._mask_T = None self._mask = None self._avg_log_probs = avg_log_probs def __call__(self, decoder, encoder_state, controls=None): self.reset() # TODO get batch size in a more reliable way. This will probably break # for cnn or transformer based models. batch_size = encoder_state["state"].size(1) search_state, context = self.init_state_context(encoder_state) active_items = search_state["decoder_state"]\ .new(batch_size * self.samples).byte().fill_(1) step_masks = [] # Perform search until we either trigger a termination condition for # each batch item or we reach the maximum number of search steps. while self.steps < self.max_steps and not self.is_finished: inactive_items = ~active_items # Mask any inputs that are finished, so that greedy would # be identitcal to forward passes. search_state["output"].data.view(-1).masked_fill_( inactive_items, self.vocab.pad_index) step_masks.append(inactive_items) self.steps += 1 search_state = self.next_state(decoder, search_state, context, active_items, controls) self._states.append(search_state) self._outputs.append(search_state["output"].clone()) active_items = self.check_termination(search_state, active_items) self.is_finished = torch.all(~active_items) # Finish the search by collecting final sequences, and other # stats. self._mask_T = torch.stack(step_masks) self._mask = self._mask_T.t().contiguous() self._collect_search_states(active_items) self._incomplete_items = active_items self._is_finished = True return self def __getitem__(self, key): if key == "output": return self._outputs def output(self, as_indices=False, n_best=-1): if n_best < 1: o = self._outputs[:, 0] if as_indices: return o tokens = [] for row in o: tokens.append( [self.vocab[t] for t in row if t != self.vocab.pad_index]) return tokens elif n_best < self.samples: o = self._outputs[:, :n_best] else: o = self._outputs if as_indices: return o beams = [] for beam in o: tokens = [] for row in beam: tokens.append( [self.vocab[t] for t in row if t != self.vocab.pad_index]) beams.append(tokens) return beams
class Constant(PlumObject): value = HP(type=props.REAL) def __call__(self, tensor): torch.nn.init.constant_(tensor, self.value)
class ClassificationLogger(PlumObject): file_prefix = HP(type=props.STRING) input_fields = HP() output_fields = HP() target_fields = HP(required=False) vocab = HP(required=False) log_every = HP(default=1, type=props.INTEGER) def __pluminit__(self): self._epoch = 0 self._log_dir = None self._file = None self._fp = None self._steps = 0 def set_log_dir(self, log_dir): self._log_dir = Path(log_dir) self._log_dir.mkdir(exist_ok=True, parents=True) def __call__(self, forward_state, batch): self._steps += 1 if self._steps % self.log_every != 0: return ref_inputs = resolve_getters(self.input_fields, batch) pred_labels = resolve_getters(self.output_fields, forward_state)\ .tolist() target = resolve_getters(self.target_fields, batch) for i, (ref_input, pred_label) in enumerate(zip(ref_inputs, pred_labels)): if not isinstance(pred_label, str) and self.vocab is not None: pred_label = self.vocab[pred_label] print("input: {}".format(ref_input), file=self._fp) if target is not None: true_label = target[i] if not isinstance(true_label, str) and self.vocab is not None: true_label = self.vocab[true_label] print("pred_label: {} target_label: {}".format( pred_label, true_label), file=self._fp) else: print("pred_label: {}".format(pred_label), file=self._fp) print(file=self._fp) def next_epoch(self): self.close() self._steps = 0 self._epoch += 1 self._file = self._log_dir / "{}.{}.log".format( self.file_prefix, self._epoch) self._fp = self._file.open("w") def close(self): if self._fp: self._fp.close() def __del__(self): self.close()
class GreedySearch(PlumObject): max_steps = HP(default=999999, type=props.INTEGER) vocab = HP() def __pluminit__(self): self.reset() def reset(self): self.is_finished = False self.steps = 0 self._states = [] self._outputs = [] def init_state(self, batch_size, encoder_state): output = Variable( torch.LongTensor([self.vocab.start_index] * batch_size)\ .view(1, -1), lengths=torch.LongTensor([1] * batch_size), length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index) if str(encoder_state.device) != "cpu": output = output.cuda(encoder_state.device) return {"output": output, "decoder_state": encoder_state} def next_state(self, decoder, prev_state, context, active_items, controls): # Get next state from the decoder. next_state = decoder.next_state(prev_state, context, controls=controls) # Mask outputs if we have already completed that batch item. next_state["output"].data.view(-1).masked_fill_( ~active_items, self.vocab.pad_index) return next_state def check_termination(self, next_state, active_items): # Check for stop tokens and batch item inactive if so. nonstop_tokens = next_state["output"].data.view(-1).ne( self.vocab.stop_index) active_items = active_items.data.mul_(nonstop_tokens) return active_items def _collect_search_states(self, active_items): # TODO implement search states api. #search_state = self._states[0] #for next_state in self._states[1:]: # search_state.append(next_state) #self._states = search_state self._outputs = torch.cat([o.data for o in self._outputs], dim=0) def __call__(self, decoder, encoder_state, controls=None): self.reset() # TODO get batch size in a more reliable way. This will probably break # for cnn or transformer based models. batch_size = encoder_state["state"].size(1) search_state = self.init_state(batch_size, encoder_state["state"]) context = { "encoder_output": encoder_state["output"], } active_items = search_state["decoder_state"].new(batch_size).byte() \ .fill_(1) step_masks = [] # Perform search until we either trigger a termination condition for # each batch item or we reach the maximum number of search steps. while self.steps < self.max_steps and not self.is_finished: inactive_items = ~active_items # Mask any inputs that are finished, so that greedy would # be identitcal to forward passes. search_state["output"].data.view(-1).masked_fill_( inactive_items, self.vocab.pad_index) step_masks.append(inactive_items) self.steps += 1 search_state = self.next_state(decoder, search_state, context, active_items, controls) self._states.append(search_state) self._outputs.append(search_state["output"].clone()) active_items = self.check_termination(search_state, active_items) self.is_finished = torch.all(~active_items) # Finish the search by collecting final sequences, and other # stats. self._collect_search_states(active_items) self._incomplete_items = active_items self._is_finished = True self._mask_T = torch.stack(step_masks) self._mask = self._mask_T.t().contiguous() return self def __getitem__(self, key): if key == "output": return self._outputs def output(self, as_indices=False): if as_indices: return self._outputs.t() tokens = [] for output in self._outputs.t(): tokens.append([ self.vocab[index] for index in output if index != self.vocab.pad_index ]) return tokens
class LaptopSystematicSelect(PlumObject): checkpoint = HP(required=False) beam_size = HP(default=1, type=props.INTEGER) source_vocab = HP() target_vocab = HP() filename = HP() def run(self, env, verbose=False): output_path = env["proj_dir"] / "output" / self.filename output_path.parent.mkdir(exist_ok=True, parents=True) if self.checkpoint is None: ckpt = self._get_default_checkpoint(env) else: ckpt = self.checkpoint if ckpt is None: raise RuntimeError("No checkpoints found!") ckpt_path = env["checkpoints"][ckpt]["path"] if verbose: print("Loading model from {}".format(ckpt_path)) model = plum.load(ckpt_path).eval() if env["gpu"] > -1: model.cuda(env["gpu"]) self._gpu = env["gpu"] samples = self.make_samples() with open(output_path, "w") as out_fp: for i, mr in enumerate(samples, 1): print("{}/{}".format(i, len(samples)), end="\r" if i < len(samples) else "\n", flush=True) gen_input = self.make_generator_inputs(mr) source = preproc.mr2source_inputs(mr) tokens = self._get_outputs(model, gen_input) data = json.dumps({ "source": source, "mr": mr, "text": " ".join(tokens), }) print(data, file=out_fp, flush=True) def make_samples(self): mrs = [] for field in VALUES[:-1]: opts = [ ('', 'dontcare'), ('', ''), ('dontcare', ''), ] for opt in opts: mr = {"da": "?select", "fields": {}} item1 = {} if opt[0] == '': item1[field] = {"lex_value": "PLACEHOLDER"} else: item1[field] = {"no_lex_value": "dontcare"} item2 = {} if opt[1] == '': item2[field] = {"lex_value": "PLACEHOLDER"} else: item2[field] = {"no_lex_value": "dontcare"} mr["fields"]["item1"] = item1 mr["fields"]["item2"] = item2 mrs.append(mr) opts = [("true", "true"), ('true', 'false'), ('true', 'dontcare'), ("false", "true"), ('false', 'false'), ('false', 'dontcare'), ("dontcare", "true"), ('dontcare', 'false'), ('dontcare', 'dontcare')] for opt in opts: mr = { "da": "?select", "fields": { "item1": { "isforbusinesscomputing": { "no_lex_value": opt[0] }, }, "item2": { "isforbusinesscomputing": { "no_lex_value": opt[1] }, }, }, } mrs.append(mr) return mrs def make_generator_inputs(self, data): source = preproc.mr2source_inputs(data) tokens = [self.source_vocab.start_token] + source \ + [self.source_vocab.stop_token] inputs = Variable(torch.LongTensor( [[self.source_vocab[t] for t in tokens]]).t(), lengths=torch.LongTensor([len(tokens)]), length_dim=0, batch_dim=1, pad_value=self.source_vocab.pad_index) if self._gpu > -1: inputs = inputs.cuda(self._gpu) return {"source_inputs": inputs} def _get_outputs(self, model, inputs): state = model.encode(inputs) if self.beam_size > 1: search = plum.seq2seq.search.BeamSearch(max_steps=100, beam_size=self.beam_size, vocab=self.target_vocab) else: search = plum.seq2seq.search.GreedySearch(max_steps=100, vocab=self.target_vocab) search(model.decoder, state) outputs = search.output() raw_tokens = outputs[0][:-1] return raw_tokens def _get_default_checkpoint(self, env): for ckpt, md in env["checkpoints"].items(): if md.get("default", False): return ckpt return ckpt
class SequenceClassificationError(PlumModule): input_vocab = HP() classifier = HP() gpu = HP(default=-1) target_fields = HP() search_fields = HP() def __pluminit__(self, classifier, gpu): if gpu > -1: classifier.cuda(gpu) def reset(self): self._errors = 0 self._total = 0 def _make_classifier_inputs(self, batch): lens = [] clf_inputs = [] for out in batch: clf_inputs.append([self.input_vocab.start_index] + [self.input_vocab[t] for t in out[:-1]] + [self.input_vocab.stop_index]) lens.append(len(out) + 1) lens = torch.LongTensor(lens) max_len = lens.max().item() clf_inputs = torch.LongTensor([ inp + [self.input_vocab.pad_index] * (max_len - len(inp)) for inp in clf_inputs ]).t() clf_inputs = Variable(clf_inputs, lengths=lens, batch_dim=1, length_dim=0, pad_value=self.input_vocab.pad_index) if self.gpu > -1: return clf_inputs.cuda(self.gpu) else: return clf_inputs def forward(self, forward_state, batch): self.classifier.eval() search_outputs = resolve_getters(self.search_fields, forward_state).output() clf_inputs = self._make_classifier_inputs(search_outputs) fs = self.classifier({"inputs": clf_inputs}) targets = resolve_getters(self.target_fields, batch) if self.gpu > -1: targets = targets.cuda(self.gpu) errors = (fs["output"] != targets).long().sum().item() total = targets.size(0) self._errors += errors self._total += total def compute(self): return { "rate": self._errors / self._total if self._total > 0 else 0., "count": self._errors } def pretty_result(self): return str(self.compute())
class E2ESystematicGeneration(PlumObject): FIELDS = [ "eat_type", "near", "area", "family_friendly", "customer_rating", "price_range", "food", "name" ] FIELD_DICT = { "food": [ 'French', 'Japanese', 'Chinese', 'English', 'Indian', 'Fast food', 'Italian' ], "family_friendly": ['no', 'yes'], "area": ['city centre', 'riverside'], "near": [ 'Café Adriatic', 'Café Sicilia', 'Yippee Noodle Bar', 'Café Brazil', 'Raja Indian Cuisine', 'Ranch', 'Clare Hall', 'The Bakers', 'The Portland Arms', 'The Sorrento', 'All Bar One', 'Avalon', 'Crowne Plaza Hotel', 'The Six Bells', 'Rainbow Vegetarian Café', 'Express by Holiday Inn', 'The Rice Boat', 'Burger King', 'Café Rouge', ], "eat_type": ['coffee shop', 'pub', 'restaurant'], "customer_rating": ['3 out of 5', '5 out of 5', 'high', 'average', 'low', '1 out of 5'], "price_range": [ 'more than £30', 'high', '£20-25', 'cheap', 'less than £20', 'moderate' ], "name": [ 'Cocum', 'Midsummer House', 'The Golden Curry', 'The Vaults', 'The Cricketers', 'The Phoenix', 'The Dumpling Tree', 'Bibimbap House', 'The Golden Palace', 'Wildwood', 'The Eagle', 'Taste of Cambridge', 'Clowns', 'Strada', 'The Mill', 'The Waterman', 'Green Man', 'Browns Cambridge', 'Cotto', 'The Olive Grove', 'Giraffe', 'Zizzi', 'Alimentum', 'The Punter', 'Aromi', 'The Rice Boat', 'Fitzbillies', 'Loch Fyne', 'The Cambridge Blue', 'The Twenty Two', 'Travellers Rest Beefeater', 'Blue Spice', 'The Plough', 'The Wrestlers', ], } #batches = HP() checkpoint = HP(required=False) mr_size = HP(type=props.INTEGER, required=True) batch_size = HP(default=8, type=props.INTEGER) beam_size = HP(default=1, type=props.INTEGER) source_vocab = HP() target_vocab = HP() filename = HP() delex = HP(default=False, type=props.BOOLEAN) def _get_default_checkpoint(self, env): for ckpt, md in env["checkpoints"].items(): if md.get("default", False): return ckpt return ckpt def _field_subsets_iter(self, size): for subset in combinations(self.FIELDS[:-1], size - 1): yield ("name", ) + subset def _instance_iter(self, fields): options = [self.FIELD_DICT[f] for f in fields] for values in product(*options): yield {field: value for field, value in zip(fields, values)} @property def total_subsets(self): return int(comb(7, self.mr_size - 1)) def total_settings(self, fields): return np.prod([len(self.FIELD_DICT[f]) for f in fields]) def _labels2input(self, labels): inputs = [self.source_vocab.start_token] for field in self.FIELDS: value = labels.get(field, "N/A").replace(" ", "_") inputs.append(field.replace("_", "").upper() + "_" + value) if self.delex: if inputs[2] != "NEAR_N/A": inputs[2] = "NEAR_PRESENT" inputs.pop(-1) inputs.append(self.source_vocab.stop_token) return inputs def _batch_labels(self, labels_batch): input_tokens = torch.LongTensor( [[self.source_vocab[tok] for tok in self._labels2input(labels)] for labels in labels_batch]) length = input_tokens.size(1) inputs = Variable(input_tokens.t(), lengths=torch.LongTensor([length] * len(labels_batch)), length_dim=0, batch_dim=1, pad_value=-1) if self._gpu > -1: inputs = inputs.cuda(self._gpu) return {"source_inputs": inputs} def _get_outputs(self, model, labels): batch = self._batch_labels(labels) state = model.encode(batch) if self.beam_size > 1: search = plum.seq2seq.search.BeamSearch(max_steps=100, beam_size=self.beam_size, vocab=self.target_vocab) else: search = plum.seq2seq.search.GreedySearch(max_steps=100, vocab=self.target_vocab) search(model.decoder, state) outputs = search.output() raw_tokens = [] postedited_outputs = [] for i, output in enumerate(outputs): raw_tokens.append(output[:-1]) output = postedit.detokenize(output) output = postedit.lexicalize(output, labels[i]) # print(labels[i]) # print(output) # input() postedited_outputs.append(output) # print() #input() return raw_tokens, postedited_outputs def run(self, env, verbose=False): output_path = env["proj_dir"] / "output" / self.filename output_path.parent.mkdir(exist_ok=True, parents=True) if self.checkpoint is None: ckpt = self._get_default_checkpoint(env) else: ckpt = self.checkpoint if ckpt is None: raise RuntimeError("No checkpoints found!") ckpt_path = env["checkpoints"][ckpt]["path"] if verbose: print("Loading model from {}".format(ckpt_path)) model = plum.load(ckpt_path).eval() if env["gpu"] > -1: model.cuda(env["gpu"]) self._gpu = env["gpu"] with output_path.open("w") as fp: field_subsets = self._field_subsets_iter(self.mr_size) for i, field_subset in enumerate(field_subsets, 1): print("slot subset {}/{}".format(i, self.total_subsets)) if verbose: print(" slots: {}".format(field_subset)) total_mrs = self.total_settings(field_subset) batch = [] inst_iter = self._instance_iter(field_subset) for j, labels in enumerate(inst_iter, 1): print("setting {}/{}".format(j, total_mrs), end="\r" if j < total_mrs else "\n", flush=True) batch.append(labels) if len(batch) == self.batch_size: output_tokens, output_strings = self._get_outputs( model, batch) for labels, tokens, string in zip( batch, output_tokens, output_strings): data = json.dumps({ "labels": labels, "tokens": tokens, "text": string }) print(data, file=fp) batch = [] if len(batch) > 0: output_tokens, output_strings = self._get_outputs( model, batch) for labels, tokens, string in zip(batch, output_tokens, output_strings): data = json.dumps({ "labels": labels, "tokens": tokens, "text": string }) print(data, file=fp)
class TVSystematicInformNoMatch(PlumObject): checkpoint = HP(required=False) beam_size = HP(default=1, type=props.INTEGER) source_vocab = HP() target_vocab = HP() filename = HP() def run(self, env, verbose=False): output_path = env["proj_dir"] / "output" / self.filename output_path.parent.mkdir(exist_ok=True, parents=True) if self.checkpoint is None: ckpt = self._get_default_checkpoint(env) else: ckpt = self.checkpoint if ckpt is None: raise RuntimeError("No checkpoints found!") ckpt_path = env["checkpoints"][ckpt]["path"] if verbose: print("Loading model from {}".format(ckpt_path)) model = plum.load(ckpt_path).eval() if env["gpu"] > -1: model.cuda(env["gpu"]) self._gpu = env["gpu"] samples = self.make_samples() with open(output_path, "w") as out_fp: for i, mr in enumerate(samples, 1): print("{}/{}".format(i, len(samples)), end="\r" if i < len(samples) else "\n", flush=True) gen_input = self.make_generator_inputs(mr) tokens = self._get_outputs(model, gen_input) source = preproc.mr2source_inputs(mr) data = json.dumps({ "source": source, "mr": mr, "text": " ".join(tokens), }) print(data, file=out_fp, flush=True) def make_samples(self): settings = [] for size in [1, 2, 3]: for fields in combinations(VALUES, size): fields = fields if 'hasusbport' in fields: idx = fields.index('hasusbport') for val in ['_true', '_false']: out = list(fields) out[idx] = out[idx] + val settings.append(out) else: settings.append(fields) mrs = [] for setting in settings: mr = {"da": "inform_no_match", "fields": {}} for field in setting: if field.startswith("hasusbport"): f, v = field.split("_") mr["fields"][f] = {"no_lex_value": v} else: mr["fields"][field] = {"lex_value": "PLACEHOLDER"} mrs.append(mr) return mrs def make_generator_inputs(self, data): source = preproc.mr2source_inputs(data) tokens = [self.source_vocab.start_token] + source \ + [self.source_vocab.stop_token] inputs = Variable(torch.LongTensor( [[self.source_vocab[t] for t in tokens]]).t(), lengths=torch.LongTensor([len(tokens)]), length_dim=0, batch_dim=1, pad_value=self.source_vocab.pad_index) if self._gpu > -1: inputs = inputs.cuda(self._gpu) return {"source_inputs": inputs} def _get_outputs(self, model, inputs): state = model.encode(inputs) if self.beam_size > 1: search = plum.seq2seq.search.BeamSearch(max_steps=100, beam_size=self.beam_size, vocab=self.target_vocab) else: search = plum.seq2seq.search.GreedySearch(max_steps=100, vocab=self.target_vocab) search(model.decoder, state) outputs = search.output() raw_tokens = outputs[0][:-1] return raw_tokens def _get_default_checkpoint(self, env): for ckpt, md in env["checkpoints"].items(): if md.get("default", False): return ckpt return ckpt
class TVMetrics(PlumModule): path = HP(type=props.EXISTING_PATH) search_fields = HP() references_fields = HP() def __pluminit__(self): self._cache = None self._queue = Queue(maxsize=0) self._thread = None self._thread = Thread(target=self._process_result) self._thread.setDaemon(True) self._thread.start() self._hyp_fp = NamedTemporaryFile("w") self._ref_fp = NamedTemporaryFile("w") def postprocess(self, tokens, mr): # TODO right now this is specific to the e2e dataset. Need to # generalize how to do post processing. tokens = [t for t in tokens if t[0] != "<" and t[-1] != ">"] text = " ".join(tokens) return preproc.lexicalize(text, mr) def _process_result(self): while True: hyp, refs, mr = self._queue.get() print(self.postprocess(hyp, mr), file=self._hyp_fp) #print(" ".join(hyp), file=self._hyp_fp) if isinstance(refs, (list, tuple)): refs = "\n".join(refs) print(refs, file=self._ref_fp, end="\n\n") self._queue.task_done() def reset(self): self._cache = None while not self._queue.empty(): self._queue.get() self._queue.task_done() self._hyp_fp = NamedTemporaryFile("w") self._ref_fp = NamedTemporaryFile("w") def apply_fields(self, fields, obj): if not isinstance(fields, (list, tuple)): fields = [fields] for field in fields: if hasattr(field, "__call__"): obj = field(obj) else: obj = obj[field] return obj def forward(self, forward_state, batch): search = self.apply_fields(self.search_fields, forward_state) hypotheses = search.output() reference_sets = self.apply_fields(self.references_fields, batch) for i, (hyp, refs) in enumerate(zip(hypotheses, reference_sets)): self._queue.put([hyp, refs, batch["mr"][i]]) def run_script(self): self._queue.join() self._ref_fp.flush() self._hyp_fp.flush() script_path = Path(self.path).resolve() result_bytes = check_output( [str(script_path), self._hyp_fp.name, self._ref_fp.name]) result = json.loads(result_bytes.decode("utf8")) self._cache = result self._ref_fp = None self._hyp_fp = None def compute(self): if self._cache is None: self.run_script() return self._cache def pretty_result(self): return str(self.compute())
class FeedForwardAttention(PlumModule): query_net = SM(default=Identity()) key_net = SM(default=Identity()) value_net = SM(default=Identity()) hidden_size = HP(type=props.INTEGER) weight = P("hidden_size", tags=["weight", "fully_connected"]) def forward(self, query, key, value=None): if value is None: value = key if isinstance(query, Variable): return self._variable_forward(query, key, value) else: return self._tensor_forward(query, key, value) def _variable_forward(self, query, key, value): key = self.key_net(key) query = self.query_net(query) key = key.permute_as_sequence_batch_features() query = query.permute_as_sequence_batch_features() assert key.dim() == query.dim() == 3 # TODO use named tensors to allow aribitrary seq dims with torch.no_grad(): mask = ~torch.einsum("qbh,kbh->qkb", [(~query.mask).float(), (~key.mask).float()]).byte() query_uns = query.data.unsqueeze(query.length_dim + 1) key_uns = key.data.unsqueeze(query.length_dim) hidden = torch.tanh(key_uns + query_uns) scores = hidden.matmul(self.weight) scores = scores.masked_fill(mask, float("-inf")) attention = torch.softmax(scores.transpose(1, 2), dim=2) attention = attention.masked_fill(attention != attention, 0.) comp = torch.einsum("ijk,kjh->ijh", [attention, self.value_net(value).data]) comp = Variable(comp, lengths=query.lengths, length_dim=0, batch_dim=1) return {"attention": attention, "output": comp} def _tensor_forward(self, query, key, value): key = self.key_net(key) query = self.query_net(query) key = key.unsqueeze(0) query = query.unsqueeze(1) hidden = torch.tanh(key + query) scores = hidden.matmul(self.weight) attention_batch_last = torch.softmax(scores, dim=1) attention_batch_second = attention_batch_last.transpose(1, 2) attention_batch_first = attention_batch_last.permute(2, 0, 1) value_batch_first = value.transpose(1, 0) result = LazyDict(attention=attention_batch_second) comp = _curry_composition(attention_batch_first, value_batch_first, self.value_net) result.lazy_set("output", comp) return result