def test_taps_error(self): # Test that an error rises if we use taps in outputs_info. with pytest.raises(RuntimeError): theano.scan_checkpoints(lambda: None, [], { "initial": self.A, "taps": [-2] })
def setUp(self): self.k = T.iscalar("k") self.A = T.vector("A") result, _ = theano.scan(fn=lambda prior_result, A: prior_result * A, outputs_info=T.ones_like(self.A), non_sequences=self.A, n_steps=self.k) result_check, _ = theano.scan_checkpoints( fn=lambda prior_result, A: prior_result * A, outputs_info=T.ones_like(self.A), non_sequences=self.A, n_steps=self.k, save_every_N=100) self.result = result[-1] self.result_check = result_check[-1] self.grad_A = T.grad(self.result.sum(), self.A) self.grad_A_check = T.grad(self.result_check.sum(), self.A)
def setUp(self): self.k = T.iscalar("k") self.A = T.vector("A") result, _ = theano.scan( fn=lambda prior_result, A: prior_result * A, outputs_info=T.ones_like(self.A), non_sequences=self.A, n_steps=self.k) result_check, _ = theano.scan_checkpoints( fn=lambda prior_result, A: prior_result * A, outputs_info=T.ones_like(self.A), non_sequences=self.A, n_steps=self.k, save_every_N=100) self.result = result[-1] self.result_check = result_check[-1] self.grad_A = T.grad(self.result.sum(), self.A) self.grad_A_check = T.grad(self.result_check.sum(), self.A)
def crf_loss0(uniaries, transition, targets, masks): """ compute minus log likelihood of crf as crf loss. :param transition: Theano 3D tensor uniary energies of each step. the shape is [batch_size, n_time_steps, num_labels], :param transition: Theano 2D tensor pairwise energies of each step. the shape is [num_labels, num_labels], where the pad label index is at last. :param targets: Theano 2D tensor targets in the shape [batch_size, n_time_steps] :param masks: Theano 2D tensor masks in the shape [batch_size, n_time_steps] :return: Theano 1D tensor an expression for minus log likelihood loss. """ assert transition.ndim == 2 assert targets.ndim == 2 assert masks.ndim == 2 assert uniaries.ndim == 3 def inner_function(uniaries_one_step, targets_one_step, mask_one_step, prior_partition, prev_label, tg_energy, transition): """ :param uniaries: [batch_size, t] :param targets_one_step: [batch_size] :param prior_partition: [batch_size, t] :param prev_label: [batch_size] :param tg_energy: [batch_size] :param transition: [t, t] :return: """ partition_shuffled = prior_partition.dimshuffle(0, 1, 'x') uniaries_one_step_shuffled = uniaries_one_step.dimshuffle(0, 'x', 1) partition_t = T.switch( mask_one_step.dimshuffle(0, 'x'), theano_logsumexp(uniaries_one_step_shuffled + transition.dimshuffle('x', 0, 1) + partition_shuffled, axis=1), prior_partition) tg_energy_t = T.switch( mask_one_step, tg_energy + uniaries_one_step[T.arange(uniaries_one_step.shape[0]), targets_one_step] + transition[prev_label, targets_one_step], tg_energy) return [partition_t, targets_one_step, tg_energy_t] # Input should be provided as (n_batch, n_time_steps, num_labels, num_labels) # but scan requires the iterable dimension to be first # So, we need to dimshuffle to (n_time_steps, n_batch, num_labels, num_labels) uniaries_shuffled = uniaries.dimshuffle(1, 0, 2) ##energies_shuffled = energies.dimshuffle(1, 0, 2, 3) targets_shuffled = targets.dimshuffle(1, 0) masks_shuffled = masks.dimshuffle(1, 0) # initials should be energies_shuffles[0, :, -1, :] init_label = T.cast(T.fill(uniaries[:, 0, 0], -1), 'int32') #aa = T.cast(T.fill(uniaries[:,0,:],0), 'float32') #aa = aa.dimshuffle(0, 'x', 1) + transition.dimshuffle('x', 0, 1) target_time0 = targets_shuffled[0] uniary_time0 = uniaries_shuffled[0] energy_time0 = transition[-1, :-1] #initials = [uniary_time0[:, :]+ transition[-1, :].dimshuffle('x', 0), target_time0, uniary_time0[T.arange(target_time0.shape[0]),target_time0]+ aa[T.arange(target_time0.shape[0]), init_label, target_time0]] initials = [ uniary_time0[:, :] + energy_time0.dimshuffle('x', 0), target_time0, uniary_time0[T.arange(target_time0.shape[0]), target_time0] + transition[init_label, target_time0] ] #print (transition[-1, :].dimshuffle('x', 0)).ndim, (transition[init_label, target_time0]).ndim [partitions, _, target_energies ], _ = theano.scan_checkpoints(fn=inner_function, outputs_info=initials, sequences=[ uniaries_shuffled[1:], targets_shuffled[1:], masks_shuffled[1:] ], non_sequences=[transition[:-1, :-1]]) #partition = partitions[-1] #target_energy = target_energies[-1] loss = theano_logsumexp(partition, axis=1) - target_energy return loss