示例#1
0
                noisy_target.to(device), clean.to(device)

            if args.transform == 'four_crop':
                # fuse batch and four crop
                noisy_input = noisy_input.view(-1, *noisy_input.shape[2:])
                noisy_target = noisy_target.view(-1, *noisy_target.shape[2:])
                clean = clean.view(-1, *clean.shape[2:])

            model.zero_grad()
            denoised = model(noisy_input)
            loss = F.mse_loss(denoised, noisy_target, reduction='sum')
            loss.backward()

            step = epoch * len(train_loader) + batch_idx + 1
            pct = step / total_steps
            lr = scheduler.step(pct)
            adjust_learning_rate(optimizer, lr)

            optimizer.step()

            mse += loss.item()
            with torch.no_grad():
                psnr += cal_psnr(clean, denoised.detach()).sum().item()
            if iters % args.print_freq == 0:
                print(f'[{batch_idx+1}|{len(train_loader)}]'\
                    f'[{epoch}|{args.epochs}] training PSNR: '\
                    f'{(psnr / (batch_idx+1) / args.batch_size / multiplier):.6f}')
        print(f'Epoch {epoch}, lr {lr}')

        psnr = psnr / n_train_samples
        rmse = np.sqrt(mse / n_train_pixels)
示例#2
0
            output_x = output[:, :1]
            output_y = output[:, 1:]
            target_x = target[:, :1]
            target_y = target[:, 1:]
            relative_x = torch.sum(
                (output_x - target_x)**2, [-1, -2]) / torch.sum(
                    target_x**2, [-1, -2])
            relative_y = torch.sum(
                (output_y - target_y)**2, [-1, -2]) / torch.sum(
                    target_y**2, [-1, -2])
            return relative_x, relative_y

        # lr scheduling
        step = (epoch - 1) * len(train_loader) + batch_idx
        pct = step / total_steps
        lr = scheduler.step(pct)
        adjust_learning_rate(optimizer, lr)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        loss_train += loss.item()

    loss_train /= batch_idx

    rel2_cat = torch.cat(relative_l2, 0)  # torch.Size([1344, 2])
    re_l2 = to_numpy(torch.mean(rel2_cat, 0))
    relative_ux = re_l2[0]
    relative_uy = re_l2[1]

    print(
        f'Epoch {epoch}: training loss: {loss_train:.6f} ' \
        f'relative-ux: {relative_ux: .5f}, relative-uy: {relative_uy: .5f}')