Ejemplo n.º 1
0
    def _train_amp(self, epoch):
        """
        Same training procedure, but uses half precision to speed up training on GPUs.
        Only works on SOME GPUs and the latest version of Pytorch.
        """
        self.model.train()
        loss_sum = 0.0
        loss_count = 0
        print_every = max(
            1, (len(self.train_loader.dataset) // self.args.batch_size) // 20)
        for i, x in enumerate(self.train_loader):

            # Cast operations to mixed precision
            if self.args.super_resolution or self.args.conditional:
                batch_size = len(x[0])
                with torch.cuda.amp.autocast():
                    loss = cond_elbo_bpd(self.model,
                                         x[0].to(self.args.device),
                                         context=x[1].to(self.args.device))
            else:
                batch_size = len(x)
                with torch.cuda.amp.autocast():
                    loss = elbo_bpd(self.model, x.to(self.args.device))

            # Scale loss and call backward() to create scaled gradients
            self.scaler.scale(loss).backward()

            if self.max_grad_norm > 0:
                # Unscales the gradients of optimizer's assigned params in-place
                self.scaler.unscale_(self.optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            # Unscale gradients and call (or skip) optimizer.step()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            if self.scheduler_iter:
                self.scheduler_iter.step()

            self.optimizer.zero_grad(set_to_none=True)

            # accumulate loss and report
            loss_sum += loss.detach().cpu().item() * batch_size
            loss_count += batch_size
            if i % print_every == 0 or i == (len(self.train_loader) - 1):
                self.log_epoch("Training", loss_count,
                               len(self.train_loader.dataset), loss_sum)

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'bpd': loss_sum / loss_count}
Ejemplo n.º 2
0
 def eval_fn(self, epoch):
     self.model.eval()
     with torch.no_grad():
         loss_sum = 0.0
         loss_count = 0
         for x in self.eval_loader:
             loss = elbo_bpd(self.model, x.to(self.args.device))
             loss_sum += loss.detach().cpu().item() * len(x)
             loss_count += len(x)
             print('Evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r')
         print('')
     return {'bpd': loss_sum/loss_count}
Ejemplo n.º 3
0
 def train_fn(self, epoch):
     self.model.train()
     loss_sum = 0.0
     loss_count = 0
     for x in self.train_loader:
         self.optimizer.zero_grad()
         loss = elbo_bpd(self.model, x.to(self.args.device))
         loss.backward()
         self.optimizer.step()
         if self.scheduler_iter: self.scheduler_iter.step()
         loss_sum += loss.detach().cpu().item() * len(x)
         loss_count += len(x)
         print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum/loss_count), end='\r')
     print('')
     if self.scheduler_epoch: self.scheduler_epoch.step()
     return {'bpd': loss_sum/loss_count}
Ejemplo n.º 4
0
    def eval_fn(self, epoch):
        self.model.eval()
        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for x in self.eval_loader:
                if self.args.super_resolution or self.args.conditional:
                    batch_size = len(x[0])
                    loss = cond_elbo_bpd(self.model,
                                         x[0].to(self.args.device),
                                         context=x[1].to(self.args.device))
                else:
                    batch_size = len(x)
                    loss = elbo_bpd(self.model, x.to(self.args.device))

                loss_sum += loss.detach().cpu().item() * batch_size
                loss_count += batch_size

            self.log_epoch("Evaluating", loss_count,
                           len(self.eval_loader.dataset), loss_sum)
            print('')
        return {'bpd': loss_sum / loss_count}
Ejemplo n.º 5
0
    def _train(self, epoch):
        self.model.train()
        loss_sum = 0.0
        loss_count = 0
        print_every = max(
            1, (len(self.train_loader.dataset) // self.args.batch_size) // 20)
        for i, x in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            if self.args.super_resolution or self.args.conditional:
                batch_size = len(x[0])
                loss = cond_elbo_bpd(self.model,
                                     x[0].to(self.args.device),
                                     context=x[1].to(self.args.device))
            else:
                batch_size = len(x)
                loss = elbo_bpd(self.model, x.to(self.args.device))

            loss.backward()

            if self.max_grad_norm > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            self.optimizer.step()
            if self.scheduler_iter:
                self.scheduler_iter.step()

            loss_sum += loss.detach().cpu().item() * batch_size
            loss_count += batch_size
            if i % print_every == 0 or i == (len(self.train_loader) - 1):
                self.log_epoch("Training", loss_count,
                               len(self.train_loader.dataset), loss_sum)

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'bpd': loss_sum / loss_count}
Ejemplo n.º 6
0
###########
## Train ##
###########

print('Training...')
for epoch in range(epochs):
    l = 0.0
    for i, x in enumerate(train_loader):
        if i == 0 and epoch == 0:
            print("Data:", x.size(), "min = ",
                  x.min().data.item(), "max = ",
                  x.max().data.item())

        optimizer.zero_grad()
        #loss = elbo_nats(model, x.to(device))
        loss = elbo_bpd(model, x.to(device))
        #loss = -model.log_prob(x.to(device)).mean()
        loss.backward()
        optimizer.step()
        l += loss.detach().cpu().item()
        print('Epoch: {}/{}, Iter: {}/{}, Bits/dim: {:.3f}'.format(
            epoch + 1, epochs, i + 1, len(train_loader), l / (i + 1)),
              end='\r')
    print('')

##########
## Test ##
##########

print('Testing...')
with torch.no_grad():