Beispiel #1
0
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

dataloader = DataLoader(
    ImageDataset("../../data/%s" % opt.dataset_name, hr_shape=hr_shape),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)

global_timer = Timer()
epoch_timer = Timer()
iter_timer = Timer()
iter_time_meter = AverageMeter()

# %% Training
global_timer.start()
for epoch in range(opt.n_epochs):
    epoch_timer.start()
    for i, imgs in enumerate(dataloader):
        if i % opt.batch_m == 0:
            iter_timer.start()
        # Configure model input
        imgs_lr = imgs["lr"].type(Tensor)
        imgs_hr = imgs["hr"].type(Tensor)

        # Adversarial ground truths
        valid = torch.ones(
            (imgs_lr.size(0), *discriminator_output_shape)).type(Tensor)
Beispiel #2
0
def validate(val_loader,
             model,
             criterion,
             print_freq,
             epoch,
             writer,
             args=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.eval()
    if args.distribute:
        local_rank = torch.distributed.get_rank()
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", local_rank)
    else:
        device = torch.device("cuda")
    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda(non_blocking=True, device=device)

            # compute output
            output = model(input)
            loss = criterion(output, target)
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            top5.update(prec5.item(), input.size(0))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0 and is_main_process():
                logging.info(
                    ('Test: [{0}/{1}]\t'
                     'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                     'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                     'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                         i,
                         len(val_loader),
                         batch_time=batch_time,
                         top1=top1,
                         top5=top5)))

    if args.distribute:
        losses.synchronize_between_processes()
        top1.synchronize_between_processes()
        top5.synchronize_between_processes()
    if is_main_process():
        writer.add_scalar('Test/loss', losses.avg, epoch)
        writer.add_scalar('Test/top1', top1.avg, epoch)
        logging.info((
            'Epoch {epoch} Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
            .format(epoch=epoch, top1=top1, top5=top5, loss=losses)))

    return (top1.avg + top5.avg) / 2
