def get_pm_loss(self, topk, grad_estimator,
                 grad_estimator_kwargs = {'grad_estimator_kwargs': None}):
     data = torch.rand((1, 5, 5))
     log_class_weights = self.get_log_q()
     return rb_lib.get_raoblackwell_ps_loss(self.get_f_z, log_class_weights, topk,
                             grad_estimator,
                             grad_estimator_kwargs = grad_estimator_kwargs,
                             data = data)
    def get_pm_loss(self, topk, grad_estimator, n_samples=1):
        log_q = self.get_log_q()

        pm_loss = 0.0
        for i in range(n_samples):
            pm_loss += rb_lib.get_raoblackwell_ps_loss(self.f_z, log_q, topk,
                                                       grad_estimator)

        return pm_loss / n_samples
Пример #3
0
    def get_pm_loss(self, topk, grad_estimator, n_samples=1):
        # returns the pseudo-loss: when backwards is called, this returns
        # an estimate of the gradient.

        log_q = self.get_log_q()

        pm_loss = 0.0
        for i in range(n_samples):
            pm_loss += rb_lib.get_raoblackwell_ps_loss(self.f_z, log_q, topk,
                                                       grad_estimator)

        return pm_loss / n_samples
    def get_rb_loss(self,
                    image,
                    grad_estimator,
                    grad_estimator_kwargs={'grad_estimator_kwargs': None},
                    epoch=None,
                    topk=0,
                    n_samples=1,
                    true_pixel_2d=None):

        if true_pixel_2d is None:
            log_class_weights = self.pixel_attention(image)
            class_weights = torch.exp(log_class_weights)
        else:
            class_weights = self._get_class_weights_from_pixel_2d(
                true_pixel_2d)
            log_class_weights = torch.log(class_weights)

        # kl term
        kl_pixel_probs = (class_weights * log_class_weights).sum()

        self.cache_id_conv_image(image)
        f_pixel = lambda i : self.get_loss_cond_pixel_1d(i, image, \
                            use_cached_image = True)

        avg_pm_loss = 0.0
        # TODO: n_samples would be more elegant as an
        # argument to get_partial_marginal_loss
        for k in range(n_samples):
            pm_loss = rb_lib.get_raoblackwell_ps_loss(f_pixel,
                                                      log_class_weights,
                                                      topk,
                                                      grad_estimator,
                                                      grad_estimator_kwargs,
                                                      epoch,
                                                      data=image)

            avg_pm_loss += pm_loss / n_samples

        map_locations = torch.argmax(log_class_weights.detach(), dim=1)
        one_hot_map_locations = get_one_hot_encoding_from_int(map_locations, \
                                                            self.full_slen**2)
        map_cond_losses = f_pixel(one_hot_map_locations).sum()

        return avg_pm_loss + image.shape[0] * kl_pixel_probs, map_cond_losses
def eval_semisuper_vae(vae,
                       classifier,
                       loader_unlabeled,
                       super_loss,
                       loader_labeled=[None],
                       train=False,
                       optimizer=None,
                       topk=0,
                       grad_estimator=bs_lib.reinforce,
                       grad_estimator_kwargs={'grad_estimator_kwargs': None},
                       n_samples=1,
                       train_labeled_only=False,
                       epoch=0,
                       baseline_optimizer=None,
                       normalizer='softmax'):

    if train:
        assert optimizer is not None
        vae.train()
        classifier.train()

    else:
        vae.eval()
        classifier.eval()

    sum_loss = 0.0
    num_images = 0.0
    total_nz = 0.0

    for labeled_data, unlabeled_data in zip(cycle(loader_labeled), \
                                                loader_unlabeled):

        unlabeled_image = unlabeled_data['image'].to(device)

        if labeled_data is not None:
            labeled_image = labeled_data['image'].to(device)
            true_labels = labeled_data['label'].to(device)

            # get loss on labeled images
            supervised_loss = \
                get_supervised_loss(vae, classifier, labeled_image, true_labels, super_loss).sum()

            num_labeled = len(loader_labeled.sampler)
            num_labeled_batch = labeled_image.shape[0]

        else:
            supervised_loss = 0.0
            num_labeled = 0.0
            num_labeled_batch = 1.0

        # run through classifier
        scores = classifier.forward(unlabeled_image)

        if normalizer == 'softmax':
            class_weights = torch.softmax(scores, dim=-1)
        elif normalizer == 'entmax15':
            class_weights = entmax15(scores, dim=-1)
        elif normalizer == 'sparsemax':
            class_weights = sparsemax(scores, dim=-1)
        else:
            raise NameError("%s is not a valid normalizer!" % (normalizer, ))

        # get a mask of nonzeros
        nz = (class_weights > 0).to(class_weights.device)

        if train:

            train_labeled_only_bool = 1.
            if train_labeled_only:
                n_samples = 0
                train_labeled_only_bool = 0.

            # flush gradients
            optimizer.zero_grad()

            # get unlabeled pseudoloss: here we use our
            # Rao-Blackwellization or some other gradient estimator
            f_z = lambda z: vae_utils.get_loss_from_one_hot_label(
                vae, unlabeled_image, z)
            unlabeled_ps_loss = 0.0
            for i in range(n_samples):
                unlabeled_ps_loss_ = rb_lib.get_raoblackwell_ps_loss(
                    f_z,
                    class_weights,
                    topk=topk,
                    epoch=epoch,
                    data=unlabeled_image,
                    grad_estimator=grad_estimator,
                    grad_estimator_kwargs=grad_estimator_kwargs)

                unlabeled_ps_loss += unlabeled_ps_loss_

            unlabeled_ps_loss = unlabeled_ps_loss / max(n_samples, 1)

            kl_q = torch.sum(class_weights[nz] * torch.log(class_weights[nz]))

            total_ps_loss = \
                (unlabeled_ps_loss + kl_q) * train_labeled_only_bool * \
                len(loader_unlabeled.sampler) / unlabeled_image.shape[0] + \
                supervised_loss * num_labeled / labeled_image.shape[0]

            # backprop gradients from pseudo loss
            total_ps_loss.backward(retain_graph=True)
            optimizer.step()

            if baseline_optimizer is not None:
                # for RELAX: as it trains to minimize a control variate
                # flush gradients
                optimizer.zero_grad()
                # for params in classifier.parameters():
                baseline_optimizer.zero_grad()
                loss_grads = grad(total_ps_loss,
                                  classifier.parameters(),
                                  create_graph=True)
                gn2 = sum([grd.norm()**2 for grd in loss_grads])
                gn2.backward()
                baseline_optimizer.step()

        # loss at MAP value of z
        loss = \
            vae_utils.get_labeled_loss(vae, unlabeled_image,
                                torch.argmax(scores, dim = 1)).detach().sum()

        sum_loss += loss
        num_images += unlabeled_image.shape[0]

        total_nz += nz.sum().item()

    return sum_loss / num_images, total_nz / num_images
