def _get_lstm_inputs(self, root, inputs): """ :param root: :return: """ device = inputs.reference_tensor.device batch_size, time = self._hp.batch_size, self._hp.max_seq_len fullseq_shape = [batch_size, time] + list(inputs.enc_e_0.shape[1:]) lstm_inputs = AttrDict() # collect start and end indexes and values of all segments e_0s = torch.zeros(fullseq_shape, dtype=torch.float32, device=device) e_gs = torch.zeros(fullseq_shape, dtype=torch.float32, device=device) start_inds, end_inds = torch.zeros((batch_size, time), dtype=torch.float32, device=device), \ torch.zeros((batch_size, time), dtype=torch.float32, device=device) reset_indicator = torch.zeros((batch_size, time), dtype=torch.uint8, device=device) for segment in root.full_tree( ): # traversing the tree in breadth-first order. if segment.depth == 0: # if leaf-node start_ind = torch.ceil(segment.start_ind).type( torch.LongTensor) end_ind = torch.floor(segment.end_ind).type(torch.LongTensor) batchwise_assign(reset_indicator, start_ind, 1) # TODO iterating over batch must be gone for ex in range(self._hp.batch_size): if start_ind[ex] > end_ind[ex]: continue # happens if start and end floats have no int in between e_0s[ex, start_ind[ex]:end_ind[ex] + 1] = segment.e_0[ex] # +1 for including end_ind frame e_gs[ex, start_ind[ex]:end_ind[ex] + 1] = segment.e_g[ex] start_inds[ex, start_ind[ex]:end_ind[ex] + 1] = segment.start_ind[ex] end_inds[ex, start_ind[ex]:end_ind[ex] + 1] = segment.end_ind[ex] # perform linear interpolation time_steps = torch.arange(time, dtype=torch.float, device=device) inter = (time_steps - start_inds) / (end_inds - start_inds + 1e-7) lstm_inputs.reset_indicator = reset_indicator lstm_inputs.cell_input = (e_gs - e_0s) * broadcast_final(inter, e_gs) + e_0s lstm_inputs.reset_input = torch.cat([e_gs, e_0s], dim=2) return lstm_inputs
def _get_lstm_inputs(self, root, inputs): """ :param root: :return: """ device = inputs.reference_tensor.device batch_size, time = self._hp.batch_size, self._hp.max_seq_len fullseq_shape = [batch_size, time] + list(inputs.enc_e_0.shape[1:]) lstm_inputs = AttrDict() e_0s = torch.zeros(fullseq_shape, dtype=torch.float32, device=device) e_gs = torch.zeros(fullseq_shape, dtype=torch.float32, device=device) reset_indicator = torch.zeros((batch_size, time), dtype=torch.uint8, device=device) for segment in root.full_tree( ): # traversing the tree in breadth-first order. if segment.depth == 0: # if leaf-node e_0, e_g = segment.e_0, segment.e_g # TODO iterating over batch must be gone for ex in range(self._hp.batch_size): e_0s[ex, segment.start_ind[ex]:segment.end_ind[ex]] = e_0[ex] e_gs[ex, segment.start_ind[ex]:segment.end_ind[ex]] = e_g[ex] lstm_inputs.cell_input = e_gs lstm_inputs.reset_indicator = reset_indicator lstm_inputs.reset_input = torch.cat([e_gs, e_0s], dim=2) # TODO compute the latent variable with another LSTM # TODO the lstm has to observe timesteps # TODO aggregate the latent variables in the loop above # TODO add the latent variable to reset inputs return lstm_inputs