def compile_valid(self): """ Get validation graph. """ self._decoder_updates = [] src_vars, src_mask, tgt_vars, tgt_mask = T.vars( 'imatrix', 'matrix', 'imatrix', 'matrix') encoder_outputs = MapDict(self.encode(src_vars, src_mask)) decoder_outputs = self.decode(encoder_outputs, tgt_vars, input_mask=src_mask) sampled_outputs = self.decode(encoder_outputs, tgt_vars, input_mask=src_mask, sampling=True) output_vars = self.expand(decoder_outputs) sampled_output_vars = self.expand(sampled_outputs) cost = T.costs.cross_entropy(output_vars, tgt_vars, mask=tgt_mask) accuracy = T.costs.accuracy(output_vars.argmax(axis=2), tgt_vars, mask=tgt_mask) return D.graph.compile( input_vars=[src_vars, src_mask, tgt_vars, tgt_mask], cost=cost, updates=self._decoder_updates, outputs={ "acc": accuracy, "outputs": sampled_output_vars.argmax(axis=2) })
def __init__(self, model, source_vocab, target_vocab, start_token="<s>", end_token="</s>", beam_size=5, opts=None, unk_replace=False, alignment_path=None): assert isinstance(model, EncoderDecoderModel) self.model = model self.source_vocab = source_vocab self.target_vocab = target_vocab self.start_token = start_token self.end_token = end_token self.start_token_id = self.source_vocab.encode_token(start_token) self.end_token_id = self.target_vocab.encode_token(end_token) self.opts = MapDict(opts) if opts else opts self.beam_size = beam_size self.unk_replace = unk_replace self.align_table = None if alignment_path: self.align_table = cPickle.load(open(alignment_path)) self.prepare()
def decode(self, encoder_outputs, target_vars, input_mask=None, sampling=False, extra_outputs=None): """ Decoding graph. """ encoder_states = encoder_outputs.encoder_states batch_size = encoder_states.shape[0] feedbacks = T.concat( [T.ones((batch_size, 1), dtype="int32"), target_vars[:, :-1]], axis=1) feedback_embeds = self.lookup_feedback(feedbacks) # Process initial states decoder_outputs = {"t": T.constant(0, dtype="int32")} for state_name, size in zip(self._decoder_states, self._decoder_state_sizes): if "init_{}".format(state_name) in encoder_outputs: decoder_outputs[state_name] = encoder_outputs["init_{}".format( state_name)] else: decoder_outputs[state_name] = T.zeros((batch_size, size)) if extra_outputs: decoder_outputs.update(extra_outputs) # Process non-seqeuences non_sequences = {"input_mask": input_mask} for k, val in encoder_outputs.items(): if not k.startswith("init_"): non_sequences[k] = val loop = D.graph.loop(sequences={ "feedback_token": feedbacks.dimshuffle((1, 0)), "feedback": feedback_embeds.dimshuffle((1, 0, 2)) }, outputs=decoder_outputs, non_sequences=non_sequences) with loop as vars: if sampling: self.sample_step(vars) else: self.decode_step(vars) vars.t += 1 output_map = MapDict() for state_name in decoder_outputs: if loop.outputs[state_name].ndim == 2: output_map[state_name] = loop.outputs[state_name].dimshuffle( (1, 0)) elif loop.outputs[state_name].ndim == 3: output_map[state_name] = loop.outputs[state_name].dimshuffle( (1, 0, 2)) else: output_map[state_name] = loop.outputs[state_name] if loop.updates: self._decoder_updates.extend(loop.updates) return output_map
def __exit__(self, exc_type, exc_val, exc_tb): from neural_var import NeuralVariable output_tensors = [] for k in self._ordered_out_keys: if k not in self._loop_vars: raise Exception("{} can not be found in loop vars.".format(k)) output_tensors.append(self._loop_vars[k].tensor) result_tensors, updates = finish_scan(output_tensors, self._scan_local_vars) if self._block and updates: if type(updates) == dict: updates = updates.items() self._block.register_updates(*updates) outputs = MapDict() for k, tensor in zip(self._ordered_out_keys, result_tensors): out_var = NeuralVariable(tensor) if self._outputs[k] is not None: out_var.output_dim = self._outputs[k].dim() outputs[k] = out_var self._scan_outputs = outputs
def compile_train(self): """ Get training graph. """ self._decoder_updates = [] src_vars, src_mask, tgt_vars, tgt_mask = T.vars( 'imatrix', 'matrix', 'imatrix', 'matrix') encoder_outputs = MapDict(self.encode(src_vars, src_mask)) decoder_outputs = self.decode(encoder_outputs, tgt_vars, input_mask=src_mask) output_vars = self.expand(decoder_outputs) cost = T.costs.cross_entropy(output_vars, tgt_vars, mask=tgt_mask) accuracy = T.costs.accuracy(output_vars.argmax(axis=2), tgt_vars, mask=tgt_mask) model_params = D.graph.new_block(*self._layers) return D.graph.compile( input_vars=[src_vars, src_mask, tgt_vars, tgt_mask], blocks=[model_params], cost=cost, monitors={"acc": accuracy}, updates=self._decoder_updates)
def export_test_components(self): """ Export encoder, decoder and expander for test. """ self.test_exporting = True # Encoder input_var = T.var('imatrix') encoder_outputs = MapDict(self.encode(input_var)) encoder_graph = D.graph.compile(input_vars=[input_var], outputs=encoder_outputs) # Decoder t_var, feedback_var = T.vars('iscalar', 'ivector') state_var = T.var('matrix', test_shape=[3, self.decoder_hidden_size()]) feedback_embeds = self.lookup_feedback(feedback_var) feedback_embeds = T.ifelse( t_var == 0, feedback_embeds, feedback_embeds) # Trick to prevent warning of unused inputs vars = MapDict({"feedback": feedback_embeds, "t": t_var}) first_encoder_outputs = MapDict([ (k, v[0]) for (k, v) in encoder_outputs.items() ]) for i, state_name in enumerate(self._decoder_states): state_val = state_var[:, sum(self._decoder_state_sizes[:i], 0 ):sum(self._decoder_state_sizes[:i + 1], 0)] if "init_{}".format(state_name) in first_encoder_outputs: state_val = T.ifelse( t_var == 0, T.repeat(first_encoder_outputs["init_{}".format( state_name)][None, :], state_var.shape[0], axis=0), state_val) vars[state_name] = state_val vars.update(first_encoder_outputs) self.decode_step(vars) state_output = T.concatenate([vars[k] for k in self._decoder_states], axis=1) decoder_inputs = [t_var, state_var, feedback_var] + [ p[1] for p in sorted(first_encoder_outputs.items()) ] decoder_graph = D.graph.compile(input_vars=decoder_inputs, output=state_output) # Expander decoder_state = T.var('matrix', test_shape=[3, sum(self._decoder_state_sizes)]) decoder_outputs = MapDict() for i, state_name in enumerate(self._decoder_states): decoder_outputs[state_name] = decoder_state[:, self._hidden_size * i:self._hidden_size * (i + 1)] prob = T.nnet.softmax(self.expand(decoder_outputs)) expander_graph = D.graph.compile(input_vars=[decoder_state], output=prob) self.test_exporting = False return encoder_graph, decoder_graph, expander_graph