def validate(epoch): print("===> Running validation...") ssmi = loss.SsimLoss() valid_loss, valid_ssmi, valid_psnr = 0, 0, 0 iters = len(validation_data_loader) with torch.no_grad(): for batch in validation_data_loader: input, target = batch[0].to(device), batch[1].to(device) output = model(input) valid_loss += loss_function(output, target).item() valid_ssmi -= ssmi(output, target).item() valid_psnr += psnr(output, target).item() valid_loss /= iters valid_ssmi /= iters valid_psnr /= iters board_writer.add_scalar('data/epoch_validation_loss', valid_loss, epoch) board_writer.add_scalar('data/epoch_ssmi', valid_ssmi, epoch) board_writer.add_scalar('data/epoch_psnr', valid_psnr, epoch) print("===> Validation loss: {:.4f}".format(valid_loss))
if config.START_FROM_EXISTING_MODEL is not None: print(f'===> Loading pre-trained model: {config.START_FROM_EXISTING_MODEL}') model = Net.from_file(config.START_FROM_EXISTING_MODEL) else: print('===> Building model...') model = Net() model.to(device) if config.LOSS == "l1": loss_function = nn.L1Loss() elif config.LOSS == "vgg": loss_function = loss.VggLoss() elif config.LOSS == "ssim": loss_function = loss.SsimLoss() elif config.LOSS == "l1+vgg": loss_function = loss.CombinedLoss() else: raise ValueError(f"Unknown loss: {config.LOSS}") optimizer = optim.Adamax(model.parameters(), lr=0.001) board_writer = SummaryWriter() # ---------------------------------------------------------------------- def train(epoch): print("===> Training...") before_pass = [p.data.clone() for p in model.parameters()] epoch_loss = 0