コード例 #1
0
    def _forward(self, batch_sizes, emission_probs):
        num_sequences = batch_sizes[0]
        max_length = batch_sizes.size(0)
        data_offset = 0

        log_denom = torch.zeros(num_sequences, 1, device=emission_probs.device)

        # 1) Initialize alpha if not given
        alpha = torch.empty(num_sequences,
                            self.num_states,
                            device=emission_probs.device)
        alpha, denom = normalize(self.markov.initial_probs *
                                 emission_probs[:num_sequences],
                                 return_denom=True)
        data_offset += num_sequences
        log_denom += denom.log()

        # 2) Run forward pass for emissions
        for t in range(max_length - 1):
            bs = batch_sizes[t + 1]
            probs = emission_probs[data_offset:data_offset + bs]
            alpha[:bs], denom = normalize(
                probs * (alpha[:bs] @ self.markov.transition_probs),
                return_denom=True)
            data_offset += bs
            log_denom[:bs] += denom.log()

        # 3) Compute log-likelihood
        log_likeli = (alpha.log() + log_denom).logsumexp(-1).sum()

        return alpha, -log_likeli
コード例 #2
0
ファイル: engine.py プロジェクト: groadabike/pycave
    def after_epoch(self, _):
        # This function sets the model's parameters and immediately terminates training.

        # Initial probs
        initial_probs = normalize(self.cache['initial_counts'])
        self.model.initial_probs.set_(initial_probs)

        # Transition probs
        tr_size = self.model.transition_probs.size()
        transition_counts = self.cache['transition_counts'].view(*tr_size)

        # Symmetric
        if self.cache['symmetric']:
            transition_counts += transition_counts.t()
        transition_probs = normalize(transition_counts)

        # Teleportation
        if self.cache['teleport_alpha'] > 0:
            teleport_factor = self.cache['teleport_alpha'] / tr_matrix.size(0)
            teleport_matrix = torch.ones_like(tr_matrix)
            beta = 1 - self.cache['teleport_alpha']
            tr_matrix = (tr_matrix - teleport_factor * teleport_matrix) / beta

        self.model.transition_probs.set_(transition_probs)

        # Terminate training
        raise xnn.CallbackException('MarkovModel training does not need iterations')
コード例 #3
0
    def train_batch(self, data, eps=0.01, patience=0):
        # This function acts as "expect" step of the Baum-Welch algorithm as well as computing
        # intermediate results for the M-step. The data is expected to be a packed sequence on the
        # correct device.

        # 1) Expect step
        emission_probs, alpha, beta, nll = self.model(data,
                                                      smooth=True,
                                                      return_emission=True)
        gamma = normalize(alpha * beta)

        alpha_ = packed_drop_last(alpha, data.batch_sizes)
        beta_ = packed_drop_first(beta, data.batch_sizes)
        emission_probs_ = packed_drop_first(emission_probs, data.batch_sizes)
        xi = self._compute_xi(alpha_, beta_, emission_probs_)

        # 2) Maximize step
        initial_probs = packed_get_first(gamma, data.batch_sizes).mean(0)
        transition_probs_num = xi.sum(0)
        transition_probs = normalize(transition_probs_num)
        output_update = self.model.emission.maximize(data.data, gamma,
                                                     self.requires_batching)
        nll = nll.item()

        # 3) Update cache depending on single-batch or multi-batch setting
        if not self.requires_batching:
            # We know that we only get a single batch
            self.cache = {
                'initial_probs': initial_probs,
                'transition_probs': transition_probs,
                'output_update': output_update,
                'neg_log_likelihood': nll / data.data.size(0)
            }
        else:
            num_seqs = data.batch_sizes[0].item()
            prv_wgt, cur_wgt = self._weights_for_updated_count(
                'num_sequences', num_seqs)

            self.cache['initial_probs'] = \
                self.cache['initial_probs'] * prv_wgt + initial_probs * cur_wgt
            self.cache['tr_probs_num'] += transition_probs_num
            self.cache['output_update'] = \
                self.model.emission.update(output_update, self.cache['output_update'])
            self.cache['nll_sum'] += nll
            self.cache['num_datapoints'] += data.data.size(0)

        # 4) Add metadata to cache (this is a no-op for all but the first batch)
        self.cache['eps'] = eps
        self.cache['patience'] = patience
コード例 #4
0
    def after_epoch(self, _):
        # This function acts as final "maximize" step of the Baum-Welch algorithm if batching was
        # performed. Otherwise, it just updates the parameters of the model.
        if not self.requires_batching:
            self.model.markov.initial_probs.set_(self.cache['initial_probs'])
            self.model.markov.transition_probs.set_(
                self.cache['transition_probs'])
            self.model.emission.apply(self.cache['output_update'])
            nll = self.cache['neg_log_likelihood']
        else:
            self.model.markov.initial_probs.set_(self.cache['initial_probs'])
            self.model.markov.transition_probs.set_(
                normalize(self.cache['tr_probs_num']))
            self.model.emission.apply(self.cache['output_update'])
            nll = self.cache['nll_sum'] / self.cache['num_datapoints']

        # This metadata field is always present
        eps = self.cache['eps']
        patience = self.cache['patience']

        # Check for early stopping
        if self.best_nll - nll < eps:
            if self.patience < patience:
                self.patience += 1
            else:
                raise xnn.CallbackException(
                    f'Training converged after {self.epoch} iterations.')
        else:
            self.best_nll = nll
            self.patience = 0
