def compile_valid(self): """ Get validation graph. """ self._decoder_updates = [] src_vars, src_mask, tgt_vars, tgt_mask = T.vars( 'imatrix', 'matrix', 'imatrix', 'matrix') encoder_outputs = MapDict(self.encode(src_vars, src_mask)) decoder_outputs = self.decode(encoder_outputs, tgt_vars, input_mask=src_mask) sampled_outputs = self.decode(encoder_outputs, tgt_vars, input_mask=src_mask, sampling=True) output_vars = self.expand(decoder_outputs) sampled_output_vars = self.expand(sampled_outputs) cost = T.costs.cross_entropy(output_vars, tgt_vars, mask=tgt_mask) accuracy = T.costs.accuracy(output_vars.argmax(axis=2), tgt_vars, mask=tgt_mask) return D.graph.compile( input_vars=[src_vars, src_mask, tgt_vars, tgt_mask], cost=cost, updates=self._decoder_updates, outputs={ "acc": accuracy, "outputs": sampled_output_vars.argmax(axis=2) })
def export_test_components(self): """ Export encoder, decoder and expander for test. """ self.test_exporting = True # Encoder input_var = T.var('imatrix') encoder_outputs = MapDict(self.encode(input_var)) encoder_graph = D.graph.compile(input_vars=[input_var], outputs=encoder_outputs) # Decoder t_var, feedback_var = T.vars('iscalar', 'ivector') state_var = T.var('matrix', test_shape=[3, self.decoder_hidden_size()]) feedback_embeds = self.lookup_feedback(feedback_var) feedback_embeds = T.ifelse( t_var == 0, feedback_embeds, feedback_embeds) # Trick to prevent warning of unused inputs vars = MapDict({"feedback": feedback_embeds, "t": t_var}) first_encoder_outputs = MapDict([ (k, v[0]) for (k, v) in encoder_outputs.items() ]) for i, state_name in enumerate(self._decoder_states): state_val = state_var[:, sum(self._decoder_state_sizes[:i], 0 ):sum(self._decoder_state_sizes[:i + 1], 0)] if "init_{}".format(state_name) in first_encoder_outputs: state_val = T.ifelse( t_var == 0, T.repeat(first_encoder_outputs["init_{}".format( state_name)][None, :], state_var.shape[0], axis=0), state_val) vars[state_name] = state_val vars.update(first_encoder_outputs) self.decode_step(vars) state_output = T.concatenate([vars[k] for k in self._decoder_states], axis=1) decoder_inputs = [t_var, state_var, feedback_var] + [ p[1] for p in sorted(first_encoder_outputs.items()) ] decoder_graph = D.graph.compile(input_vars=decoder_inputs, output=state_output) # Expander decoder_state = T.var('matrix', test_shape=[3, sum(self._decoder_state_sizes)]) decoder_outputs = MapDict() for i, state_name in enumerate(self._decoder_states): decoder_outputs[state_name] = decoder_state[:, self._hidden_size * i:self._hidden_size * (i + 1)] prob = T.nnet.softmax(self.expand(decoder_outputs)) expander_graph = D.graph.compile(input_vars=[decoder_state], output=prob) self.test_exporting = False return encoder_graph, decoder_graph, expander_graph
def compile_train(self): """ Get training graph. """ self._decoder_updates = [] src_vars, src_mask, tgt_vars, tgt_mask = T.vars( 'imatrix', 'matrix', 'imatrix', 'matrix') encoder_outputs = MapDict(self.encode(src_vars, src_mask)) decoder_outputs = self.decode(encoder_outputs, tgt_vars, input_mask=src_mask) output_vars = self.expand(decoder_outputs) cost = T.costs.cross_entropy(output_vars, tgt_vars, mask=tgt_mask) accuracy = T.costs.accuracy(output_vars.argmax(axis=2), tgt_vars, mask=tgt_mask) model_params = D.graph.new_block(*self._layers) return D.graph.compile( input_vars=[src_vars, src_mask, tgt_vars, tgt_mask], blocks=[model_params], cost=cost, monitors={"acc": accuracy}, updates=self._decoder_updates)