def test_one_epoch(): model.eval() mean_loss = 0 count = 0 for idx, data in enumerate(test_dataloader): sample_path, in_LDRs, in_HDRs, in_exps, ref_HDRs = data sample_path = sample_path[0] in_LDRs = in_LDRs.to(device) in_HDRs = in_HDRs.to(device) ref_HDRs = ref_HDRs.to(device) # Forward with torch.no_grad(): res = model(in_LDRs, in_HDRs) # Compute loss with torch.no_grad(): loss = criterion(tonemap(res), tonemap(ref_HDRs)) dump_sample(sample_path, res.cpu().detach().numpy()) print('--------------- Test Batch %d ---------------' % (idx + 1)) print('loss: %.12f' % loss.item()) mean_loss += loss.item() count += 1 mean_loss = mean_loss / count return mean_loss
def train_one_epoch(): model.train() for idx, data in enumerate(train_dataloader): in_LDRs, ref_LDRs, in_HDRs, ref_HDRs, in_exps, ref_exps = data in_LDRs = in_LDRs.to(device) in_HDRs = in_HDRs.to(device) ref_HDRs = ref_HDRs.to(device) # Forward result = model(in_LDRs, in_HDRs) # Backward loss = criterion(tonemap(result), tonemap(ref_HDRs)) loss.backward() optimizer.step() optimizer.zero_grad() print('--------------- Train Batch %d ---------------' % (idx + 1)) print('loss: %.12f' % loss.item())