コード例 #5
0
 def _compute_xi(self, alpha_, beta_, emission_probs_):
     K = self.model.num_states
     alpha_ = alpha_.reshape(-1, K, 1)
     beta_ = (beta_ * emission_probs_).view(-1, 1, K)
     xi_num = torch.bmm(alpha_, beta_) * self.model.markov.transition_probs
     xi_num = xi_num.view(-1, K, K)
     return normalize(xi_num, [-1, -2])
コード例 #6
0
 def _rearrange_prediction_sequence(self, item):
     gamma = normalize(item['out'][0] * item['out'][1])
     packed = PackedSequence(data=gamma, batch_sizes=item['bs'])
     padded, lengths = pad_packed_sequence(packed, batch_first=True)
     if item['idx'] is not None:
         return [padded[i, :lengths[i]] for i in item['idx']]
     return [padded[i, :lengths[i]] for i in range(lengths.size(0))]
コード例 #7
0
    def train_batch(self, data, eps=0.01, reg=1e-6):
        # E-step: compute responsibilities
        responsibilities, nll = self.model(data)
        nll_ = nll.mean().item()

        # M-step: maximize
        gaussian_max = self.model.gaussian.maximize(data,
                                                    responsibilities,
                                                    self.requires_batching,
                                                    reg=reg)
        component_weights = normalize(gaussian_max['state_sums'])

        # Store in cache
        new_count = self.cache['count'] + data.size(0)
        prev_weight = self.cache['count'] / new_count
        cur_weight = data.size(0) / new_count

        self.cache['count'] = new_count
        self.cache['gaussian'] = self.model.gaussian.update(
            gaussian_max, self.cache['gaussian'])
        self.cache['component_weights'] = \
            self.cache['component_weights'] * prev_weight + component_weights * cur_weight
        self.cache['neg_log_likelihood'] = \
            self.cache['neg_log_likelihood'] * prev_weight + (nll_ * cur_weight)

        # Attach metadata
        self.cache['eps'] = eps
コード例 #8
0
 def predict_batch(self, data):
     # Get responsibilities and normalize them to get a distribution over components
     return normalize(self.model(data)[0])
コード例 #9
0
    def _forward_backward(self, batch_sizes, emission_probs):
        M = self.num_states
        device = emission_probs.device
        max_length = batch_sizes.size(0)

        # 1) Initialize (empty) parameters
        alpha = torch.empty(batch_sizes.sum(), M, device=device)
        beta = torch.empty_like(alpha)
        alpha_log_denom = torch.zeros(batch_sizes[0], 1, device=device)

        # 2) Run forward phase
        data_offset = batch_sizes[0].item()
        alpha[:data_offset], denom = normalize(self.markov.initial_probs *
                                               emission_probs[:data_offset],
                                               return_denom=True)
        alpha_log_denom += denom.log()

        for t in range(max_length - 1):
            prev_bs = batch_sizes[t].item()
            bs = batch_sizes[t + 1].item()

            probs = emission_probs[data_offset:data_offset + bs]
            previous = alpha[data_offset - prev_bs:data_offset - prev_bs + bs]

            alpha[data_offset:data_offset + bs], denom = normalize(
                probs * (previous @ self.markov.transition_probs),
                return_denom=True)

            alpha_log_denom[:bs] += denom.log()
            data_offset += bs

        # 3) Run backward phase (data_offset is sum of all batch sizes here)
        batch_sum = data_offset

        # 3.1) First, set all ones at the end of sequences
        for t in range(max_length - 1, -1, -1):
            prev_bs = batch_sizes[t + 1].item() if t < max_length - 1 else 0
            bs = batch_sizes[t].item()
            num_endings = bs - prev_bs

            beta[data_offset - num_endings:data_offset] = 1
            data_offset -= bs

        # 3.2) Now, iterate
        data_offset = batch_sum - batch_sizes[max_length - 1].item()
        for t in range(max_length - 1, 0, -1):
            # batch size of subsequent timestep - these need to be updated
            bs = batch_sizes[t].item()
            # this is the current timestep - only for `bs` need update
            current_bs = batch_sizes[t - 1].item()

            probs = emission_probs[data_offset:data_offset + bs]
            prev_beta = beta[data_offset:data_offset + bs]

            beta[data_offset-current_bs: data_offset-current_bs+bs] = \
                normalize((probs * prev_beta) @ self.markov.transition_probs.t())

            data_offset -= current_bs

        # 4) Compute log-likelihood
        alpha_ = packed_get_last(alpha, batch_sizes)
        log_likeli = (alpha_.log() + alpha_log_denom).logsumexp(-1).sum()

        return alpha, beta, -log_likeli