def _forward(self, batch_sizes, emission_probs): num_sequences = batch_sizes[0] max_length = batch_sizes.size(0) data_offset = 0 log_denom = torch.zeros(num_sequences, 1, device=emission_probs.device) # 1) Initialize alpha if not given alpha = torch.empty(num_sequences, self.num_states, device=emission_probs.device) alpha, denom = normalize(self.markov.initial_probs * emission_probs[:num_sequences], return_denom=True) data_offset += num_sequences log_denom += denom.log() # 2) Run forward pass for emissions for t in range(max_length - 1): bs = batch_sizes[t + 1] probs = emission_probs[data_offset:data_offset + bs] alpha[:bs], denom = normalize( probs * (alpha[:bs] @ self.markov.transition_probs), return_denom=True) data_offset += bs log_denom[:bs] += denom.log() # 3) Compute log-likelihood log_likeli = (alpha.log() + log_denom).logsumexp(-1).sum() return alpha, -log_likeli
def after_epoch(self, _): # This function sets the model's parameters and immediately terminates training. # Initial probs initial_probs = normalize(self.cache['initial_counts']) self.model.initial_probs.set_(initial_probs) # Transition probs tr_size = self.model.transition_probs.size() transition_counts = self.cache['transition_counts'].view(*tr_size) # Symmetric if self.cache['symmetric']: transition_counts += transition_counts.t() transition_probs = normalize(transition_counts) # Teleportation if self.cache['teleport_alpha'] > 0: teleport_factor = self.cache['teleport_alpha'] / tr_matrix.size(0) teleport_matrix = torch.ones_like(tr_matrix) beta = 1 - self.cache['teleport_alpha'] tr_matrix = (tr_matrix - teleport_factor * teleport_matrix) / beta self.model.transition_probs.set_(transition_probs) # Terminate training raise xnn.CallbackException('MarkovModel training does not need iterations')
def train_batch(self, data, eps=0.01, patience=0): # This function acts as "expect" step of the Baum-Welch algorithm as well as computing # intermediate results for the M-step. The data is expected to be a packed sequence on the # correct device. # 1) Expect step emission_probs, alpha, beta, nll = self.model(data, smooth=True, return_emission=True) gamma = normalize(alpha * beta) alpha_ = packed_drop_last(alpha, data.batch_sizes) beta_ = packed_drop_first(beta, data.batch_sizes) emission_probs_ = packed_drop_first(emission_probs, data.batch_sizes) xi = self._compute_xi(alpha_, beta_, emission_probs_) # 2) Maximize step initial_probs = packed_get_first(gamma, data.batch_sizes).mean(0) transition_probs_num = xi.sum(0) transition_probs = normalize(transition_probs_num) output_update = self.model.emission.maximize(data.data, gamma, self.requires_batching) nll = nll.item() # 3) Update cache depending on single-batch or multi-batch setting if not self.requires_batching: # We know that we only get a single batch self.cache = { 'initial_probs': initial_probs, 'transition_probs': transition_probs, 'output_update': output_update, 'neg_log_likelihood': nll / data.data.size(0) } else: num_seqs = data.batch_sizes[0].item() prv_wgt, cur_wgt = self._weights_for_updated_count( 'num_sequences', num_seqs) self.cache['initial_probs'] = \ self.cache['initial_probs'] * prv_wgt + initial_probs * cur_wgt self.cache['tr_probs_num'] += transition_probs_num self.cache['output_update'] = \ self.model.emission.update(output_update, self.cache['output_update']) self.cache['nll_sum'] += nll self.cache['num_datapoints'] += data.data.size(0) # 4) Add metadata to cache (this is a no-op for all but the first batch) self.cache['eps'] = eps self.cache['patience'] = patience
def after_epoch(self, _): # This function acts as final "maximize" step of the Baum-Welch algorithm if batching was # performed. Otherwise, it just updates the parameters of the model. if not self.requires_batching: self.model.markov.initial_probs.set_(self.cache['initial_probs']) self.model.markov.transition_probs.set_( self.cache['transition_probs']) self.model.emission.apply(self.cache['output_update']) nll = self.cache['neg_log_likelihood'] else: self.model.markov.initial_probs.set_(self.cache['initial_probs']) self.model.markov.transition_probs.set_( normalize(self.cache['tr_probs_num'])) self.model.emission.apply(self.cache['output_update']) nll = self.cache['nll_sum'] / self.cache['num_datapoints'] # This metadata field is always present eps = self.cache['eps'] patience = self.cache['patience'] # Check for early stopping if self.best_nll - nll < eps: if self.patience < patience: self.patience += 1 else: raise xnn.CallbackException( f'Training converged after {self.epoch} iterations.') else: self.best_nll = nll self.patience = 0
def _compute_xi(self, alpha_, beta_, emission_probs_): K = self.model.num_states alpha_ = alpha_.reshape(-1, K, 1) beta_ = (beta_ * emission_probs_).view(-1, 1, K) xi_num = torch.bmm(alpha_, beta_) * self.model.markov.transition_probs xi_num = xi_num.view(-1, K, K) return normalize(xi_num, [-1, -2])
def _rearrange_prediction_sequence(self, item): gamma = normalize(item['out'][0] * item['out'][1]) packed = PackedSequence(data=gamma, batch_sizes=item['bs']) padded, lengths = pad_packed_sequence(packed, batch_first=True) if item['idx'] is not None: return [padded[i, :lengths[i]] for i in item['idx']] return [padded[i, :lengths[i]] for i in range(lengths.size(0))]
def train_batch(self, data, eps=0.01, reg=1e-6): # E-step: compute responsibilities responsibilities, nll = self.model(data) nll_ = nll.mean().item() # M-step: maximize gaussian_max = self.model.gaussian.maximize(data, responsibilities, self.requires_batching, reg=reg) component_weights = normalize(gaussian_max['state_sums']) # Store in cache new_count = self.cache['count'] + data.size(0) prev_weight = self.cache['count'] / new_count cur_weight = data.size(0) / new_count self.cache['count'] = new_count self.cache['gaussian'] = self.model.gaussian.update( gaussian_max, self.cache['gaussian']) self.cache['component_weights'] = \ self.cache['component_weights'] * prev_weight + component_weights * cur_weight self.cache['neg_log_likelihood'] = \ self.cache['neg_log_likelihood'] * prev_weight + (nll_ * cur_weight) # Attach metadata self.cache['eps'] = eps
def predict_batch(self, data): # Get responsibilities and normalize them to get a distribution over components return normalize(self.model(data)[0])
def _forward_backward(self, batch_sizes, emission_probs): M = self.num_states device = emission_probs.device max_length = batch_sizes.size(0) # 1) Initialize (empty) parameters alpha = torch.empty(batch_sizes.sum(), M, device=device) beta = torch.empty_like(alpha) alpha_log_denom = torch.zeros(batch_sizes[0], 1, device=device) # 2) Run forward phase data_offset = batch_sizes[0].item() alpha[:data_offset], denom = normalize(self.markov.initial_probs * emission_probs[:data_offset], return_denom=True) alpha_log_denom += denom.log() for t in range(max_length - 1): prev_bs = batch_sizes[t].item() bs = batch_sizes[t + 1].item() probs = emission_probs[data_offset:data_offset + bs] previous = alpha[data_offset - prev_bs:data_offset - prev_bs + bs] alpha[data_offset:data_offset + bs], denom = normalize( probs * (previous @ self.markov.transition_probs), return_denom=True) alpha_log_denom[:bs] += denom.log() data_offset += bs # 3) Run backward phase (data_offset is sum of all batch sizes here) batch_sum = data_offset # 3.1) First, set all ones at the end of sequences for t in range(max_length - 1, -1, -1): prev_bs = batch_sizes[t + 1].item() if t < max_length - 1 else 0 bs = batch_sizes[t].item() num_endings = bs - prev_bs beta[data_offset - num_endings:data_offset] = 1 data_offset -= bs # 3.2) Now, iterate data_offset = batch_sum - batch_sizes[max_length - 1].item() for t in range(max_length - 1, 0, -1): # batch size of subsequent timestep - these need to be updated bs = batch_sizes[t].item() # this is the current timestep - only for `bs` need update current_bs = batch_sizes[t - 1].item() probs = emission_probs[data_offset:data_offset + bs] prev_beta = beta[data_offset:data_offset + bs] beta[data_offset-current_bs: data_offset-current_bs+bs] = \ normalize((probs * prev_beta) @ self.markov.transition_probs.t()) data_offset -= current_bs # 4) Compute log-likelihood alpha_ = packed_get_last(alpha, batch_sizes) log_likeli = (alpha_.log() + alpha_log_denom).logsumexp(-1).sum() return alpha, beta, -log_likeli