def train(gpu, args):
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=rank)
    torch.manual_seed(0)

    generator = GeneratorResNet()

    torch.cuda.set_device(gpu)
    generator.cuda(gpu)
    generator.load_state_dict(
        filter_state_dict(
            torch.load('saved_models/generator_%s.pth' %
                       args.checkpoint_name)))

    # Wrap the model
    generator = nn.parallel.DistributedDataParallel(generator,
                                                    device_ids=[gpu])

    # Dataloader
    dataset = ImageDataset("../../data/img_align_celeba_eval",
                           hr_shape=(256, 256))
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=args.world_size, rank=rank)
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=0,
                                         pin_memory=True,
                                         sampler=sampler)

    torch.autograd.set_detect_anomaly(True)
    total_step = len(loader)

    psnr_gen_list = []
    psnr_lr_list = []
    ssim_gen_list = []
    ssim_lr_list = []

    iter_timer = Timer()
    iter_time_meter = AverageMeter()

    for i, imgs in enumerate(loader):
        with torch.no_grad():
            iter_timer.start()

            imgs_lr = imgs["lr"].cuda(non_blocking=True)
            imgs_hr = imgs["hr"].cuda(non_blocking=True)
            imgs_hr_raw = imgs['hr_raw'].cuda(non_blocking=True)

            gen_hr = generator(imgs_lr)

            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            imgs_lr = minmaxscaler(imgs_lr)
            imgs_hr = minmaxscaler(imgs_hr)
            gen_hr = minmaxscaler(gen_hr)

            if args.grayscale:
                gen_hr_gray = gen_hr.mean(dim=1)[:, None, :, :]
                imgs_lr_gray = imgs_lr.mean(dim=1)[:, None, :, :]
                imgs_hr_raw_gray = imgs_hr_raw.mean(dim=1)[:, None, :, :]

                psnr_gen = psnr(gen_hr_gray, imgs_hr_raw_gray, max_val=1)
                psnr_lr = psnr(imgs_lr_gray, imgs_hr_raw_gray, max_val=1)
                ssim_gen = ssim(gen_hr_gray,
                                imgs_hr_raw_gray,
                                size_average=False)
                ssim_lr = ssim(imgs_lr_gray,
                               imgs_hr_raw_gray,
                               size_average=False)
            else:
                psnr_gen = psnr(gen_hr, imgs_hr_raw, max_val=1)
                psnr_lr = psnr(imgs_lr, imgs_hr_raw, max_val=1)
                ssim_gen = ssim(gen_hr, imgs_hr_raw, size_average=False)
                ssim_lr = ssim(imgs_lr, imgs_hr_raw, size_average=False)

            psnr_gen_list.append(psnr_gen)
            psnr_lr_list.append(psnr_lr)
            ssim_gen_list.append(ssim_gen)
            ssim_lr_list.append(ssim_lr)

            iter_time_meter.update(iter_timer.stop())
            if gpu == 0:
                batches_done = i
                print(
                    '[Batch %d/%d] time for iteration: %.4f (%.4f) (sum: %.4f)'
                    % (i, len(loader), iter_time_meter.val,
                       iter_time_meter.avg, iter_time_meter.sum))

                if batches_done % args.sample_interval == 0:
                    print('%10s %10s %10s' % ('', 'PSNR', 'SSIM'))
                    print('%10s %10.4f %10.4f' %
                          ('generator', psnr_gen.mean().item(),
                           ssim_gen.mean().item()))
                    print('%10s %10.4f %10.4f' %
                          ('low res', psnr_lr.mean().item(),
                           ssim_lr.mean().item()))

                    # Save image grid with upsampled inputs and SRGAN outputs
                    gen_hr = make_grid(gen_hr[:4], nrow=1, normalize=True)
                    imgs_lr = make_grid(imgs_lr[:4], nrow=1, normalize=True)
                    imgs_hr_raw = make_grid(imgs_hr_raw[:4],
                                            nrow=1,
                                            normalize=True)
                    img_grid = torch.cat((imgs_hr_raw, imgs_lr, gen_hr), -1)
                    save_image(img_grid,
                               "images_test_distributed/%d.png" % batches_done,
                               normalize=False)

    psnr_gen = torch.cat(psnr_gen_list).mean().item()
    psnr_lr = torch.cat(psnr_lr_list).mean().item()
    ssim_gen = torch.cat(ssim_gen_list).mean().item()
    ssim_lr = torch.cat(ssim_lr_list).mean().item()

    write(
        ' '.join(
            list(
                map(str, [
                    psnr_gen, psnr_lr, ssim_gen, ssim_lr, iter_time_meter.sum
                ]))),
        str(rank) + 'txt')
Beispiel #4
0
def train(
    train_loader,
    model,
    criterion,
    optimizer,
    epoch,
    print_freq,
    writer,
    args=None,
):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    end = time.time()
    if args.distribute:
        local_rank = torch.distributed.get_rank()
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", local_rank)
    else:
        device = torch.device("cuda")

    for i, (input, target) in enumerate(train_loader):

        data_time.update(time.time() - end)
        target = target.cuda(non_blocking=True, device=device)

        if args.half:
            with torch.cuda.amp.autocast():
                output = model(input.cuda(device))
                loss = criterion(output, target)
        else:
            output = model(input.cuda(device))
            loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        top5.update(prec5.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        losses.update(loss.item(), input.size(0))
        if i % print_freq == 0 and is_main_process():
            logging.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.8f}\t'
                          'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                              epoch,
                              int(i),
                              int(len(train_loader)),
                              batch_time=batch_time,
                              loss=losses,
                              top1=top1,
                              top5=top5,
                              lr=optimizer.param_groups[-1]['lr'])))
    if args.distribute:
        losses.synchronize_between_processes()
        top1.synchronize_between_processes()
        top5.synchronize_between_processes()
    if is_main_process():
        writer.add_scalar('Train/loss', losses.avg, epoch)
        writer.add_scalar('Train/top1', top1.avg, epoch)
        writer.add_scalar('Train/lr', optimizer.param_groups[-1]['lr'], epoch)
        logging.info((
            'Epoch {epoch} Training Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
            .format(epoch=epoch, top1=top1, top5=top5, loss=losses)))
