Exemple #1
0
    def forward(ctx, input, input_lengths, graphs, leaky_coefficient=1e-5):
        input = input.contiguous().clamp(
            -30, 30)  # clamp for both the denominator and the numerator
        B = input.size(0)
        if B != graphs.batch_size:
            raise ValueError(
                "input batch size ({}) does not equal to graph batch size ({})"
                .format(B, graphs.batch_size))
        packed_data = torch.nn.utils.rnn.pack_padded_sequence(
            input,
            input_lengths,
            batch_first=True,
        )
        batch_sizes = packed_data.batch_sizes
        input_lengths = input_lengths.cpu()
        if not graphs.log_domain:  # usually for the denominator
            exp_input = input.exp()
            objf, input_grad, ok = pychain_C.forward_backward(
                graphs.forward_transitions,
                graphs.forward_transition_indices,
                graphs.forward_transition_probs,
                graphs.backward_transitions,
                graphs.backward_transition_indices,
                graphs.backward_transition_probs,
                graphs.leaky_probs,
                graphs.initial_probs,
                graphs.final_probs,
                graphs.start_state,
                exp_input,
                batch_sizes,
                input_lengths,
                graphs.num_states,
                leaky_coefficient,
            )
        else:  # usually for the numerator
            objf, log_probs_grad, ok = pychain_C.forward_backward_log_domain(
                graphs.forward_transitions,
                graphs.forward_transition_indices,
                graphs.forward_transition_probs,
                graphs.backward_transitions,
                graphs.backward_transition_indices,
                graphs.backward_transition_probs,
                graphs.initial_probs,
                graphs.final_probs,
                graphs.start_state,
                input,
                batch_sizes,
                input_lengths,
                graphs.num_states,
            )
            input_grad = log_probs_grad.exp()

        ctx.save_for_backward(input_grad)
        return objf
Exemple #2
0
    def forward(ctx,
                input,
                input_lengths,
                num_graphs,
                den_graphs,
                leaky_coefficient=1e-5):
        try:
            import pychain_C
        except ImportError:
            raise ImportError(
                "Please install OpenFST and PyChain by `make openfst pychain` "
                "after entering espresso/tools")

        input = input.contiguous().clamp(
            -30, 30)  # clamp for both the denominator and the numerator
        B = input.size(0)
        if B != num_graphs.batch_size or B != den_graphs.batch_size:
            raise ValueError(
                "input batch size ({}) does not equal to num graph batch size ({}) "
                "or den graph batch size ({})".format(B, num_graphs.batch_size,
                                                      den_graphs.batch_size))
        packed_data = torch.nn.utils.rnn.pack_padded_sequence(
            input,
            input_lengths.cpu(),
            batch_first=True,
        )
        batch_sizes = packed_data.batch_sizes
        input_lengths = input_lengths.cpu()

        exp_input = input.exp()
        den_objf, input_grad, denominator_ok = pychain_C.forward_backward(
            den_graphs.forward_transitions,
            den_graphs.forward_transition_indices,
            den_graphs.forward_transition_probs,
            den_graphs.backward_transitions,
            den_graphs.backward_transition_indices,
            den_graphs.backward_transition_probs,
            den_graphs.leaky_probs,
            den_graphs.initial_probs,
            den_graphs.final_probs,
            den_graphs.start_state,
            exp_input,
            batch_sizes,
            input_lengths,
            den_graphs.num_states,
            leaky_coefficient,
        )
        denominator_ok = denominator_ok.item()

        assert num_graphs.log_domain
        num_objf, log_probs_grad, numerator_ok = pychain_C.forward_backward_log_domain(
            num_graphs.forward_transitions,
            num_graphs.forward_transition_indices,
            num_graphs.forward_transition_probs,
            num_graphs.backward_transitions,
            num_graphs.backward_transition_indices,
            num_graphs.backward_transition_probs,
            num_graphs.initial_probs,
            num_graphs.final_probs,
            num_graphs.start_state,
            input,
            batch_sizes,
            input_lengths,
            num_graphs.num_states,
        )
        numerator_ok = numerator_ok.item()

        loss = -num_objf + den_objf

        if (loss - loss) != 0.0 or not denominator_ok or not numerator_ok:
            default_loss = 10
            input_grad = torch.zeros_like(input)
            logger.warning(
                f"Loss is {loss} and denominator computation "
                f"(if done) returned {denominator_ok} "
                f"and numerator computation returned {numerator_ok} "
                f", setting loss to {default_loss} per frame")
            loss = torch.full_like(num_objf,
                                   default_loss * input_lengths.sum())
        else:
            num_grad = log_probs_grad.exp()
            input_grad -= num_grad

        ctx.save_for_backward(input_grad)
        return loss
