예제 #1
0
def compute_loss(llrs,
                 metadata=None,
                 ptar=0.01,
                 mask=None,
                 loss_type='cross_entropy',
                 return_info=False,
                 enrollment_ids=None,
                 ids_to_idxs=None):

    if metadata is None:
        assert mask is not None
        # The mask is assumed to have been generated using the Key class,
        # so targets are labeled with 1, impostors with -1 and non-scored trials with 0.
        valid = mask != 0
        same_spk = mask == 1
    else:
        assert metadata is not None
        if enrollment_ids is None:
            # The llrs are assumed to correspond to all vs all trials
            same_spk, valid = utils.create_scoring_masks(metadata)
        else:
            # The llrs are assumed to correspond to all samples. The speaker ids in this case are
            # indices that have to be mapped to enrollment indexes.
            spk_ids = metadata['speaker_id'].type(torch.int)
            same_spk = utils.onehot_for_class_ids(spk_ids, enrollment_ids,
                                                  ids_to_idxs).T
            # All samples are valid.
            valid = torch.ones_like(llrs).type(torch.bool)

    # Select the valid llrs and shift them to convert them to logits
    llrs = llrs[valid]
    logits = llrs + logit(ptar)
    labels = same_spk[valid]
    labels = labels.type(logits.type())

    # The loss will be given by tar_weight * tar_loss + imp_weight * imp_loss
    ptart = torch.as_tensor(ptar)
    tar_weight = ptart / torch.sum(labels == 1)
    imp_weight = (1 - ptart) / torch.sum(labels == 0)

    # Finally, compute the loss and multiply it by the weight that corresponds to the impostors
    # Loss types are taken from Niko Brummer's paper: "Likelihood-ratio calibration using prior-weighted proper scoring rules"
    if loss_type == "cross_entropy":
        criterion = nn.BCEWithLogitsLoss(pos_weight=tar_weight / imp_weight,
                                         reduction='sum')
        baseline_loss = -ptar * np.log(ptar) - (1 - ptar) * np.log(1 - ptar)
        loss = criterion(logits, labels) * imp_weight / baseline_loss
    elif loss_type == "brier":
        baseline_loss = ptar * (1 - ptar)**2 + (1 - ptar) * ptar**2
        posteriors = torch.sigmoid(logits)
        loss = torch.sum(labels * tar_weight * (1 - posteriors)**2 +
                         (1 - labels) * imp_weight *
                         posteriors**2) / baseline_loss

    if return_info:
        return loss, llrs, labels
    else:
        return loss
예제 #2
0
def compute_loss(llrs, metadata=None, ptar=0.01, mask=None, loss_type='cross_entropy', return_info=False):

    if metadata is None:
        assert mask is not None
        # The mask is assumed to have been generated using the Key class, 
        # so targets are labeled with 1, impostors with -1 and non-scored trials with 0.
        valid = mask != 0
        same_spk = mask == 1
    else:
        assert metadata is not None
        same_spk, valid = utils.create_scoring_masks(metadata)

    # Shift the llrs to convert them to posteriors
    logits = llrs + logit(ptar)
    logits = logits[valid]
    labels = same_spk[valid]
    labels = labels.type(logits.type())

    # The loss will be given by tar_weight * tar_loss + imp_weight * imp_loss
    ptart = torch.as_tensor(ptar)
    tar_weight = ptart/torch.sum(labels==1)
    imp_weight = (1-ptart)/torch.sum(labels==0)

    # Finally, compute the loss and multiply it by the weight that corresponds to the impostors
    # Loss types are taken from Niko Brummer's paper: "Likelihood-ratio calibration using prior-weighted proper scoring rules"
    if loss_type == "cross_entropy":
        criterion = nn.BCEWithLogitsLoss(pos_weight=tar_weight/imp_weight, reduction='sum')
        baseline_loss = -ptar*np.log(ptar) - (1-ptar)*np.log(1-ptar)
        loss = criterion(logits, labels)*imp_weight/baseline_loss
    elif loss_type == "brier":
        baseline_loss = ptar * (1-ptar)**2 + (1-ptar) * ptar**2
        posteriors = torch.sigmoid(logits)
        loss = torch.sum(labels*tar_weight*(1-posteriors)**2 + (1-labels)*imp_weight*posteriors**2)/baseline_loss

    if return_info:
        return loss, llrs, labels
    else:
        return loss
