def compute_losses(output, minibatch, lambda_factor=1): def compute_kernel(x, y): x_size = x.size(0) y_size = y.size(0) dim = x.size(1) x = x.unsqueeze(1) # (x_size, 1, dim) y = y.unsqueeze(0) # (1, y_size, dim) tiled_x = x.expand(x_size, y_size, dim) tiled_y = y.expand(x_size, y_size, dim) kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) return torch.exp(-kernel_input) # (x_size, y_size) def compute_mmd(x, y): x_kernel = compute_kernel(x, x) y_kernel = compute_kernel(y, y) xy_kernel = compute_kernel(x, y) mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() return mmd def W_loss(z): return compute_mmd(z,torch.randn_like(z)) rec_loss = F.binary_cross_entropy(output,minibatch.unsqueeze(1)) prior_loss = lambda_factor*W_loss(z) return (rec_loss, prior_loss)
def loss_vae(recon_x, x, mu, logvar, type="BCE", h1=0.1, h2=0.1): """ see Appendix B from VAE paper: Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) :param recon_x: :param x: :param mu,logvar: VAE parameters :param type: choices BCE,L1,L2 :param h1: reconsrtruction hyperparam :param h2: KL div hyperparam :return: total loss of VAE """ batch = recon_x.shape[0] assert recon_x.size() == x.size() assert recon_x.shape[0] == x.shape[0] rec_flat = recon_x.view(batch, -1) x_flat = x.view(batch, -1) if type == "BCE": loss_rec = F.binary_cross_entropy(rec_flat, x_flat, reduction='sum') elif type == "L1": loss_rec = torch.sum(torch.abs(rec_flat - x_flat)) elif type == "L2": loss_rec = torch.sum(torch.sqrt(rec_flat * rec_flat - x_flat * x_flat)) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return loss_rec * h1 + KLD * h2
def forward(self, input, target): # _assert_no_grad(target) target = target.float() # input = input[0][0] beta = 1 - torch.mean(target) # input = F.softmax(input, dim=1) input = input[:, 0, :, :] # target pixel = 1 -> weight beta # target pixel = 0 -> weight 1-beta weights = 1 - beta + (2 * beta - 1) * target return F.binary_cross_entropy(input, target, weights, reduction='mean')
def forward(self, inputs, targets): if self.logits: BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) else: BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss if self.reduce: return torch.mean(F_loss) else: return F_loss
def train_model(model, epochs, train_loader, optimizer): model.train() for epoch in range(epochs): total_loss = 0 total_correct = 0 total = 0 with tqdm(train_loader, desc='Train Epoch #{}'.format(epoch)) as t: for data, target in t: data, target = data.to(DEVICE), target.to(DEVICE) optimizer.zero_grad() output = model(data) optimizer.zero_grad() loss = F.binary_cross_entropy(output, target) loss.backward() optimizer.step() total += len(data) total_correct += output.round().eq(target).sum().item() total_loss += loss.item() * len(data) t.set_postfix(loss='{:.4f}'.format(total_loss / total), accuracy='{:.4f}'.format(total_correct / total))
def gan_loss(self, y_hat, y): return F.binary_cross_entropy(y_hat, y)
def compute_loss(minibatch, output, mu, logvar): rec_loss = F.binary_cross_entropy(output,minibatch.unsqueeze(1),\ reduction="sum") kl_loss = 0.25 * torch.sum(torch.exp(logvar) + mu**2 - 1. - logvar) loss = rec_loss + kl_loss return loss