Пример #1
0
def validate(dataloader, net, criterion_MSE, args):
    batch_time = utilities.AverageMeter('Time', ':6.3f')
    losses = utilities.AverageMeter('Loss', ':.4e')
    progress = utilities.ProgressMeter(len(dataloader), [batch_time, losses], prefix='Validation: ')

    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(dataloader):
            inputs = data['input_spectrum']
            inputs = inputs.float()
            inputs = inputs.cuda(args.gpu)
            target = data['output_spectrum']
            target = target.float()
            target = target.cuda(args.gpu)

            output = net(inputs)

            loss_MSE = criterion_MSE(output, target)
            losses.update(loss_MSE.item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % 400 == 0:
                progress.display(i)

    return losses.avg
Пример #2
0
def train(dataloader, net, optimizer, scheduler, criterion, criterion_MSE, epoch, args):
    
    batch_time = utilities.AverageMeter('Time', ':6.3f')
    losses = utilities.AverageMeter('Loss', ':.4e')
    progress = utilities.ProgressMeter(len(dataloader), [batch_time, losses], prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    for i, data in enumerate(dataloader):
        inputs = data['input_spectrum']
        inputs = inputs.float()
        inputs = inputs.cuda(args.gpu)
        target = data['output_spectrum']
        target = target.float()
        target = target.cuda(args.gpu)

        output = net(inputs)

        optimizer.zero_grad()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if args.scheduler == "cyclic-lr" or args.scheduler == "one-cycle-lr":
            scheduler.step()   

        loss_MSE = criterion_MSE(output, target)
        losses.update(loss_MSE.item(), inputs.size(0)) 

        batch_time.update(time.time() - end)
        end = time.time()

        if i % 400 == 0:
            progress.display(i)
    return losses.avg
Пример #3
0
def validate(dataloader, net, criterion_MSE, args):
    
    batch_time = utilities.AverageMeter('Time', ':6.3f')
    losses = utilities.AverageMeter('Loss', ':.4e')
    psnr = utilities.AverageMeter('PSNR', ':.4f')
    ssim = utilities.AverageMeter('SSIM', ':.4f')
    progress = utilities.ProgressMeter(len(dataloader), [batch_time, psnr, ssim], prefix='Validation: ')

    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(dataloader):
            inputs = data['input_image']
            inputs = inputs.float()
            inputs = inputs.cuda(args.gpu)
            target = data['output_image']
            target = target.float()
            target = target.cuda(args.gpu)

            output = net(inputs)

            loss_MSE = criterion_MSE(output, target)
            losses.update(loss_MSE.item(), inputs.size(0)) 

            psnr_batch = utilities.calc_psnr(output, target)
            psnr.update(psnr_batch, inputs.size(0))

            ssim_batch = utilities.calc_ssim(output, target)
            ssim.update(ssim_batch, inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % 20 == 0:
                progress.display(i)

    return losses.avg, psnr.avg, ssim.avg