예제 #3
0
    def init_params_with_data(self, dataset, config, device=None, subset=None):

        balance_by_domain = config.balance_batches_by_domain
        assert 'init_params' in config
        init_params = config.init_params
        
        # The code here repeats the steps in forward above, but adds the steps necessary for initialization.
        # I chose to keep these two methods separate to leave the forward small and easy to read.
        with torch.no_grad():

            x, meta, _ = dataset.get_data_and_meta(subset)
            speaker_ids = meta['speaker_id']
            domain_ids  = meta['domain_id']
            x_torch = utils.np_to_torch(x, device)
            
            if init_params.get("random"):
                self.lda_stage.init_random(init_params.get('stdev',0.1))
            else:
                self.lda_stage.init_with_lda(x, speaker_ids, init_params, sec_ids=domain_ids)

            x2_torch = self.lda_stage(x_torch)
            x2 = x2_torch.cpu().numpy()

            if hasattr(self,'si_stage1'):
                if self.si_input == 'main_input':
                    si_input_torch = x_torch
                    si_input = x
                else:
                    si_input_torch = x2_torch
                    si_input = x2

                if init_params.get("random"):
                    self.si_stage1.init_random(init_params.get('stdev',0.1))
                else:
                    self.si_stage1.init_with_lda(si_input, speaker_ids, init_params, sec_ids=domain_ids, complement=True)

                s2_torch = self.si_stage1(si_input_torch)
                    
                if init_params.get('init_si_stage2_with_domain_gb', False):
                    # Initialize the second stage of the si-extractor to be a gaussian backend that predicts
                    # the posterior of each domain. In this case, the number of domains has to coincide with the
                    # dimension of the side info vector
                    assert self.si_dim == len(np.unique(domain_ids))
                    self.si_stage2.init_with_lda(s2_torch.cpu().numpy(), domain_ids, init_params, sec_ids=speaker_ids, gaussian_backend=True)
    
                else:
                    # This is the only component that is initialized randomly unless otherwise indicated by the variable "init_si_stage2_with_domain_gb"
                    self.si_stage2.init_random(init_params.get('w_init', 0.5), init_params.get('b_init', 0.0), init_params.get('type', 'normal'))

                if hasattr(self,'shift_selector'):
                    # Initialize the shifts as the mean of the lda outputs weighted by the si
                    si_torch = self.si_stage2(s2_torch)
                    si = si_torch.cpu().numpy()
                    if init_params.get("random"):
                        self.shift_selector.init_random(init_params.get('stdev',0.1))
                    else:
                        self.shift_selector.init_with_weighted_means(x2, si)
                    x2_torch -= self.shift_selector(si_torch)
                    x2 = x2_torch.cpu().numpy()

            if init_params.get("random"):
                self.plda_stage.init_random(init_params.get('stdev',0.1))
            else:    
                self.plda_stage.init_with_plda(x2, speaker_ids, init_params, domain_ids=domain_ids)

            # Since the training data is usually large, we cannot create all possible trials for x3.
            # So, to create a bunch of trials, we just create a trial loader with a large batch size.
            # This means we need to rerun lda again, but it is a small price to pay for the 
            # convenience of reusing the machinery of trial creation in the TrialLoader.
            loader = ddata.TrialLoader(dataset, device, seed=0, batch_size=2000, num_batches=1, balance_by_domain=balance_by_domain, subset=subset)
            x_torch, meta_batch = next(loader.__iter__())
            x2_torch = self.lda_stage(x_torch)
            scrs_torch = self.plda_stage(x2_torch)
            same_spk_torch, valid_torch = utils.create_scoring_masks(meta_batch)
            scrs, same_spk, valid = [v.detach().cpu().numpy() for v in [scrs_torch, same_spk_torch, valid_torch]]

            if init_params.get("random"):
                self.cal_stage.init_random(init_params.get('stdev',0.1))
            else:
                self.cal_stage.init_with_logreg(scrs, same_spk, valid, config.ptar, std_for_mats=init_params.get('std_for_cal_matrices',0))

            dummy_durs = torch.ones(scrs.shape[0]).to(device) 
            dummy_si = torch.zeros(scrs.shape[0], self.si_dim).to(device)
            llrs_torch = self.cal_stage(scrs_torch, dummy_durs, dummy_si)
            mask = np.ones_like(same_spk, dtype=int)
            mask[~same_spk] = -1
            mask[~valid] = 0
            
            return compute_loss(llrs_torch, mask=utils.np_to_torch(mask, device), ptar=config.ptar, loss_type=config.loss)