def get_beam_candidates(selection_model, beams, beam_width, placement_hidden, step_length, num_arrows, hold_indices, chart_type, special_tokens, delta_time, hold_filters, empty_filters, device): # store expanded seqs, scores, + original beam idx candidates, curr_states = [], [] for z, (seq, beam_score, hidden, cell) in enumerate(beams): curr_token = step_index_to_features(seq[-1], chart_type, special_tokens, device) curr_token = curr_token.unsqueeze(0).unsqueeze(0).float() curr_token = torch.cat((curr_token, delta_time), dim=-1) logits, (hidden, cell) = selection_model(curr_token, placement_hidden, hidden, cell, step_length) curr_states.append((hidden, cell)) curr_candidates = expand_beam(logits.squeeze(), seq, beam_score, z, hold_indices, num_arrows, hold_filters, empty_filters) candidates.extend(curr_candidates) # sort by beam score (accumulated (abs) log probs) -> keep lowest b candidates candidates = sorted(candidates, key=lambda x: x[1])[:beam_width] return candidates, curr_states
def test_tokenization_double(self): print( f'Double mode vocab size: {SELECTION_VOCAB_SIZES["pump-double"]}') for i in range(SELECTION_VOCAB_SIZES['pump-double']): feats = step_index_to_features(i, 'pump-double', None, d).unsqueeze(0) self.assertEqual( step_sequence_to_targets(feats, 'pump-double', None)[0].item(), i)
def save_best_beam(best, placement_times, chart_data, chart_type, special_tokens, device): seq, _, _ = best hold_indices = set() for m in range(1, len(seq)): token_feats = step_index_to_features(seq[m], chart_type, special_tokens, device) token_str = step_features_to_str(token_feats) token_str = filter_steps(token_str, hold_indices) chart_data.append([placement_times[m - 1], token_str])
def test_special_tokens(self): db_vs = SELECTION_VOCAB_SIZES['pump-double'] special = {db_vs: 'XXXXXXXXXX', (db_vs + 1): '..HHXHH..'} for key, val in special.items(): feats = step_index_to_features(key, 'pump-double', special, d) self.assertEqual(step_features_to_str(feats[0]), val) self.assertEqual( step_sequence_to_targets(feats, 'pump-double', special)[0].item(), key) new_feats = sequence_to_tensor(['XXWWXXWWXX', 'XX..XX..XX']) new_targets, new_tokens = step_sequence_to_targets( new_feats, 'pump-double', special) self.assertEqual(new_targets[0].item(), db_vs + 2) self.assertEqual(new_targets[1].item(), db_vs + 3) self.assertEqual(new_tokens, 2) self.assertEqual(special[db_vs + 3], 'XX..XX..XX') self.assertEqual(special[db_vs + 2], 'XXWWXXWWXX')
def generate_steps(selection_model, placements, placement_hiddens, vocab_size, n_step_features, chart_type, sample_rate, special_tokens, sampling, k, p, b, device=torch.device('cpu')): placement_frames = (placements == 1).nonzero(as_tuple=False).flatten() num_placements = int(placements.sum().item()) print(f'{num_placements} placements were chosen. Now selecting steps...') # store pairs of time (s) and step vocab indices chart_data = [] num_arrows = (n_step_features - TIME_FEATURES) // NUM_ARROW_STATES hold_filters, empty_filters = get_filter_indices(chart_type) # Start generating the sequence of steps step_length = torch.ones(1, dtype=torch.long, device=device) hold_indices = [set() for _ in range(b)] if sampling == 'beam-search' else set() conditioned = placement_hiddens is not None selection_model.eval() # for beam search, track the b most likely token sequences + # their (log) probabilities at each step beams = [] placement_times = [] with torch.no_grad(): start_token = torch.zeros(1, 1, n_step_features - 1, device=device) hidden, cell = selection_model.initStates(batch_size=1, device=device) for i in trange(num_placements): placement_melframe = placement_frames[i].item() placement_time = train_util.convert_melframe_to_secs( placement_melframe, sample_rate) placement_times.append(placement_time) delta_time = torch.tensor( [placement_times[-1] - placement_times[-2]] if i > 1 else [0]).unsqueeze( 0).unsqueeze(0).to(device) placement_hidden = placement_hiddens[i].unsqueeze( 0) if conditioned else None if sampling == 'beam-search': if i == 0: beams.append([[0], 0.0, hidden.clone(), cell.clone()]) candidates, curr_states = get_beam_candidates( selection_model, beams, b, placement_hidden, step_length, num_arrows, hold_indices, chart_type, special_tokens, delta_time, hold_filters, empty_filters, device) # if the last element in the sequence, keep the one with the best score if i == num_placements - 1: save_best_beam(candidates[0], placement_times, chart_data, chart_type, special_tokens, device) else: for seq, score, beam_idx in candidates: curr_hidden, curr_cell = curr_states[beam_idx] beams[beam_idx] = [seq, score, curr_hidden, curr_cell] else: curr_token = next_token_feats.unsqueeze(0).unsqueeze( 0).float() if i > 0 else start_token curr_token = torch.cat((curr_token, delta_time), dim=-1) logits, (hidden, cell) = selection_model(curr_token, placement_hidden, hidden, cell, step_length) next_token_idx = predict_step(logits.squeeze(), sampling, k, p, hold_indices, num_arrows, hold_filters, empty_filters) # convert token index -> feature tensor -> str [ucs] representation next_token_feats = step_index_to_features( next_token_idx, chart_type, special_tokens, device) next_token_str = step_features_to_str(next_token_feats) next_token_str = filter_steps(next_token_str, hold_indices) chart_data.append([placement_time, next_token_str]) return chart_data