def get_pm_loss(self, topk, grad_estimator, grad_estimator_kwargs = {'grad_estimator_kwargs': None}): data = torch.rand((1, 5, 5)) log_class_weights = self.get_log_q() return rb_lib.get_raoblackwell_ps_loss(self.get_f_z, log_class_weights, topk, grad_estimator, grad_estimator_kwargs = grad_estimator_kwargs, data = data)
def get_pm_loss(self, topk, grad_estimator, n_samples=1): log_q = self.get_log_q() pm_loss = 0.0 for i in range(n_samples): pm_loss += rb_lib.get_raoblackwell_ps_loss(self.f_z, log_q, topk, grad_estimator) return pm_loss / n_samples
def get_pm_loss(self, topk, grad_estimator, n_samples=1): # returns the pseudo-loss: when backwards is called, this returns # an estimate of the gradient. log_q = self.get_log_q() pm_loss = 0.0 for i in range(n_samples): pm_loss += rb_lib.get_raoblackwell_ps_loss(self.f_z, log_q, topk, grad_estimator) return pm_loss / n_samples
def get_rb_loss(self, image, grad_estimator, grad_estimator_kwargs={'grad_estimator_kwargs': None}, epoch=None, topk=0, n_samples=1, true_pixel_2d=None): if true_pixel_2d is None: log_class_weights = self.pixel_attention(image) class_weights = torch.exp(log_class_weights) else: class_weights = self._get_class_weights_from_pixel_2d( true_pixel_2d) log_class_weights = torch.log(class_weights) # kl term kl_pixel_probs = (class_weights * log_class_weights).sum() self.cache_id_conv_image(image) f_pixel = lambda i : self.get_loss_cond_pixel_1d(i, image, \ use_cached_image = True) avg_pm_loss = 0.0 # TODO: n_samples would be more elegant as an # argument to get_partial_marginal_loss for k in range(n_samples): pm_loss = rb_lib.get_raoblackwell_ps_loss(f_pixel, log_class_weights, topk, grad_estimator, grad_estimator_kwargs, epoch, data=image) avg_pm_loss += pm_loss / n_samples map_locations = torch.argmax(log_class_weights.detach(), dim=1) one_hot_map_locations = get_one_hot_encoding_from_int(map_locations, \ self.full_slen**2) map_cond_losses = f_pixel(one_hot_map_locations).sum() return avg_pm_loss + image.shape[0] * kl_pixel_probs, map_cond_losses
def eval_semisuper_vae(vae, classifier, loader_unlabeled, super_loss, loader_labeled=[None], train=False, optimizer=None, topk=0, grad_estimator=bs_lib.reinforce, grad_estimator_kwargs={'grad_estimator_kwargs': None}, n_samples=1, train_labeled_only=False, epoch=0, baseline_optimizer=None, normalizer='softmax'): if train: assert optimizer is not None vae.train() classifier.train() else: vae.eval() classifier.eval() sum_loss = 0.0 num_images = 0.0 total_nz = 0.0 for labeled_data, unlabeled_data in zip(cycle(loader_labeled), \ loader_unlabeled): unlabeled_image = unlabeled_data['image'].to(device) if labeled_data is not None: labeled_image = labeled_data['image'].to(device) true_labels = labeled_data['label'].to(device) # get loss on labeled images supervised_loss = \ get_supervised_loss(vae, classifier, labeled_image, true_labels, super_loss).sum() num_labeled = len(loader_labeled.sampler) num_labeled_batch = labeled_image.shape[0] else: supervised_loss = 0.0 num_labeled = 0.0 num_labeled_batch = 1.0 # run through classifier scores = classifier.forward(unlabeled_image) if normalizer == 'softmax': class_weights = torch.softmax(scores, dim=-1) elif normalizer == 'entmax15': class_weights = entmax15(scores, dim=-1) elif normalizer == 'sparsemax': class_weights = sparsemax(scores, dim=-1) else: raise NameError("%s is not a valid normalizer!" % (normalizer, )) # get a mask of nonzeros nz = (class_weights > 0).to(class_weights.device) if train: train_labeled_only_bool = 1. if train_labeled_only: n_samples = 0 train_labeled_only_bool = 0. # flush gradients optimizer.zero_grad() # get unlabeled pseudoloss: here we use our # Rao-Blackwellization or some other gradient estimator f_z = lambda z: vae_utils.get_loss_from_one_hot_label( vae, unlabeled_image, z) unlabeled_ps_loss = 0.0 for i in range(n_samples): unlabeled_ps_loss_ = rb_lib.get_raoblackwell_ps_loss( f_z, class_weights, topk=topk, epoch=epoch, data=unlabeled_image, grad_estimator=grad_estimator, grad_estimator_kwargs=grad_estimator_kwargs) unlabeled_ps_loss += unlabeled_ps_loss_ unlabeled_ps_loss = unlabeled_ps_loss / max(n_samples, 1) kl_q = torch.sum(class_weights[nz] * torch.log(class_weights[nz])) total_ps_loss = \ (unlabeled_ps_loss + kl_q) * train_labeled_only_bool * \ len(loader_unlabeled.sampler) / unlabeled_image.shape[0] + \ supervised_loss * num_labeled / labeled_image.shape[0] # backprop gradients from pseudo loss total_ps_loss.backward(retain_graph=True) optimizer.step() if baseline_optimizer is not None: # for RELAX: as it trains to minimize a control variate # flush gradients optimizer.zero_grad() # for params in classifier.parameters(): baseline_optimizer.zero_grad() loss_grads = grad(total_ps_loss, classifier.parameters(), create_graph=True) gn2 = sum([grd.norm()**2 for grd in loss_grads]) gn2.backward() baseline_optimizer.step() # loss at MAP value of z loss = \ vae_utils.get_labeled_loss(vae, unlabeled_image, torch.argmax(scores, dim = 1)).detach().sum() sum_loss += loss num_images += unlabeled_image.shape[0] total_nz += nz.sum().item() return sum_loss / num_images, total_nz / num_images
def eval_semisuper_vae(vae, classifier, loader_unlabeled, loader_labeled=[None], train=False, optimizer=None, topk=0, grad_estimator=bs_lib.reinforce, grad_estimator_kwargs={'grad_estimator_kwargs': None}, n_samples=1, train_labeled_only=False, epoch=0, baseline_optimizer=None): if train: assert optimizer is not None vae.train() classifier.train() else: vae.eval() classifier.eval() sum_loss = 0.0 num_images = 0.0 for labeled_data, unlabeled_data in zip(cycle(loader_labeled), \ loader_unlabeled): unlabeled_image = unlabeled_data['image'].to(device) if labeled_data is not None: labeled_image = labeled_data['image'].to(device) true_labels = labeled_data['label'].to(device) # get labeled portion of loss supervised_loss = \ get_supervised_loss(vae, classifier, labeled_image, true_labels).sum() num_labeled = len(loader_labeled.sampler) num_labeled_batch = labeled_image.shape[0] else: supervised_loss = 0.0 num_labeled = 0.0 num_labeled_batch = 1.0 # run through classifier log_q = classifier.forward(unlabeled_image) if train: train_labeled_only_bool = 1. if train_labeled_only: n_samples = 0 train_labeled_only_bool = 0. # flush gradients optimizer.zero_grad() # get unlabeled pseudoloss f_z = lambda z: vae_utils.get_loss_from_one_hot_label( vae, unlabeled_image, z) unlabeled_ps_loss = 0.0 for i in range(n_samples): unlabeled_ps_loss_ = rb_lib.get_raoblackwell_ps_loss( f_z, log_q, topk=topk, epoch=epoch, data=unlabeled_image, grad_estimator=grad_estimator, grad_estimator_kwargs=grad_estimator_kwargs) unlabeled_ps_loss += unlabeled_ps_loss_ unlabeled_ps_loss = unlabeled_ps_loss / max(n_samples, 1) kl_q = torch.sum(torch.exp(log_q) * log_q) total_ps_loss = \ (unlabeled_ps_loss + kl_q) * train_labeled_only_bool * \ len(loader_unlabeled.sampler) / unlabeled_image.shape[0] + \ supervised_loss * num_labeled / labeled_image.shape[0] # backprop gradients from pseudo loss total_ps_loss.backward(retain_graph=True) optimizer.step() if baseline_optimizer is not None: # flush gradients optimizer.zero_grad() # for params in classifier.parameters(): baseline_optimizer.zero_grad() loss_grads = grad(total_ps_loss, classifier.parameters(), create_graph=True) gn2 = sum([grd.norm()**2 for grd in loss_grads]) gn2.backward() baseline_optimizer.step() # loss at MAP value of z loss = \ vae_utils.get_labeled_loss(vae, unlabeled_image, torch.argmax(log_q, dim = 1)).detach().sum() sum_loss += loss num_images += unlabeled_image.shape[0] return sum_loss / num_images