def sequential_data_preparation( input_batch, input_keep=1, start_index=2, end_index=3, dropout_index=1, device=get_device() ): """ Sequential Training Data Builder. Args: input_batch (torch.Tensor): Batch of padded sequences, output of nn.utils.rnn.pad_sequence(batch) of size `[sequence length, batch_size, 1]`. input_keep (float): The probability not to drop input sequence tokens according to a Bernoulli distribution with p = input_keep. Defaults to 1. start_index (int): The index of the sequence start token. end_index (int): The index of the sequence end token. dropout_index (int): The index of the dropout token. Defaults to 1. device (torch.device): Device to be used. Returns: (torch.Tensor, torch.Tensor, torch.Tensor): encoder_seq, decoder_seq, target_seq encoder_seq is a batch of padded input sequences starting with the start_index, of size `[sequence length +1, batch_size]`. decoder_seq is like encoder_seq but word dropout is applied (so if input_keep==1, then decoder_seq = encoder_seq). target_seq (torch.Tensor): Batch of padded target sequences ending in the end_index, of size `[sequence length +1, batch_size]`. """ batch_size = input_batch.shape[1] input_batch = input_batch.long().to(device) decoder_batch = input_batch.clone() # apply token dropout if keep != 1 if input_keep != 1: # build dropout indices consisting of dropout_index dropout_indices = torch.LongTensor( dropout_index * torch.ones(1, batch_size).numpy() ) # mask for token dropout mask = Bernoulli(input_keep).sample((input_batch.shape[0], )) mask = torch.LongTensor(mask.numpy()) dropout_loc = np.where(mask == 0)[0] decoder_batch[dropout_loc] = dropout_indices end_padding = torch.LongTensor(torch.zeros(1, batch_size).numpy()) target_seq = torch.cat((input_batch[1:, :], end_padding), dim=0) target_seq = copy.deepcopy(target_seq).to(device) return input_batch, decoder_batch, target_seq
def _process_sample(sample): if len(sample.shape) != 1: raise ValueError input = sample.long().to(device) decoder = input.clone() # apply token dropout if keep != 1 if input_keep != 1: # mask for token dropout mask = Bernoulli(input_keep).sample(input.shape) mask = torch.LongTensor(mask.numpy()) dropout_loc = np.where(mask == 0)[0] decoder[dropout_loc] = dropout_index # just .clone() propagates to graph target = torch.cat( [input[1:].detach().clone(), torch.Tensor([0]).long().to(device)]) return input, decoder, target.to(device)