def compute_loss(llrs, metadata=None, ptar=0.01, mask=None, loss_type='cross_entropy', return_info=False, enrollment_ids=None, ids_to_idxs=None): if metadata is None: assert mask is not None # The mask is assumed to have been generated using the Key class, # so targets are labeled with 1, impostors with -1 and non-scored trials with 0. valid = mask != 0 same_spk = mask == 1 else: assert metadata is not None if enrollment_ids is None: # The llrs are assumed to correspond to all vs all trials same_spk, valid = utils.create_scoring_masks(metadata) else: # The llrs are assumed to correspond to all samples. The speaker ids in this case are # indices that have to be mapped to enrollment indexes. spk_ids = metadata['speaker_id'].type(torch.int) same_spk = utils.onehot_for_class_ids(spk_ids, enrollment_ids, ids_to_idxs).T # All samples are valid. valid = torch.ones_like(llrs).type(torch.bool) # Select the valid llrs and shift them to convert them to logits llrs = llrs[valid] logits = llrs + logit(ptar) labels = same_spk[valid] labels = labels.type(logits.type()) # The loss will be given by tar_weight * tar_loss + imp_weight * imp_loss ptart = torch.as_tensor(ptar) tar_weight = ptart / torch.sum(labels == 1) imp_weight = (1 - ptart) / torch.sum(labels == 0) # Finally, compute the loss and multiply it by the weight that corresponds to the impostors # Loss types are taken from Niko Brummer's paper: "Likelihood-ratio calibration using prior-weighted proper scoring rules" if loss_type == "cross_entropy": criterion = nn.BCEWithLogitsLoss(pos_weight=tar_weight / imp_weight, reduction='sum') baseline_loss = -ptar * np.log(ptar) - (1 - ptar) * np.log(1 - ptar) loss = criterion(logits, labels) * imp_weight / baseline_loss elif loss_type == "brier": baseline_loss = ptar * (1 - ptar)**2 + (1 - ptar) * ptar**2 posteriors = torch.sigmoid(logits) loss = torch.sum(labels * tar_weight * (1 - posteriors)**2 + (1 - labels) * imp_weight * posteriors**2) / baseline_loss if return_info: return loss, llrs, labels else: return loss
def compute_loss(llrs, metadata=None, ptar=0.01, mask=None, loss_type='cross_entropy', return_info=False): if metadata is None: assert mask is not None # The mask is assumed to have been generated using the Key class, # so targets are labeled with 1, impostors with -1 and non-scored trials with 0. valid = mask != 0 same_spk = mask == 1 else: assert metadata is not None same_spk, valid = utils.create_scoring_masks(metadata) # Shift the llrs to convert them to posteriors logits = llrs + logit(ptar) logits = logits[valid] labels = same_spk[valid] labels = labels.type(logits.type()) # The loss will be given by tar_weight * tar_loss + imp_weight * imp_loss ptart = torch.as_tensor(ptar) tar_weight = ptart/torch.sum(labels==1) imp_weight = (1-ptart)/torch.sum(labels==0) # Finally, compute the loss and multiply it by the weight that corresponds to the impostors # Loss types are taken from Niko Brummer's paper: "Likelihood-ratio calibration using prior-weighted proper scoring rules" if loss_type == "cross_entropy": criterion = nn.BCEWithLogitsLoss(pos_weight=tar_weight/imp_weight, reduction='sum') baseline_loss = -ptar*np.log(ptar) - (1-ptar)*np.log(1-ptar) loss = criterion(logits, labels)*imp_weight/baseline_loss elif loss_type == "brier": baseline_loss = ptar * (1-ptar)**2 + (1-ptar) * ptar**2 posteriors = torch.sigmoid(logits) loss = torch.sum(labels*tar_weight*(1-posteriors)**2 + (1-labels)*imp_weight*posteriors**2)/baseline_loss if return_info: return loss, llrs, labels else: return loss
def init_params_with_data(self, dataset, config, device=None, subset=None): balance_by_domain = config.balance_batches_by_domain assert 'init_params' in config init_params = config.init_params # The code here repeats the steps in forward above, but adds the steps necessary for initialization. # I chose to keep these two methods separate to leave the forward small and easy to read. with torch.no_grad(): x, meta, _ = dataset.get_data_and_meta(subset) speaker_ids = meta['speaker_id'] domain_ids = meta['domain_id'] x_torch = utils.np_to_torch(x, device) if init_params.get("random"): self.lda_stage.init_random(init_params.get('stdev',0.1)) else: self.lda_stage.init_with_lda(x, speaker_ids, init_params, sec_ids=domain_ids) x2_torch = self.lda_stage(x_torch) x2 = x2_torch.cpu().numpy() if hasattr(self,'si_stage1'): if self.si_input == 'main_input': si_input_torch = x_torch si_input = x else: si_input_torch = x2_torch si_input = x2 if init_params.get("random"): self.si_stage1.init_random(init_params.get('stdev',0.1)) else: self.si_stage1.init_with_lda(si_input, speaker_ids, init_params, sec_ids=domain_ids, complement=True) s2_torch = self.si_stage1(si_input_torch) if init_params.get('init_si_stage2_with_domain_gb', False): # Initialize the second stage of the si-extractor to be a gaussian backend that predicts # the posterior of each domain. In this case, the number of domains has to coincide with the # dimension of the side info vector assert self.si_dim == len(np.unique(domain_ids)) self.si_stage2.init_with_lda(s2_torch.cpu().numpy(), domain_ids, init_params, sec_ids=speaker_ids, gaussian_backend=True) else: # This is the only component that is initialized randomly unless otherwise indicated by the variable "init_si_stage2_with_domain_gb" self.si_stage2.init_random(init_params.get('w_init', 0.5), init_params.get('b_init', 0.0), init_params.get('type', 'normal')) if hasattr(self,'shift_selector'): # Initialize the shifts as the mean of the lda outputs weighted by the si si_torch = self.si_stage2(s2_torch) si = si_torch.cpu().numpy() if init_params.get("random"): self.shift_selector.init_random(init_params.get('stdev',0.1)) else: self.shift_selector.init_with_weighted_means(x2, si) x2_torch -= self.shift_selector(si_torch) x2 = x2_torch.cpu().numpy() if init_params.get("random"): self.plda_stage.init_random(init_params.get('stdev',0.1)) else: self.plda_stage.init_with_plda(x2, speaker_ids, init_params, domain_ids=domain_ids) # Since the training data is usually large, we cannot create all possible trials for x3. # So, to create a bunch of trials, we just create a trial loader with a large batch size. # This means we need to rerun lda again, but it is a small price to pay for the # convenience of reusing the machinery of trial creation in the TrialLoader. loader = ddata.TrialLoader(dataset, device, seed=0, batch_size=2000, num_batches=1, balance_by_domain=balance_by_domain, subset=subset) x_torch, meta_batch = next(loader.__iter__()) x2_torch = self.lda_stage(x_torch) scrs_torch = self.plda_stage(x2_torch) same_spk_torch, valid_torch = utils.create_scoring_masks(meta_batch) scrs, same_spk, valid = [v.detach().cpu().numpy() for v in [scrs_torch, same_spk_torch, valid_torch]] if init_params.get("random"): self.cal_stage.init_random(init_params.get('stdev',0.1)) else: self.cal_stage.init_with_logreg(scrs, same_spk, valid, config.ptar, std_for_mats=init_params.get('std_for_cal_matrices',0)) dummy_durs = torch.ones(scrs.shape[0]).to(device) dummy_si = torch.zeros(scrs.shape[0], self.si_dim).to(device) llrs_torch = self.cal_stage(scrs_torch, dummy_durs, dummy_si) mask = np.ones_like(same_spk, dtype=int) mask[~same_spk] = -1 mask[~valid] = 0 return compute_loss(llrs_torch, mask=utils.np_to_torch(mask, device), ptar=config.ptar, loss_type=config.loss)