def set_ensemble_train_func(self): write('Building an ensemble training function...') self.optimizer = get_optimizer(self.argv) self.optimizer.set_params(self.model.params) if self.argv.load_opt_param: self.optimizer.load_params(self.argv.load_opt_param) # 1D: batch_size * n_spans, 2D: [batch index, label id, span index] span_true = T.imatrix('span_true') # 1D: batch_size, 2D: n_spans, 3D: 2 * hidden_dim h_span = self.model.feat_layer.forward(self.model.inputs, self.experts) # 1D: batch_size, 2D: n_labels, 3D: n_spans; score logits = self.model.feat_layer.calc_logit_scores(h=h_span) # 1D: batch_size, 2D: n_labels; span index span_pred = self.model.argmax_span(logits) nll = self.model.calc_loss(logits, span_true) l2_reg = L2Regularizer() objective = nll + l2_reg(alpha=self.argv.reg, params=self.model.params) grads = T.grad(cost=objective, wrt=self.model.params) updates = self.optimizer(grads=grads, params=self.model.params) self.train_func = theano.function(inputs=self.model.inputs + [span_true], outputs=[objective, span_pred], updates=updates, mode='FAST_RUN')
def set_train_func(self): write('Building a training function...') self.optimizer = get_optimizer(self.argv) self.optimizer.set_params(self.model.params) if self.argv.load_opt_param: write('\tLoading optimization params...') self.optimizer.load_params(self.argv.load_opt_param) y_true = T.imatrix('y') # 1D: batch_size, 2D: n_words, 3D: output_dim emit_scores = self.model.get_emit_scores() # 1D: batch_size, 2D: n_words; elem=label id y_pred = self.model.label_layer.get_y_pred(emit_scores) # 1D: batch_size; elem=log proba y_path_proba = self.model.label_layer.get_y_path_proba( emit_scores, y_true) l2_reg = L2Regularizer() cost = -T.mean(y_path_proba) + l2_reg(alpha=self.argv.reg, params=self.model.params) grads = T.grad(cost=cost, wrt=self.model.params) updates = self.optimizer(grads=grads, params=self.model.params) self.train_func = theano.function(inputs=self.model.inputs + [y_true], outputs=[cost, y_pred], updates=updates, on_unused_input='warn', mode='FAST_RUN')