def main(N, S, B):
    lang = 'python'
    precision = 'single'

    graphs = make_hmm(S, B)
    S = graphs.num_states - 1
    data = torch.zeros(B, N, S, dtype=torch.float32).contiguous()
    print(data.shape)
    #lengths = list(reversed(range(N-B+1, N+1)))
    lengths = [N for i in range(B)]
    print(lengths)
    data_lengths = torch.tensor(lengths, dtype=torch.int32)

    for device in ['cpu', 'cuda:0']:

        packed_data = torch.nn.utils.rnn.pack_padded_sequence(
            data,
            data_lengths,
            batch_first=True,
        )
        batch_sizes = packed_data.batch_sizes

        graphs = make_hmm(S, B)
        forward_transitions = graphs.forward_transitions.to(device)
        forward_transition_indices = graphs.forward_transition_indices.to(
            device)
        forward_transition_probs = graphs.forward_transition_probs.to(device)
        backward_transitions = graphs.backward_transitions.to(device)
        backward_transition_indices = graphs.backward_transition_indices.to(
            device)
        backward_transition_probs = graphs.backward_transition_probs.to(device)
        initial_probs = graphs.initial_probs.to(device)
        final_probs = graphs.final_probs.to(device)
        start_state = graphs.start_state.to(device)
        data = data.to(device)
        data_lengths = data_lengths.to(device)

        t1 = time.time()
        objf, log_probs_grad, ok = pychain_C.forward_backward_log_domain(
            forward_transitions,
            forward_transition_indices,
            forward_transition_probs,
            backward_transitions,
            backward_transition_indices,
            backward_transition_probs,
            initial_probs,
            final_probs,
            start_state,
            data,
            batch_sizes,
            data_lengths,
            graphs.num_states,
        )
        t2 = time.time()
        dev = 'gpu' if device == 'cuda:0' else 'cpu'
        print(
            f'{lang}\t{precision}\t{B}\t{S}\t{N}\tpychain_log\tdense\t{dev}\t{t2 - t1}'
        )

        data.exp_()
        graphs = make_hmm(S, B, log_domain=False)
        forward_transitions = graphs.forward_transitions.to(device)
        forward_transition_indices = graphs.forward_transition_indices.to(
            device)
        forward_transition_probs = graphs.forward_transition_probs.to(device)
        backward_transitions = graphs.backward_transitions.to(device)
        backward_transition_indices = graphs.backward_transition_indices.to(
            device)
        backward_transition_probs = graphs.backward_transition_probs.to(device)
        leaky_probs = graphs.leaky_probs.to(device)
        initial_probs = graphs.initial_probs.to(device)
        final_probs = graphs.final_probs.to(device)
        start_state = graphs.start_state.to(device)
        data = data.to(device)
        data_lengths = data_lengths.to(device)

        t1 = time.time()
        objf, log_probs_grad, ok = pychain_C.forward_backward(
            forward_transitions, forward_transition_indices,
            forward_transition_probs, backward_transitions,
            backward_transition_indices, backward_transition_probs,
            leaky_probs, initial_probs, final_probs, start_state, data,
            batch_sizes, data_lengths, graphs.num_states, 1e-3)
        t2 = time.time()
        dev = 'gpu' if device == 'cuda:0' else 'cpu'
        print(
            f'{lang}\t{precision}\t{B}\t{S}\t{N}\tpychain_leaky\tdense\t{dev}\t{t2 - t1}'
        )