def eval_semisuper_vae(vae,
                       classifier,
                       loader_unlabeled,
                       loader_labeled=[None],
                       train=False,
                       optimizer=None,
                       topk=0,
                       grad_estimator=bs_lib.reinforce,
                       grad_estimator_kwargs={'grad_estimator_kwargs': None},
                       n_samples=1,
                       train_labeled_only=False,
                       epoch=0,
                       baseline_optimizer=None):

    if train:
        assert optimizer is not None
        vae.train()
        classifier.train()

    else:
        vae.eval()
        classifier.eval()

    sum_loss = 0.0
    num_images = 0.0

    for labeled_data, unlabeled_data in zip(cycle(loader_labeled), \
                                                loader_unlabeled):

        unlabeled_image = unlabeled_data['image'].to(device)

        if labeled_data is not None:
            labeled_image = labeled_data['image'].to(device)
            true_labels = labeled_data['label'].to(device)

            # get labeled portion of loss
            supervised_loss = \
                get_supervised_loss(vae, classifier, labeled_image,
                                                        true_labels).sum()

            num_labeled = len(loader_labeled.sampler)
            num_labeled_batch = labeled_image.shape[0]

        else:
            supervised_loss = 0.0
            num_labeled = 0.0
            num_labeled_batch = 1.0

        # run through classifier
        log_q = classifier.forward(unlabeled_image)

        if train:

            train_labeled_only_bool = 1.
            if train_labeled_only:
                n_samples = 0
                train_labeled_only_bool = 0.

            # flush gradients
            optimizer.zero_grad()

            # get unlabeled pseudoloss
            f_z = lambda z: vae_utils.get_loss_from_one_hot_label(
                vae, unlabeled_image, z)
            unlabeled_ps_loss = 0.0
            for i in range(n_samples):
                unlabeled_ps_loss_ = rb_lib.get_raoblackwell_ps_loss(
                    f_z,
                    log_q,
                    topk=topk,
                    epoch=epoch,
                    data=unlabeled_image,
                    grad_estimator=grad_estimator,
                    grad_estimator_kwargs=grad_estimator_kwargs)

                unlabeled_ps_loss += unlabeled_ps_loss_

            unlabeled_ps_loss = unlabeled_ps_loss / max(n_samples, 1)

            kl_q = torch.sum(torch.exp(log_q) * log_q)

            total_ps_loss = \
                (unlabeled_ps_loss + kl_q) * train_labeled_only_bool * \
                len(loader_unlabeled.sampler) / unlabeled_image.shape[0] + \
                supervised_loss * num_labeled / labeled_image.shape[0]

            # backprop gradients from pseudo loss
            total_ps_loss.backward(retain_graph=True)
            optimizer.step()

            if baseline_optimizer is not None:
                # flush gradients
                optimizer.zero_grad()
                # for params in classifier.parameters():
                baseline_optimizer.zero_grad()
                loss_grads = grad(total_ps_loss,
                                  classifier.parameters(),
                                  create_graph=True)
                gn2 = sum([grd.norm()**2 for grd in loss_grads])
                gn2.backward()
                baseline_optimizer.step()

        # loss at MAP value of z
        loss = \
            vae_utils.get_labeled_loss(vae, unlabeled_image,
                                torch.argmax(log_q, dim = 1)).detach().sum()

        sum_loss += loss
        num_images += unlabeled_image.shape[0]

    return sum_loss / num_images