def _train_amp(self, epoch): """ Same training procedure, but uses half precision to speed up training on GPUs. Only works on SOME GPUs and the latest version of Pytorch. """ self.model.train() loss_sum = 0.0 loss_count = 0 print_every = max( 1, (len(self.train_loader.dataset) // self.args.batch_size) // 20) for i, x in enumerate(self.train_loader): # Cast operations to mixed precision if self.args.super_resolution or self.args.conditional: batch_size = len(x[0]) with torch.cuda.amp.autocast(): loss = cond_elbo_bpd(self.model, x[0].to(self.args.device), context=x[1].to(self.args.device)) else: batch_size = len(x) with torch.cuda.amp.autocast(): loss = elbo_bpd(self.model, x.to(self.args.device)) # Scale loss and call backward() to create scaled gradients self.scaler.scale(loss).backward() if self.max_grad_norm > 0: # Unscales the gradients of optimizer's assigned params in-place self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) # Unscale gradients and call (or skip) optimizer.step() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_iter: self.scheduler_iter.step() self.optimizer.zero_grad(set_to_none=True) # accumulate loss and report loss_sum += loss.detach().cpu().item() * batch_size loss_count += batch_size if i % print_every == 0 or i == (len(self.train_loader) - 1): self.log_epoch("Training", loss_count, len(self.train_loader.dataset), loss_sum) print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'bpd': loss_sum / loss_count}
def eval_fn(self, epoch): self.model.eval() with torch.no_grad(): loss_sum = 0.0 loss_count = 0 for x in self.eval_loader: loss = elbo_bpd(self.model, x.to(self.args.device)) loss_sum += loss.detach().cpu().item() * len(x) loss_count += len(x) print('Evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r') print('') return {'bpd': loss_sum/loss_count}
def train_fn(self, epoch): self.model.train() loss_sum = 0.0 loss_count = 0 for x in self.train_loader: self.optimizer.zero_grad() loss = elbo_bpd(self.model, x.to(self.args.device)) loss.backward() self.optimizer.step() if self.scheduler_iter: self.scheduler_iter.step() loss_sum += loss.detach().cpu().item() * len(x) loss_count += len(x) print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum/loss_count), end='\r') print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'bpd': loss_sum/loss_count}
def eval_fn(self, epoch): self.model.eval() with torch.no_grad(): loss_sum = 0.0 loss_count = 0 for x in self.eval_loader: if self.args.super_resolution or self.args.conditional: batch_size = len(x[0]) loss = cond_elbo_bpd(self.model, x[0].to(self.args.device), context=x[1].to(self.args.device)) else: batch_size = len(x) loss = elbo_bpd(self.model, x.to(self.args.device)) loss_sum += loss.detach().cpu().item() * batch_size loss_count += batch_size self.log_epoch("Evaluating", loss_count, len(self.eval_loader.dataset), loss_sum) print('') return {'bpd': loss_sum / loss_count}
def _train(self, epoch): self.model.train() loss_sum = 0.0 loss_count = 0 print_every = max( 1, (len(self.train_loader.dataset) // self.args.batch_size) // 20) for i, x in enumerate(self.train_loader): self.optimizer.zero_grad() if self.args.super_resolution or self.args.conditional: batch_size = len(x[0]) loss = cond_elbo_bpd(self.model, x[0].to(self.args.device), context=x[1].to(self.args.device)) else: batch_size = len(x) loss = elbo_bpd(self.model, x.to(self.args.device)) loss.backward() if self.max_grad_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm) self.optimizer.step() if self.scheduler_iter: self.scheduler_iter.step() loss_sum += loss.detach().cpu().item() * batch_size loss_count += batch_size if i % print_every == 0 or i == (len(self.train_loader) - 1): self.log_epoch("Training", loss_count, len(self.train_loader.dataset), loss_sum) print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'bpd': loss_sum / loss_count}
########### ## Train ## ########### print('Training...') for epoch in range(epochs): l = 0.0 for i, x in enumerate(train_loader): if i == 0 and epoch == 0: print("Data:", x.size(), "min = ", x.min().data.item(), "max = ", x.max().data.item()) optimizer.zero_grad() #loss = elbo_nats(model, x.to(device)) loss = elbo_bpd(model, x.to(device)) #loss = -model.log_prob(x.to(device)).mean() loss.backward() optimizer.step() l += loss.detach().cpu().item() print('Epoch: {}/{}, Iter: {}/{}, Bits/dim: {:.3f}'.format( epoch + 1, epochs, i + 1, len(train_loader), l / (i + 1)), end='\r') print('') ########## ## Test ## ########## print('Testing...') with torch.no_grad():