def evaluate(self, data_loader, epoch, history, **kwargs): progress = tqdm.tqdm(data_loader, leave=False) with torch.no_grad(): for i, (source, target) in enumerate(progress): source, target = source.to(onegan.device()), target.to( onegan.device()) output = self.gnet(source) _, d_terms = self.forward_d(source, output.detach(), target) _, g_terms = self.forward_g(source, output, target) acc = self.metric(output, target) progress.set_description('Evaluate') progress.set_postfix( history.add({ **g_terms, **d_terms, 'acc/psnr': acc }, log_suffix='_val')) self.logger.image( { 'input': source.data, 'output': output.data, 'target': target.data }, epoch=epoch, prefix='val_') self.logger.scalar(history.metric(), epoch)
def gradient_penalty(dnet, target, pred): w = torch.rand(target.size(0), 1, 1, 1, device=onegan.device()).expand_as(target) interp = torch.tensor(w * target + (1 - w) * pred, requires_grad=True, device=onegan.device()) output = dnet(interp) grads = grad(outputs=output, inputs=interp, grad_outputs=torch.ones(output.size(), device=onegan.device()), create_graph=True, retain_graph=True)[0] return ((grads.view(grads.size(0), -1).norm(dim=1) - 1) ** 2).mean()
def test_device(): # default device assert onegan.device().type == ('cuda' if torch.cuda.is_available() else 'cpu') # change to cpu onegan.set_device('cpu') assert onegan.device().type == 'cpu' if torch.cuda.is_available(): # change back to gpu onegan.set_device('cuda') assert onegan.device().type == 'cuda' # change back to gpu:0 onegan.set_device('cuda:0') assert onegan.device().type == 'cuda' and onegan.device().index == 0
def train(self, data_loader, epoch, history, **kwargs): progress = tqdm.tqdm(data_loader) for i, (source, target) in enumerate(progress): source, target = source.to(onegan.device()), target.to( onegan.device()) output = self.gnet(source) d_loss, d_terms = self.forward_d(source, output.detach(), target) g_loss, g_terms = self.forward_g(source, output, target) acc = self.metric(output, target) self.d_optim.zero_grad() d_loss.backward() self.d_optim.step() self.g_optim.zero_grad() g_loss.backward() self.g_optim.step() progress.set_description('Epoch#%d' % (epoch + 1)) progress.set_postfix( history.add({ **g_terms, **d_terms, 'acc/psnr': acc })) self.logger.image( { 'input': source.data, 'output': output.data, 'target': target.data }, epoch=epoch, prefix='train_') self.logger.scalar(history.metric(), epoch)
def fetch_data(): source, target = next(progress) return source.to(onegan.device()), target.to(onegan.device())