def transform_batch(self, batch_sources): """Implements transformer hook to add padding and masks to datastream. batch_sources is a tuple containing all the sources from datastream. In our case it's just one source, so we handle it accordingly. """ return pad_mask(batch_sources[0])
def train(model, dataset, optimizer, criterion, epoch, args, data_start_index): model.train() if data_start_index == 0: dataset.shuffle('train', seed=epoch + args.seed) if args.epoch_max_len is not None: data_end_index = min(data_start_index + args.epoch_max_len, len(dataset.splits['train'])) loader = dataset.loader('train', num_workers=args.num_workers, indices=list( range(data_start_index, data_end_index))) data_start_index = data_end_index if data_end_index < len( dataset.splits['train']) else 0 else: loader = dataset.loader('train', num_workers=args.num_workers) loss_meter = AverageMeter('loss', ':6.4f') total_length = len(loader) progress = ProgressMeter(total_length, [loss_meter], prefix='Training: ') for batch_num, batch in enumerate(tqdm(loader, total=len(loader))): batch = [tensor.to(args.device) for tensor in batch] inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch if args.task not in ['formality', 'iambic']: if not args.debug and len( inputs) != args.batch_size: # it'll screw up the bias...? continue scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True) if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq expanded_labels = classification_targets.unsqueeze(1).expand( -1, scores.shape[1]) # batch x seq length_mask = pad_mask(lengths).permute(1, 0) # batch x seq loss = criterion( scores.flatten()[length_mask.flatten() == 1], expanded_labels.flatten().float()[length_mask.flatten() == 1]) elif args.task in ['iambic', 'newline']: use_indices = classification_targets.flatten() != -1 loss = criterion( scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices]) else: # topic, rhyme loss = criterion(scores.flatten(), labels.flatten().float()) optimizer.zero_grad() loss.backward() optimizer.step() loss_meter.update(loss.detach(), len(labels)) if batch_num % args.train_print_freq == 0: progress.display(batch_num) progress.display(total_length) return data_start_index
def validate(model, dataset, criterion, epoch, args): model.eval() random.seed(0) loader = dataset.loader('val', num_workers=args.num_workers) loss_meter = AverageMeter('loss', ':6.4f') total_length = len(loader) progress = ProgressMeter(total_length, [loss_meter], prefix='Validation: ') with torch.no_grad(): for batch_num, batch in enumerate(tqdm(loader, total=len(loader))): batch = [tensor.to(args.device) for tensor in batch] inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch if args.task not in ['formality', 'iambic']: # topic predictor if not args.debug and len(inputs) != args.batch_size: continue scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True) if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq expanded_labels = classification_targets.unsqueeze(1).expand( -1, scores.shape[1]) # batch x seq length_mask = pad_mask(lengths).permute(1, 0) # batch x seq loss = criterion( scores.flatten()[length_mask.flatten() == 1], expanded_labels.flatten().float()[length_mask.flatten() == 1]) elif args.task == 'intent': # we're learning for all positions at once. scores are batch x seq x 4 expanded_labels = classification_targets.unsqueeze(1).expand( -1, scores.shape[1], -1) # batch x seq x 4 expanded_labels = expanded_labels.contiguous().view( -1, 4) # batch*seq x 4 scores = scores.contiguous().view(-1, 4) loss = criterion(scores, expanded_labels.float()) elif args.task in ['iambic', 'newline']: use_indices = classification_targets.flatten() != -1 loss = criterion( scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices]) else: # topic, rhyme loss = criterion(scores.flatten(), labels.flatten().float()) loss_meter.update(loss.detach(), len(labels)) if batch_num % args.train_print_freq == 0: progress.display(batch_num) progress.display(total_length) return loss_meter.avg
def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False): """ inputs: token ids, batch x seq, right-padded with 0s lengths: lengths of inputs; batch future_words: batch x N words to check if not predict next token, else batch log_probs: N syllables_to_go: batch """ if self.topic: inputs = self.gpt_embed(inputs) # batch x seq x 300 inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) rnn_output, _ = self.rnn(inputs) rnn_output, _ = pad_packed_sequence(rnn_output) rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 hidden = rnn_output attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq embed = self.word_embed(future_words) # batch x N x 300 embed_query = self.embed_key_linear(embed) attention_tensor = self.attention_linear(hidden).unsqueeze( 2) * embed_query.unsqueeze(1) # batch x seq x N x 300 attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N attention_weights = attention_weights * attention_mask.unsqueeze(2) hidden = self.attention_value_linear(hidden) weighted_hidden = ( hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum( dim=1) # batch x seq x N x 768 -> batch x N x 768 unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed) ) # batch x N x 300 unnormalized_scores = torch.cat([unnormalized_scores, embed], dim=2) unnormalized_scores = self.nonlinear( self.out_linear2(self.nonlinear(unnormalized_scores))) unnormalized_scores = self.out_linear3(unnormalized_scores) scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0) return scores # batch x N of normalized scores or batch x elif self.formality: inputs = self.marian_embed(inputs) inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) rnn_output, _ = self.rnn(inputs) rnn_output, _ = pad_packed_sequence(rnn_output) rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 return self.out_linear(rnn_output).squeeze(2) elif self.iambic: inputs = self.gpt_embed(inputs) inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) rnn_output, _ = self.rnn(inputs) rnn_output, _ = pad_packed_sequence(rnn_output) rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 return self.out_linear(rnn_output).squeeze(2) elif self.rhyme: inputs = self.gpt_embed(inputs) # batch x seq x 300 inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) rnn_output, _ = self.rnn(inputs) rnn_output, _ = pad_packed_sequence(rnn_output) rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 hidden = rnn_output attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq embed = self.word_embed(future_words) # batch x N x 300 embedded_syllables_to_go = self.count_syllable_embed( syllables_to_go).unsqueeze(1).expand(-1, embed.shape[1], -1) # batch x N x 100 auxiliary_embed = embedded_syllables_to_go embed_query = self.embed_key_linear( torch.cat([embed, auxiliary_embed], dim=2)) attention_tensor = self.attention_linear(hidden).unsqueeze( 2) * embed_query.unsqueeze(1) # batch x seq x N x 300 attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N attention_weights = attention_weights * attention_mask.unsqueeze(2) hidden = self.attention_value_linear(hidden) weighted_hidden = ( hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum( dim=1) # batch x seq x N x 768 -> batch x N x 768 unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed) ) # batch x N x 300 unnormalized_scores = torch.cat( [unnormalized_scores, embed, auxiliary_embed], dim=2) unnormalized_scores = self.nonlinear( self.out_linear2(self.nonlinear(unnormalized_scores))) unnormalized_scores = self.out_linear3(unnormalized_scores) scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0) return scores # batch x N of normalized scores or batch x elif self.newline: inputs = self.gpt_embed(inputs) # batch x seq x 300 inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) rnn_output, _ = self.rnn(inputs) rnn_output, _ = pad_packed_sequence(rnn_output) rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 hidden = torch.cat([ rnn_output, self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand( -1, rnn_output.shape[1], -1) ], dim=2) return self.out_linear3( self.nonlinear( self.out_linear2(self.nonlinear( self.out_linear(hidden))))).squeeze(2) else: raise NotImplementedError
def variable_beam_stream_fast(sg, model, tokenized_sentences, k=5, max_length=100, rp=0.6, ap=2.5, rpl=0.02, mc=3, find_top_z=1, max_indices=32, encode_batch_size=64, max_si_tokens=7168, bos_token=None, len_penalty=1, one_batch=False): ensemble_size = len(model.models) BOS_ID = sg.eos if bos_token is None else bos_token EOS_ID = sg.eos if one_batch: full_data_size = tokenized_sentences['net_input']['src_tokens'].shape[ 0] else: full_data_size = len(tokenized_sentences) batch_iterator = model._build_batches(tokenized_sentences, False) # not streaming master_done_beams = [[] for _ in range(full_data_size)] master_batch_ids = [None for _ in range(full_data_size)] parent_model = model model = model.models master_decoded_indices = torch.zeros(1, 0, k).long().to( parent_model.device) # seq, batch, k master_log_probs = torch.zeros(0, k).to(parent_model.device) # batch x k master_enc_out = [] master_state = IncrementalState( 0, k, ensemble_size, parent_model.device) # init incremental state master_valid_beam_mask = torch.zeros(0, k).to( parent_model.device) # batch x k master_num_valid_beams = torch.zeros(0).long().to( parent_model.device) # batch master_index = torch.zeros(0).long().to(parent_model.device) # batch master_src_lengths = torch.zeros(0).long().to(parent_model.device) master_progress = torch.zeros(0).long().to(parent_model.device) # batch master_end_found = torch.zeros(0, k).long().to( parent_model.device) # batch x k master_done_lengths = torch.zeros(0).long().to( parent_model.device) # batch master_best_finished_log_probs = torch.zeros(0).to( parent_model.device) - 1e8 # batch current_idx = 0 has_more_batches = True decode_calls = 0 n_expansions = 0 master_remove_indices = torch.zeros(0).long().to(parent_model.device) num_pad = 0 reselect = True while True: while has_more_batches and master_src_lengths.sum( ) <= max_si_tokens - parent_model.args.max_tokens: # token-based limit assert reselect if one_batch: # not streaming batch = tokenized_sentences has_more_batches = False else: try: batch = next(batch_iterator) except StopIteration: has_more_batches = False break batch = utils.apply_to_sample(lambda t: t.to(parent_model.device), batch) for i, id in enumerate(batch['id'].tolist()): master_batch_ids[current_idx + i] = id net_input = batch["net_input"] src_tokens = net_input["src_tokens"] num_new_sources = len(src_tokens) # encode add the next batch of source infos; update the index encoder_outs = sg.model.forward_encoder(net_input) # concatenate to the current master tensors # decoded_indices; note these are left padded current_seqlen = master_decoded_indices.size(0) master_decoded_indices = torch.cat([ master_decoded_indices, pad_to_length(torch.zeros(1, num_new_sources, k) + BOS_ID, current_seqlen, 0, side='left', value=0).long().to(parent_model.device) ], dim=1) # log_probs master_log_probs = torch.cat([ master_log_probs, torch.cat([ torch.zeros(num_new_sources, 1), torch.zeros(num_new_sources, k - 1) - 1e8 ], dim=1).to(parent_model.device) ], dim=0) if len(master_enc_out) == 0: assert current_idx == 0 master_enc_out = encoder_outs else: assert len(master_enc_out) == len(encoder_outs) for i in range(len(master_enc_out)): meo, eo = master_enc_out[i], encoder_outs[i] max_seq = max(meo.encoder_out.shape[0], eo.encoder_out.shape[0]) new_eo = EncoderOut(encoder_out=torch.cat([ pad_to_length( meo.encoder_out, max_seq, 0, side='left', value=0), pad_to_length( eo.encoder_out, max_seq, 0, side='left', value=0) ], dim=1), encoder_padding_mask=torch.cat([ pad_to_length( meo.encoder_padding_mask, max_seq, 1, side='left', value=True), pad_to_length( eo.encoder_padding_mask, max_seq, 1, side='left', value=True) ], dim=0), encoder_embedding=torch.cat([ pad_to_length( meo.encoder_embedding, max_seq, 1, side='left', value=0), pad_to_length(eo.encoder_embedding, max_seq, 1, side='left', value=0) ], dim=0), encoder_states=None, src_tokens=None, src_lengths=None) master_enc_out[i] = new_eo if not one_batch: # get the encoder attention keys sg.model.incremental_states = [{} for _ in range(ensemble_size)] sg.model.forward_decoder( (torch.zeros(num_new_sources) + BOS_ID).long().to( parent_model.device).unsqueeze(1), encoder_outs, sg.temperature) dummy_state = sg.model.incremental_states master_state.append_new_incremental_state( num_new_sources, dummy_state, torch.arange(num_new_sources).long().to( parent_model.device) + current_idx) master_valid_beam_mask = torch.cat([ master_valid_beam_mask, torch.cat([ torch.ones(num_new_sources, 1), torch.zeros(num_new_sources, k - 1) ], dim=1).to(parent_model.device) ], dim=0) # print(net_input['src_lengths'].max()) master_src_lengths = torch.cat( [master_src_lengths, net_input['src_lengths']], dim=0) # num_valid_beams master_num_valid_beams = torch.cat([ master_num_valid_beams, torch.ones(num_new_sources).long().to(parent_model.device) ], dim=0) # index master_index = torch.cat([ master_index, current_idx + torch.arange(num_new_sources).to(parent_model.device) ], dim=0) # progress master_progress = torch.cat([ master_progress, torch.zeros(num_new_sources).long().to(parent_model.device) ], dim=0) # end_found master_end_found = torch.cat([ master_end_found, torch.zeros(num_new_sources, k).long().to(parent_model.device) ], dim=0) # done lengths master_done_lengths = torch.cat([ master_done_lengths, torch.zeros(num_new_sources).long().to(parent_model.device) ], dim=0) # best done log probs master_best_finished_log_probs = torch.cat([ master_best_finished_log_probs, torch.zeros(num_new_sources).to(parent_model.device) - 1e8 ], dim=0) current_idx += num_new_sources # break # for debugging # break if none left if not has_more_batches and len(master_index) == 0: break # based on max_bs and source_info, select which indices to use (sort source_info), then create: selected_indices, unselected_indices, prog_min = select_source_indices( master_num_valid_beams, master_progress, master_index, max_indices, reverse=False, sort=False) if one_batch: assert len(unselected_indices) == 0 # for debugging selected_master_indices = master_index[selected_indices] batch_size = len(selected_indices) selected_enc_out = sg.model.reorder_encoder_out( master_enc_out, selected_indices.unsqueeze(1).expand(-1, k).flatten()) # if decode_calls % 50 == 0: # print(decode_calls) valid_beam_mask = master_valid_beam_mask[selected_indices] valid_beam_indices = valid_beam_mask.flatten().nonzero().flatten( ) # idk why need to flatten again reverse_idx = (torch.cumsum(valid_beam_mask.flatten( ), dim=0) * valid_beam_mask.flatten()).long( ) - 1 # it's fine to select whatever position for padding as they'll be removed later if num_pad > 0: if num_pad >= len( master_decoded_indices ): # edge case: we previously ran out of beams, and we are starting fresh now assert num_pad == len(master_decoded_indices) num_pad -= 1 master_decoded_indices = master_decoded_indices[num_pad:] master_state.clean_padding(num_pad) if reselect: selected_state_master_indices, selected_state = master_state.select_incremental_state( selected_master_indices, master_remove_indices, prog_min) master_state.num_sources -= len(master_remove_indices) sg.model.incremental_states = selected_state log_probs = master_log_probs[selected_indices] progress = master_progress[selected_indices] decoded_indices = master_decoded_indices[-progress.max() - 1:, selected_indices, :] end_found = master_end_found[selected_indices] done_lengths = master_done_lengths[selected_indices] best_finished_log_probs = master_best_finished_log_probs[ selected_indices] # flattened_indices = last_indices.flatten().unsqueeze(0) # 1 x batch*k # create valid beam indices from valid beam mask if one_batch and decode_calls == 0: selected_state_master_indices = master_index.clone() assert len(selected_state_master_indices) == len(valid_beam_indices) decode_calls += 1 n_expansions += len(valid_beam_indices) # use valid_beam_mask to select valid indices out of decoded_indices, encoder_outs, model incremental state decoding_selected_indices = decoded_indices.flatten( 1)[:, valid_beam_indices] # seq x selected selected_enc_out = sg.model.reorder_encoder_out( selected_enc_out, valid_beam_indices) assert torch.all( decoding_selected_indices.flatten(1).permute(1, 0)[:, 0] == 2) next_log_probs, _ = sg.model.forward_decoder( decoding_selected_indices.flatten(1).permute( 1, 0)[:, :master_progress.max() + 1], selected_enc_out, sg.temperature) # remake next_scores, state with dummies next_log_probs = next_log_probs[reverse_idx].view(1, batch_size, k, -1) # reorder incremental model state reorder_idx = reverse_idx next_log_probs = next_log_probs.view(1, batch_size, k, -1) # for edge case where EOS_ID appears later down in the beam but still needs to be dealt with correctly on the next step! end_found = end_found.unsqueeze(0).unsqueeze( 3 ) # batch_size x k x 1 of whether end index is in tgt_idx already; if so, make prob of padding 1 end_found = ( end_found + (progress + 1 == max_length).long().view(1, -1, 1, 1)).clamp(max=1) end_found_scores = torch.zeros_like(next_log_probs).to( parent_model.device) - 1e8 end_found_scores[:, :, :, EOS_ID] = 0 # make it so you only pick eos for the sequences that are already done, and don't duplicate them, by making other probs -inf next_log_probs = end_found * end_found_scores + ( 1 - end_found) * next_log_probs # ~ is for inverting the mask next_log_probs = next_log_probs - 1e8 * ( 1 - valid_beam_mask.unsqueeze(0).unsqueeze(3) ) # get rid of padding positions next_log_probs = next_log_probs + log_probs.unsqueeze(0).unsqueeze( 3) # 1, batch, k, vocab mc_probs, mc_indices = next_log_probs.topk(mc, dim=3) # 1, batch, k, mc top_log_probs, top_indices = mc_probs.flatten(2).topk( k, dim=2) # 1, batch, k mc_vocab_indices = top_indices % mc beam_indices = top_indices // mc # 1, batch, k vocab_indices = torch.gather( mc_indices.flatten(2).flatten(0, 1), 1, (mc_vocab_indices + beam_indices * mc).flatten(0, 1)).unsqueeze( 0) # 1, batch, k # check which vocab_indices are done (in the first beam position), and add the corresponding beam to an array of done predictions newly_done_all = (vocab_indices == EOS_ID).long() # 1, batch, k newly_done = torch.cumprod( newly_done_all, dim=2 ) # keep on beam if there's something above it that's not done yet done_lengths += newly_done.sum(dim=2).flatten( ) # update this one before others since we'll need it earlier newly_done_indices = newly_done.flatten().nonzero() # batch*k for j in newly_done_indices: source_idx = j // k # add to some master list with an entry for each source if len(master_done_beams[ selected_master_indices[source_idx]]) < find_top_z: finished_cand = decoded_indices[:, source_idx, beam_indices[0, source_idx, j % k]].flatten() finished_cand_length = progress[source_idx] + 1 while len(finished_cand) > 0 and finished_cand[-1] == EOS_ID: finished_cand = finished_cand[:-1] finished_cand_length -= 1 if len(finished_cand) > 0: # avoid length 0 master_done_beams[selected_master_indices[source_idx]].append( \ {'tokens': finished_cand.cpu(), 'score': (top_log_probs.flatten()[j] / ((finished_cand_length)**len_penalty)).item() }) best_finished_log_probs[source_idx] = max( best_finished_log_probs[source_idx], top_log_probs.flatten()[j]) else: # rarely with greedy search (beam size k = 1) you get stuff with length 0... so avoid crashing but give it low score master_done_beams[selected_master_indices[source_idx]].append( \ {'tokens': finished_cand.cpu(), 'score': -1e8 }) # then, shift log_probs and beam_indices for those beams and delete that beam(s); put in placeholder beam and log_prob at the k^th position # need to shift top_log_probs, beam_indices, vocab_indices accordingly top_log_probs = torch.cat([ top_log_probs, torch.zeros_like(top_log_probs).to(parent_model.device) - 1e8 ], dim=2) # 1, batch, 2k shift_indices = newly_done.sum( dim=2).unsqueeze(2) + torch.arange(k).to( parent_model.device).unsqueeze(0).unsqueeze(1) # 1, batch, k top_log_probs = torch.gather(top_log_probs, 2, shift_indices) shift_indices = shift_indices.clamp(max=k - 1) beam_indices = torch.gather(beam_indices, 2, shift_indices) vocab_indices = torch.gather(vocab_indices, 2, shift_indices) newly_done_all = torch.gather(newly_done_all, 2, shift_indices) log_probs = top_log_probs.squeeze(0) state_indices = (beam_indices + k * torch.arange(batch_size).to( parent_model.device).unsqueeze(1).repeat(1, k)).flatten() reorder_idx = reorder_idx[state_indices] # update valid beam mask ap_thresholds = (torch.max(log_probs[:, 0], best_finished_log_probs) - ap).unsqueeze(1) # batch x 1 valid_beam_mask = (log_probs > ap_thresholds).float() # batch x k # update valid beam mask based on how many beams are left for each source done_mask = pad_mask( k - done_lengths, parent_model.device, max_seqlen=k).permute( 1, 0) # batch x k of beams to keep, up to k - num done already all_low_prob_mask = 1 - valid_beam_mask.max( dim=1 )[0] # NOTE since we filter out by the absolute threshold including previously finished beams, we could get < k finished candidates, but always at least 1 found_z_mask = (all_low_prob_mask.bool() | (done_lengths >= find_top_z)).unsqueeze(1) valid_beam_mask = valid_beam_mask * done_mask * (1 - found_z_mask.long()) # filter the done ones out of all the master tensors keep_indices = (~found_z_mask).flatten().nonzero().flatten().long() remove_indices = (found_z_mask).flatten().nonzero().flatten().long() keep_indices = torch.cat( [selected_indices[keep_indices], unselected_indices], dim=0) master_remove_indices = master_index[selected_indices[remove_indices]] # update these quantities in their respective source_info objects after computing them # just deleting/concatenating to a single master tensor # master_decoded_indices seq x batch x k new_master_indices = torch.zeros( 1, master_decoded_indices.size(1), k).long().to(parent_model.device) # 1 x batch x k new_master_indices[:, selected_indices] = vocab_indices master_decoded_indices[:, selected_indices] = torch.gather( master_decoded_indices[:, selected_indices], 2, beam_indices.expand( master_decoded_indices[:, selected_indices].shape)) master_decoded_indices = torch.cat( [master_decoded_indices, new_master_indices], dim=0) if prog_min + 2 >= master_decoded_indices.shape[0]: master_decoded_indices = torch.cat([ torch.zeros(1, master_decoded_indices.size(1), k).long().to( parent_model.device), master_decoded_indices ], dim=0) master_decoded_indices[:, selected_indices] = torch.roll( master_decoded_indices[:, selected_indices], -1, 0) master_decoded_indices = master_decoded_indices[:-1] # master_log_probs batch x k master_log_probs[selected_indices] = log_probs # master_valid_beam_mask batch x k master_valid_beam_mask[selected_indices] = valid_beam_mask # master_num_valid_beams batch master_num_valid_beams = master_valid_beam_mask.sum(dim=1).long() # master_progress batch master_progress[selected_indices] += 1 # master_end_found batch x k master_end_found[selected_indices] = ( torch.gather(end_found.squeeze(3), 2, beam_indices) | newly_done_all[0, :, :]).squeeze(0) # master_done_lengths batch master_done_lengths[selected_indices] = done_lengths # master_best_finished_log_probs batch master_best_finished_log_probs[ selected_indices] = best_finished_log_probs # update master versions of sg.model state reorder_idx = reorder_idx[ valid_beam_mask.flatten().nonzero().flatten()] selected_state_master_indices = selected_state_master_indices[ reorder_idx] reorder_incremental_state(sg.model, reorder_idx) master_src_lengths = master_src_lengths[keep_indices] if master_src_lengths.sum( ) <= max_si_tokens - parent_model.args.max_tokens: reselect = True elif len(progress) < (master_progress == prog_min + 1).sum(): reselect = True else: reselect = False if reselect: # if not one_batch: # print('reselect', decode_calls) master_state.recache(selected_state_master_indices, sg.model.incremental_states) master_decoded_indices = master_decoded_indices[:, keep_indices, :] master_log_probs = master_log_probs[keep_indices] master_enc_out = sg.model.reorder_encoder_out(master_enc_out, keep_indices) master_valid_beam_mask = master_valid_beam_mask[keep_indices] master_num_valid_beams = master_num_valid_beams[keep_indices] master_index = master_index[keep_indices] master_progress = master_progress[keep_indices] master_end_found = master_end_found[keep_indices] master_done_lengths = master_done_lengths[keep_indices] master_best_finished_log_probs = master_best_finished_log_probs[ keep_indices] # delete any unnecessary padding so we don't keep increasing padding num_pad = (master_decoded_indices.sum(dim=1).sum(dim=1) == 0).sum( dim=0) if not reselect: assert num_pad == 0 assert all([bid is not None for bid in master_batch_ids]) for i in range(len(master_done_beams)): master_done_beams[i] = sorted(master_done_beams[i], key=lambda x: x['score'], reverse=True) if one_batch: return master_done_beams, decode_calls, n_expansions else: return master_batch_ids, master_done_beams, decode_calls, n_expansions
def train(model, dataset, optimizer, criterion, epoch, args, data_start_index): model.train() if args.iw: loader = DataLoader(dataset, batch_size=args.batch_size // 2) loss_meter = AverageMeter('loss', ':6.4f') acc_meter = AverageMeter('acc', ':6.4f') total_length = len(loader) progress = ProgressMeter(total_length, [loss_meter, acc_meter], prefix='Training: ') for batch_num, ((x_source, source_lengths), (x_target, target_lengths)) in enumerate( tqdm(iter(loader), total=len(loader), leave=False)): x = torch.cat([x_source, x_target]).to(args.device).long() y_source = torch.cat([ torch.cat([ torch.zeros(source_lengths[i]), -torch.ones(100 - source_lengths[i]) ]).unsqueeze(0) for i in range(source_lengths.shape[0]) ], dim=0) y_target = torch.cat([ torch.cat([ torch.ones(target_lengths[i]), -torch.ones(100 - target_lengths[i]) ]).unsqueeze(0) for i in range(target_lengths.shape[0]) ], dim=0) y = torch.cat([y_source, y_target]).float().to(args.device) lengths = torch.cat([source_lengths, target_lengths]).squeeze().to(args.device) scores = model(x, lengths) use_indices = y.flatten() != -1 loss = criterion(scores.flatten()[use_indices], y.flatten().float()[use_indices]) optimizer.zero_grad() loss.backward() optimizer.step() preds = torch.round(torch.sigmoid(scores.flatten()[use_indices])) accs = sum(preds == y.flatten() [use_indices]) / y.flatten()[use_indices].shape[0] acc_meter.update(accs.detach(), y.shape[0]) loss_meter.update(loss.detach(), y.shape[0]) if batch_num % args.train_print_freq == 0: progress.display(batch_num) progress.display(total_length) else: if data_start_index == 0: dataset.shuffle('train', seed=epoch + args.seed) if args.epoch_max_len is not None: data_end_index = min(data_start_index + args.epoch_max_len, len(dataset.splits['train'])) loader = dataset.loader('train', num_workers=args.num_workers, indices=list( range(data_start_index, data_end_index))) data_start_index = data_end_index if data_end_index < len( dataset.splits['train']) else 0 else: loader = dataset.loader('train', num_workers=args.num_workers) loss_meter = AverageMeter('loss', ':6.4f') total_length = len(loader) progress = ProgressMeter(total_length, [loss_meter], prefix='Training: ') for batch_num, batch in enumerate(tqdm(loader, total=len(loader))): batch = [tensor.to(args.device) for tensor in batch] inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch if args.task not in ['formality', 'iambic']: if not args.debug and len( inputs ) != args.batch_size: # it'll screw up the bias...? continue scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True) if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq expanded_labels = classification_targets.unsqueeze(1).expand( -1, scores.shape[1]) # batch x seq length_mask = pad_mask(lengths).permute(1, 0) # batch x seq loss = criterion( scores.flatten()[length_mask.flatten() == 1], expanded_labels.flatten().float()[length_mask.flatten() == 1]) elif args.task in ['iambic', 'newline']: use_indices = classification_targets.flatten() != -1 loss = criterion( scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices]) else: # topic, rhyme loss = criterion(scores.flatten(), labels.flatten().float()) optimizer.zero_grad() loss.backward() optimizer.step() loss_meter.update(loss.detach(), len(labels)) if batch_num % args.train_print_freq == 0: progress.display(batch_num) progress.display(total_length) return data_start_index
def validate(model, dataset, criterion, epoch, args): model.eval() random.seed(0) if args.iw: loader = DataLoader(SplitDataset(args, split='test'), batch_size=args.batch_size // 2) loss_meter = AverageMeter('loss', ':6.4f') acc_meter = AverageMeter('acc', ':6.4f') total_length = len(loader) progress = ProgressMeter(total_length, [loss_meter, acc_meter], prefix='Validation: ') with torch.no_grad(): for batch_num, ((x_source, source_lengths), (x_target, target_lengths)) in enumerate( tqdm(iter(loader), total=len(loader), leave=False)): x = torch.cat([x_source, x_target]).to(args.device).long() y_source = torch.cat([ torch.cat([ torch.zeros(source_lengths[i]), -torch.ones(100 - source_lengths[i]) ]).unsqueeze(0) for i in range(source_lengths.shape[0]) ], dim=0) y_target = torch.cat([ torch.cat([ torch.ones(target_lengths[i]), -torch.ones(100 - target_lengths[i]) ]).unsqueeze(0) for i in range(target_lengths.shape[0]) ], dim=0) y = torch.cat([y_source, y_target]).float().to(args.device) lengths = torch.cat([source_lengths, target_lengths]).squeeze().to(args.device) scores = model(x, lengths) use_indices = y.flatten() != -1 loss = criterion(scores.flatten()[use_indices], y.flatten().float()[use_indices]) preds = torch.round( torch.sigmoid(scores.flatten()[use_indices])) accs = sum(preds == y.flatten() [use_indices]) / y.flatten()[use_indices].shape[0] acc_meter.update(accs.detach(), y.shape[0]) loss_meter.update(loss.detach(), y.shape[0]) if batch_num % args.train_print_freq == 0: progress.display(batch_num) progress.display(total_length) return loss_meter.avg else: loader = dataset.loader('val', num_workers=args.num_workers) loss_meter = AverageMeter('loss', ':6.4f') total_length = len(loader) progress = ProgressMeter(total_length, [loss_meter], prefix='Validation: ') with torch.no_grad(): for batch_num, batch in enumerate(tqdm(loader, total=len(loader))): batch = [tensor.to(args.device) for tensor in batch] inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch if args.task not in ['formality', 'iambic']: # topic predictor if not args.debug and len(inputs) != args.batch_size: continue scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True) if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq expanded_labels = classification_targets.unsqueeze( 1).expand(-1, scores.shape[1]) # batch x seq length_mask = pad_mask(lengths).permute(1, 0) # batch x seq loss = criterion( scores.flatten()[length_mask.flatten() == 1], expanded_labels.flatten().float()[ length_mask.flatten() == 1]) elif args.task in ['iambic', 'newline']: use_indices = classification_targets.flatten() != -1 loss = criterion( scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices]) else: # topic, rhyme loss = criterion(scores.flatten(), labels.flatten().float()) loss_meter.update(loss.detach(), len(labels)) if batch_num % args.train_print_freq == 0: progress.display(batch_num) progress.display(total_length) return loss_meter.avg