Beispiel #5
0
def train(gpu, args):
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=rank)
    torch.manual_seed(0)

    hr_shape = (args.hr_height, args.hr_width)
    generator = GeneratorResNet()
    discriminator = Discriminator(input_shape=(args.channels, *hr_shape))
    feature_extractor = FeatureExtractor()

    torch.cuda.set_device(gpu)
    generator.cuda(gpu)
    discriminator.cuda(gpu)
    feature_extractor.cuda(gpu)

    # define loss function (criterion) and optimizer
    criterion_GAN = nn.MSELoss().cuda(gpu)
    criterion_content = nn.L1Loss().cuda(gpu)

    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))

    # Save model attributes (erased after DDP)
    discriminator_output_shape = discriminator.output_shape

    # Wrap the model
    generator = nn.parallel.DistributedDataParallel(generator,
                                                    device_ids=[gpu])
    discriminator = nn.parallel.DistributedDataParallel(
        discriminator, device_ids=[gpu], broadcast_buffers=False)
    feature_extractor = nn.parallel.DistributedDataParallel(feature_extractor,
                                                            device_ids=[gpu])

    # Data loading code
    train_dataset = ImageDataset("../../data/%s" % args.dataset_name,
                                 hr_shape=hr_shape)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0,
                                               pin_memory=True,
                                               sampler=train_sampler)

    torch.autograd.set_detect_anomaly(True)
    total_step = len(train_loader)

    if gpu == 0:
        global_timer = Timer()
        epoch_timer = Timer()
        iter_timer = Timer()
        iter_time_meter = AverageMeter()

        global_timer.start()

    for epoch in range(args.n_epochs):
        if gpu == 0:
            epoch_timer.start()
        for i, imgs in enumerate(train_loader):
            if gpu == 0:
                iter_timer.start()
            imgs_lr = imgs["lr"].cuda(non_blocking=True)
            imgs_hr = imgs["hr"].cuda(non_blocking=True)

            valid = torch.ones((imgs_lr.size(0), *discriminator_output_shape),
                               device=gpu)
            fake = torch.zeros((imgs_lr.size(0), *discriminator_output_shape),
                               device=gpu)

            # ------------------
            #  Train Generators
            # ------------------
            gen_hr = generator(imgs_lr)

            loss_GAN = criterion_GAN(discriminator(gen_hr), valid)

            gen_features = feature_extractor(gen_hr)
            real_features = feature_extractor(imgs_hr)
            loss_content = criterion_content(gen_features,
                                             real_features.detach())

            # Total loss
            loss_G = loss_content + 1e-3 * loss_GAN

            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            loss_real = criterion_GAN(discriminator(imgs_hr), valid)
            loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)

            loss_D = (loss_real + loss_fake) / 2

            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

            if gpu == 0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] " %
                    (epoch, args.n_epochs, i, len(train_loader), loss_D.item(),
                     loss_G.item()),
                    end='')
                iter_time_meter.update(iter_timer.stop())
                print('time for iteration: %.4f (%.4f)' %
                      (iter_time_meter.val, iter_time_meter.avg))

                batches_done = epoch * len(train_loader) + i
                if batches_done % args.sample_interval == 0:
                    # Save image grid with upsampled inputs and SRGAN outputs
                    imgs_lr = nn.functional.interpolate(imgs_lr,
                                                        scale_factor=4)
                    imgs_hr_raw = imgs['hr_raw'].cuda(non_blocking=True)
                    with torch.no_grad():
                        print(
                            '[psnr] (imgs_lr):%.4f, (gen_hr):%.4f' %
                            (psnr(minmaxscaler(imgs_lr),
                                  imgs_hr_raw,
                                  max_val=1).mean().item(),
                             psnr(minmaxscaler(gen_hr), imgs_hr_raw,
                                  max_val=1).mean().item()))

                    gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
                    imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
                    img_grid = torch.cat((imgs_hr_raw, imgs_lr, gen_hr), -1)
                    save_image(img_grid,
                               "images/%d.png" % batches_done,
                               normalize=False)
        if gpu == 0:
            print('Elapsed_time for epoch(%s): %s' %
                  (epoch, epoch_timer.stop()))
    if gpu == 0:
        print("Training complete in: %s " % global_timer.stop())
        print('Average time per iteration: %s' % str(iter_time_meter.avg))
        torch.save(generator.state_dict(),
                   "saved_models/generator_%s.pth" % args.checkpoint_name)
        torch.save(discriminator.state_dict(),
                   "saved_models/discriminator_%s.pth" % args.checkpoint_name)