Exemplo n.º 1
0
def main():

    print('Loading dataset ...\n')
    dataset_train = Dataset(data_path=opt.data_path)
    loader_train = DataLoader(dataset=dataset_train,
                              num_workers=4,
                              batch_size=opt.batch_size,
                              shuffle=True)
    print("# of training samples: %d\n" % int(len(loader_train)))

    # Build model
    model = Network(nin=64, use_GPU=opt.use_GPU)
    print_network(model)

    # loss function
    criterion = SSIM()
    criterion1 = nn.L1Loss()
    criterion2 = nn.MSELoss()
    # Move to GPU
    if opt.use_GPU:
        model = model.cuda()
        criterion.cuda()
        criterion1.cuda()
        criterion2.cuda()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)
    scheduler = MultiStepLR(optimizer, milestones=opt.milestone,
                            gamma=0.2)  # learning rates,

    # record training
    writer = SummaryWriter(opt.save_path)

    # load the lastest model
    initial_epoch = findLastCheckpoint(save_dir=opt.save_path)
    if initial_epoch > 0:
        print('resuming by loading epoch %d' % initial_epoch)
        model.load_state_dict(
            torch.load(
                os.path.join(opt.save_path,
                             'net_epoch%d.pth' % initial_epoch)))

    # start training
    step = 0
    for epoch in range(initial_epoch, opt.epochs):
        scheduler.step(epoch)
        for param_group in optimizer.param_groups:
            print('learning rate %f' % param_group["lr"])

        ## epoch training start
        for i, (input_train, target_train) in enumerate(loader_train, 0):
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            input_train, target_train = Variable(input_train), Variable(
                target_train)

            if opt.use_GPU:
                input_train, target_train = input_train.cuda(
                ), target_train.cuda()

            out_train, r1, r2 = model(input_train)
            pixel_metric = criterion(target_train, out_train)
            loss1 = criterion(target_train, r1)
            loss2 = criterion(target_train, r2)
            loss3 = criterion1(target_train, out_train)
            #loss4 = criterion1(target_train, r1)
            #loss5=criterion1(target_train,r2)
            loss = -pixel_metric - loss1 - loss2 + loss3  #+loss4+loss5

            loss.backward()
            optimizer.step()

            # training curve
            model.eval()
            out_train, _, _ = model(input_train)
            out_train = torch.clamp(out_train, 0., 1.)
            psnr_train = batch_PSNR(out_train, target_train, 1.)
            print(
                "[epoch %d][%d/%d] loss: %.4f, pixel_metric: %.4f,loss1: %.4f,loss2: %.4f,loss3: %.4f,PSNR: %.4f"
                % (epoch + 1, i + 1, len(loader_train), loss.item(),
                   pixel_metric.item(), loss1.item(), loss2.item(),
                   loss3.item(), psnr_train))

            if step % 10 == 0:
                # Log the scalar values
                writer.add_scalar('loss', loss.item(), step)
                writer.add_scalar('PSNR on training data', psnr_train, step)
            step += 1
        ## epoch training end

        # log the images
        model.eval()
        out_train, _, _ = model(input_train)
        out_train = torch.clamp(out_train, 0., 1.)
        im_target = utils.make_grid(target_train.data,
                                    nrow=8,
                                    normalize=True,
                                    scale_each=True)
        im_input = utils.make_grid(input_train.data,
                                   nrow=8,
                                   normalize=True,
                                   scale_each=True)
        im_derain = utils.make_grid(out_train.data,
                                    nrow=8,
                                    normalize=True,
                                    scale_each=True)
        writer.add_image('clean image', im_target, epoch + 1)
        writer.add_image('rainy image', im_input, epoch + 1)
        writer.add_image('deraining image', im_derain, epoch + 1)

        # save model
        torch.save(model.state_dict(),
                   os.path.join(opt.save_path, 'net_latest.pth'))

        if epoch % opt.save_freq == 0:
            torch.save(
                model.state_dict(),
                os.path.join(opt.save_path, 'net_epoch%d.pth' % (epoch + 1)))
