예제 #1
0
 def check_grads_disabled(self, activations, labels, input_length):
     """
     Check if optimization to disable gradients is working
     """
     ctc_cost = ctc(activations, labels, input_length)
     ctc_function = aesara.function([], [ctc_cost])
     for node in ctc_function.maker.fgraph.apply_nodes:
         if isinstance(node.op, ConnectionistTemporalClassification):
             assert node.op.compute_grad is False
예제 #2
0
 def setup_cpu_op(
     self,
     activations,
     labels,
     input_length,
     compute_grad=True,
     mode=mode_without_gpu,
 ):
     cpu_ctc_cost = ctc(activations, labels, input_length)
     outputs = [cpu_ctc_cost]
     if compute_grad:
         # Symbolic gradient of CTC cost
         cpu_ctc_grad = grad(mean(cpu_ctc_cost), activations)
         outputs += [cpu_ctc_grad]
     return aesara.function([], outputs, mode=mode)
예제 #3
0
    def run_ctc(self, activations, labels, input_length, expected_costs,
                expected_grads):
        # Create symbolic variables
        t_activations = aesara.shared(activations, name="activations")
        t_activation_times = aesara.shared(input_length,
                                           name="activation_times")
        t_labels = aesara.shared(labels, name="labels")

        t_cost = ctc(t_activations, t_labels, t_activation_times)
        # Symbolic gradient of CTC cost
        t_grad = tt.grad(tt.mean(t_cost), t_activations)
        # Compile symbolic functions
        train = aesara.function([], [t_cost, t_grad])

        cost, grad = train()

        utt.assert_allclose(expected_grads / cost.shape[0], grad)
        utt.assert_allclose(expected_costs, cost)

        self.check_grads_disabled(t_activations, t_labels, t_activation_times)
예제 #4
0
 def wrapper(acts):
     # Create auxiliary symbolic variables
     t_activation_times = aesara.shared(in_lengths,
                                        name="activation_times")
     t_labels = aesara.shared(labels, name="labels")
     return ctc(acts, t_labels, t_activation_times)