예제 #1
0
    def _get_lstm_inputs(self, root, inputs):
        """
        :param root:
        :return:
        """
        device = inputs.reference_tensor.device
        batch_size, time = self._hp.batch_size, self._hp.max_seq_len
        fullseq_shape = [batch_size, time] + list(inputs.enc_e_0.shape[1:])
        lstm_inputs = AttrDict()

        # collect start and end indexes and values of all segments
        e_0s = torch.zeros(fullseq_shape, dtype=torch.float32, device=device)
        e_gs = torch.zeros(fullseq_shape, dtype=torch.float32, device=device)
        start_inds, end_inds = torch.zeros((batch_size, time), dtype=torch.float32, device=device), \
                               torch.zeros((batch_size, time), dtype=torch.float32, device=device)
        reset_indicator = torch.zeros((batch_size, time),
                                      dtype=torch.uint8,
                                      device=device)
        for segment in root.full_tree(
        ):  # traversing the tree in breadth-first order.
            if segment.depth == 0:  # if leaf-node
                start_ind = torch.ceil(segment.start_ind).type(
                    torch.LongTensor)
                end_ind = torch.floor(segment.end_ind).type(torch.LongTensor)
                batchwise_assign(reset_indicator, start_ind, 1)

                # TODO iterating over batch must be gone
                for ex in range(self._hp.batch_size):
                    if start_ind[ex] > end_ind[ex]:
                        continue  # happens if start and end floats have no int in between
                    e_0s[ex, start_ind[ex]:end_ind[ex] +
                         1] = segment.e_0[ex]  # +1 for including end_ind frame
                    e_gs[ex, start_ind[ex]:end_ind[ex] + 1] = segment.e_g[ex]
                    start_inds[ex, start_ind[ex]:end_ind[ex] +
                               1] = segment.start_ind[ex]
                    end_inds[ex, start_ind[ex]:end_ind[ex] +
                             1] = segment.end_ind[ex]

        # perform linear interpolation
        time_steps = torch.arange(time, dtype=torch.float, device=device)
        inter = (time_steps - start_inds) / (end_inds - start_inds + 1e-7)

        lstm_inputs.reset_indicator = reset_indicator
        lstm_inputs.cell_input = (e_gs - e_0s) * broadcast_final(inter,
                                                                 e_gs) + e_0s
        lstm_inputs.reset_input = torch.cat([e_gs, e_0s], dim=2)

        return lstm_inputs
예제 #2
0
    def _get_lstm_inputs(self, root, inputs):
        """
        :param root:
        :return:
        """
        device = inputs.reference_tensor.device
        batch_size, time = self._hp.batch_size, self._hp.max_seq_len
        fullseq_shape = [batch_size, time] + list(inputs.enc_e_0.shape[1:])
        lstm_inputs = AttrDict()

        e_0s = torch.zeros(fullseq_shape, dtype=torch.float32, device=device)
        e_gs = torch.zeros(fullseq_shape, dtype=torch.float32, device=device)
        reset_indicator = torch.zeros((batch_size, time),
                                      dtype=torch.uint8,
                                      device=device)
        for segment in root.full_tree(
        ):  # traversing the tree in breadth-first order.
            if segment.depth == 0:  # if leaf-node
                e_0, e_g = segment.e_0, segment.e_g

                # TODO iterating over batch must be gone
                for ex in range(self._hp.batch_size):
                    e_0s[ex,
                         segment.start_ind[ex]:segment.end_ind[ex]] = e_0[ex]
                    e_gs[ex,
                         segment.start_ind[ex]:segment.end_ind[ex]] = e_g[ex]

        lstm_inputs.cell_input = e_gs
        lstm_inputs.reset_indicator = reset_indicator
        lstm_inputs.reset_input = torch.cat([e_gs, e_0s], dim=2)

        # TODO compute the latent variable with another LSTM
        #      TODO the lstm has to observe timesteps
        #      TODO aggregate the latent variables in the loop above
        # TODO add the latent variable to reset inputs
        return lstm_inputs