Exemplo n.º 2
0
def main():
    # Load dataset
    print('Loading dataset ...\n')
    dataset_train = Dataset(train=True, data_path=opt.data_path)
    loader_train = DataLoader(dataset=dataset_train,
                              num_workers=4,
                              batch_size=opt.batchSize,
                              shuffle=True)
    print("# of training samples: %d\n" % int(len(dataset_train)))
    # Build model

    model = DRN(channel=3,
                inter_iter=opt.inter_iter,
                intra_iter=opt.intra_iter,
                use_GPU=opt.use_GPU)
    print_network(model)

    criterion = SSIM()

    # Move to GPU
    if opt.use_GPU:
        model = model.cuda()
        criterion.cuda()
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)
    scheduler = MultiStepLR(optimizer, milestones=opt.milestone,
                            gamma=0.5)  # learning rates
    # training
    writer = SummaryWriter(opt.save_folder)
    step = 0

    initial_epoch = findLastCheckpoint(
        save_dir=opt.save_folder)  # load the last model in matconvnet style
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        model.load_state_dict(
            torch.load(
                os.path.join(opt.save_folder,
                             'net_epoch%d.pth' % initial_epoch)))

    for epoch in range(initial_epoch, opt.epochs):

        scheduler.step(epoch)
        # set learning rate
        for param_group in optimizer.param_groups:
            print('learning rate %f' % param_group["lr"])
        # train
        for i, (input, target) in enumerate(loader_train, 0):
            # training step
            loss_list = []
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            input_train, target_train = Variable(input.cuda()), Variable(
                target.cuda())

            out_train, outs = model(input_train)

            pixel_loss = criterion(target_train, out_train)

            for lossi in range(opt.inter_iter):
                loss1 = criterion(target_train, outs[lossi])
                loss_list.append(loss1)

            loss = -pixel_loss
            index = 0.1
            for lossi in range(opt.inter_iter):
                loss += -index * loss_list[lossi]
                index = index + 0.1
            loss.backward()
            optimizer.step()
            # results
            model.eval()
            out_train, _ = model(input_train)
            out_train = torch.clamp(out_train, 0., 1.)
            psnr_train = batch_PSNR(out_train, target_train, 1.)
            print(
                "[epoch %d][%d/%d] loss: %.4f, loss1: %.4f, loss2: %.4f, loss3: %.4f, loss4: %.4f, PSNR_train: %.4f"
                % (epoch + 1, i + 1, len(loader_train), loss.item(),
                   loss_list[0].item(), loss_list[1].item(),
                   loss_list[2].item(), loss_list[3].item(), psnr_train))
            # print("[epoch %d][%d/%d] loss: %.4f, PSNR_train: %.4f" %
            #       (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train))
            # if you are using older version of PyTorch, you may need to change loss.item() to loss.data[0]
            if step % 10 == 0:
                # Log the scalar values
                writer.add_scalar('loss', loss.item(), step)
                writer.add_scalar('PSNR on training data', psnr_train, step)
            step += 1
        ## the end of each epoch

        model.eval()

        # log the images
        out_train, _ = model(input_train)
        out_train = torch.clamp(out_train, 0., 1.)
        Img = utils.make_grid(target_train.data,
                              nrow=8,
                              normalize=True,
                              scale_each=True)
        Imgn = utils.make_grid(input_train.data,
                               nrow=8,
                               normalize=True,
                               scale_each=True)
        Irecon = utils.make_grid(out_train.data,
                                 nrow=8,
                                 normalize=True,
                                 scale_each=True)
        writer.add_image('clean image', Img, epoch)
        writer.add_image('noisy image', Imgn, epoch)
        writer.add_image('reconstructed image', Irecon, epoch)
        # save model
        torch.save(model.state_dict(),
                   os.path.join(opt.save_folder, 'net_latest.pth'))

        if epoch % opt.save_freq == 0:
            torch.save(
                model.state_dict(),
                os.path.join(opt.save_folder, 'net_epoch%d.pth' % (epoch + 1)))