def forward(ctx, input, input_lengths, graphs, leaky_coefficient): exp_input = input.clamp(-30, 30).exp() B = input.size(0) if B != graphs.batch_size: raise ValueError( "input batch size {0} does not equal to graph batch size {1}". format(B, graphs.batch_size)) forward_transitions = graphs.forward_transitions forward_transition_indices = graphs.forward_transition_indices forward_transition_probs = graphs.forward_transition_probs backward_transitions = graphs.backward_transitions backward_transition_indices = graphs.backward_transition_indices backward_transition_probs = graphs.backward_transition_probs leaky_probs = graphs.leaky_probs num_states = graphs.num_states final_probs = graphs.final_probs start_state = graphs.start_state leaky_coefficient = leaky_coefficient packed_data = torch.nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True) batch_sizes = packed_data.batch_sizes input_lengths = input_lengths.cpu() objf, input_grad, _ = pychain_C.forward_backward( forward_transitions, forward_transition_indices, forward_transition_probs, backward_transitions, backward_transition_indices, backward_transition_probs, leaky_probs, final_probs, start_state, exp_input, batch_sizes, input_lengths, num_states, leaky_coefficient) ctx.save_for_backward(input_grad) return objf
def forward(ctx, input, input_lengths, graphs, leaky_coefficient=1e-5): input = input.contiguous().clamp( -30, 30) # clamp for both the denominator and the numerator B = input.size(0) if B != graphs.batch_size: raise ValueError( "input batch size ({}) does not equal to graph batch size ({})" .format(B, graphs.batch_size)) packed_data = torch.nn.utils.rnn.pack_padded_sequence( input, input_lengths, batch_first=True, ) batch_sizes = packed_data.batch_sizes input_lengths = input_lengths.cpu() if not graphs.log_domain: # usually for the denominator exp_input = input.exp() objf, input_grad, ok = pychain_C.forward_backward( graphs.forward_transitions, graphs.forward_transition_indices, graphs.forward_transition_probs, graphs.backward_transitions, graphs.backward_transition_indices, graphs.backward_transition_probs, graphs.leaky_probs, graphs.initial_probs, graphs.final_probs, graphs.start_state, exp_input, batch_sizes, input_lengths, graphs.num_states, leaky_coefficient, ) else: # usually for the numerator objf, log_probs_grad, ok = pychain_C.forward_backward_log_domain( graphs.forward_transitions, graphs.forward_transition_indices, graphs.forward_transition_probs, graphs.backward_transitions, graphs.backward_transition_indices, graphs.backward_transition_probs, graphs.initial_probs, graphs.final_probs, graphs.start_state, input, batch_sizes, input_lengths, graphs.num_states, ) input_grad = log_probs_grad.exp() ctx.save_for_backward(input_grad) return objf
def forward(ctx, input, graphs): exp_input = input.clamp(-30, 30).exp() B = input.size(0) if B != graphs.batch_size: raise ValueError("input batch size {0} does not equal to graph batch size {1}" .format(B, graphs.batch_size)) forward_transitions = graphs.forward_transitions forward_transition_indices = graphs.forward_transition_indices forward_transition_probs = graphs.forward_transition_probs backward_transitions = graphs.backward_transitions backward_transition_indices = graphs.backward_transition_indices backward_transition_probs = graphs.backward_transition_probs initial_probs = graphs.initial_probs num_states = graphs.num_states objf, input_grad, _ = pychain_C.forward_backward( forward_transitions, forward_transition_indices, forward_transition_probs, backward_transitions, backward_transition_indices, backward_transition_probs, initial_probs, exp_input, num_states) ctx.save_for_backward(input_grad) return objf
def forward(ctx, input, input_lengths, num_graphs, den_graphs, leaky_coefficient=1e-5): try: import pychain_C except ImportError: raise ImportError( "Please install OpenFST and PyChain by `make openfst pychain` " "after entering espresso/tools") input = input.contiguous().clamp( -30, 30) # clamp for both the denominator and the numerator B = input.size(0) if B != num_graphs.batch_size or B != den_graphs.batch_size: raise ValueError( "input batch size ({}) does not equal to num graph batch size ({}) " "or den graph batch size ({})".format(B, num_graphs.batch_size, den_graphs.batch_size)) packed_data = torch.nn.utils.rnn.pack_padded_sequence( input, input_lengths.cpu(), batch_first=True, ) batch_sizes = packed_data.batch_sizes input_lengths = input_lengths.cpu() exp_input = input.exp() den_objf, input_grad, denominator_ok = pychain_C.forward_backward( den_graphs.forward_transitions, den_graphs.forward_transition_indices, den_graphs.forward_transition_probs, den_graphs.backward_transitions, den_graphs.backward_transition_indices, den_graphs.backward_transition_probs, den_graphs.leaky_probs, den_graphs.initial_probs, den_graphs.final_probs, den_graphs.start_state, exp_input, batch_sizes, input_lengths, den_graphs.num_states, leaky_coefficient, ) denominator_ok = denominator_ok.item() assert num_graphs.log_domain num_objf, log_probs_grad, numerator_ok = pychain_C.forward_backward_log_domain( num_graphs.forward_transitions, num_graphs.forward_transition_indices, num_graphs.forward_transition_probs, num_graphs.backward_transitions, num_graphs.backward_transition_indices, num_graphs.backward_transition_probs, num_graphs.initial_probs, num_graphs.final_probs, num_graphs.start_state, input, batch_sizes, input_lengths, num_graphs.num_states, ) numerator_ok = numerator_ok.item() loss = -num_objf + den_objf if (loss - loss) != 0.0 or not denominator_ok or not numerator_ok: default_loss = 10 input_grad = torch.zeros_like(input) logger.warning( f"Loss is {loss} and denominator computation " f"(if done) returned {denominator_ok} " f"and numerator computation returned {numerator_ok} " f", setting loss to {default_loss} per frame") loss = torch.full_like(num_objf, default_loss * input_lengths.sum()) else: num_grad = log_probs_grad.exp() input_grad -= num_grad ctx.save_for_backward(input_grad) return loss
def main(N, S, B): lang = 'python' precision = 'single' graphs = make_hmm(S, B) S = graphs.num_states - 1 data = torch.zeros(B, N, S, dtype=torch.float32).contiguous() print(data.shape) #lengths = list(reversed(range(N-B+1, N+1))) lengths = [N for i in range(B)] print(lengths) data_lengths = torch.tensor(lengths, dtype=torch.int32) for device in ['cpu', 'cuda:0']: packed_data = torch.nn.utils.rnn.pack_padded_sequence( data, data_lengths, batch_first=True, ) batch_sizes = packed_data.batch_sizes graphs = make_hmm(S, B) forward_transitions = graphs.forward_transitions.to(device) forward_transition_indices = graphs.forward_transition_indices.to( device) forward_transition_probs = graphs.forward_transition_probs.to(device) backward_transitions = graphs.backward_transitions.to(device) backward_transition_indices = graphs.backward_transition_indices.to( device) backward_transition_probs = graphs.backward_transition_probs.to(device) initial_probs = graphs.initial_probs.to(device) final_probs = graphs.final_probs.to(device) start_state = graphs.start_state.to(device) data = data.to(device) data_lengths = data_lengths.to(device) t1 = time.time() objf, log_probs_grad, ok = pychain_C.forward_backward_log_domain( forward_transitions, forward_transition_indices, forward_transition_probs, backward_transitions, backward_transition_indices, backward_transition_probs, initial_probs, final_probs, start_state, data, batch_sizes, data_lengths, graphs.num_states, ) t2 = time.time() dev = 'gpu' if device == 'cuda:0' else 'cpu' print( f'{lang}\t{precision}\t{B}\t{S}\t{N}\tpychain_log\tdense\t{dev}\t{t2 - t1}' ) data.exp_() graphs = make_hmm(S, B, log_domain=False) forward_transitions = graphs.forward_transitions.to(device) forward_transition_indices = graphs.forward_transition_indices.to( device) forward_transition_probs = graphs.forward_transition_probs.to(device) backward_transitions = graphs.backward_transitions.to(device) backward_transition_indices = graphs.backward_transition_indices.to( device) backward_transition_probs = graphs.backward_transition_probs.to(device) leaky_probs = graphs.leaky_probs.to(device) initial_probs = graphs.initial_probs.to(device) final_probs = graphs.final_probs.to(device) start_state = graphs.start_state.to(device) data = data.to(device) data_lengths = data_lengths.to(device) t1 = time.time() objf, log_probs_grad, ok = pychain_C.forward_backward( forward_transitions, forward_transition_indices, forward_transition_probs, backward_transitions, backward_transition_indices, backward_transition_probs, leaky_probs, initial_probs, final_probs, start_state, data, batch_sizes, data_lengths, graphs.num_states, 1e-3) t2 = time.time() dev = 'gpu' if device == 'cuda:0' else 'cpu' print( f'{lang}\t{precision}\t{B}\t{S}\t{N}\tpychain_leaky\tdense\t{dev}\t{t2 - t1}' )