Пример #1
0
def train(dataloader, validloader, net, nepoch=10):

    start_epoch = 0
    loss_function = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=opt.lr)

    useGPU = torch.cuda.is_available() and not opt.cpu

    if useGPU:
        loss_function.cuda()

    if len(opt.weights) > 0:  # load previous weights?
        checkpoint = torch.load(opt.weights)
        print('loading checkpoint', opt.weights)
        if opt.undomulti:
            checkpoint['state_dict'] = remove_dataparallel_wrapper(
                checkpoint['state_dict'])
        else:
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']

    if len(opt.scheduler) > 0:
        stepsize, gamma = int(opt.scheduler.split(',')[0]), float(
            opt.scheduler.split(',')[1])
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              stepsize,
                                              gamma=gamma,
                                              last_epoch=start_epoch - 1)

    count = 0
    opt.t0 = time.perf_counter()

    for epoch in range(start_epoch, nepoch):
        mean_loss = 0

        for i, bat in enumerate(dataloader):
            lr, hr = bat[0], bat[1]
            optimizer.zero_grad()

            if useGPU:
                sr = net(lr.cuda())
            else:
                sr = net(lr)

            loss = loss_function(sr, hr.cuda())

            loss.backward()
            optimizer.step()

            ######### Status and display #########
            mean_loss += loss.data.item()
            print(
                '\r[%d/%d][%d/%d] Loss: %0.6f' %
                (epoch + 1, nepoch, i + 1, len(dataloader), loss.data.item()),
                end='')

            count += 1
            if opt.log and count * opt.batchSize // 1000 > 0:
                t1 = time.perf_counter() - opt.t0
                mem = torch.cuda.memory_allocated()
                print(epoch,
                      count * opt.batchSize,
                      t1,
                      mem,
                      mean_loss / count,
                      file=opt.train_stats)
                opt.train_stats.flush()
                count = 0

        # ---------------- Scheduler -----------------
        if len(opt.scheduler) > 0:
            scheduler.step()
            for param_group in optimizer.param_groups:
                print('\nLearning rate', param_group['lr'])
                break

        # ---------------- Printing -----------------
        print('\nEpoch %d done, %0.6f' % (epoch,
                                          (mean_loss / len(dataloader))))
        print('\nEpoch %d done, %0.6f' % (epoch,
                                          (mean_loss / len(dataloader))),
              file=opt.fid)
        opt.fid.flush()
        if opt.log:
            opt.writer.add_scalar('data/mean_loss',
                                  mean_loss / len(dataloader), epoch)

        # ---------------- TEST -----------------
        if (epoch + 1) % opt.testinterval == 0:
            testAndMakeCombinedPlots(net, validloader, opt, epoch)
            # if opt.scheduler:
            # scheduler.step(mean_loss / len(dataloader))

        if (epoch + 1) % opt.saveinterval == 0:
            # torch.save(net.state_dict(), opt.out + '/prelim.pth')
            checkpoint = {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(checkpoint, opt.out + '/prelim.pth')

    checkpoint = {
        'epoch': nepoch,
        'state_dict': net.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(checkpoint, opt.out + '/final.pth')
Пример #2
0
    dataloader, validloader = GetDataloaders(opt)
    net = GetModel(opt)

    if opt.log:
        opt.writer = SummaryWriter(
            comment='_%s_%s' %
            (opt.out.replace('\\', '/').split('/')[-1], opt.model))
        opt.train_stats = open(
            opt.out.replace('\\', '/') + '/train_stats.csv', 'w')
        opt.test_stats = open(
            opt.out.replace('\\', '/') + '/test_stats.csv', 'w')
        print('iter,nsample,time,memory,meanloss', file=opt.train_stats)
        print('iter,time,memory,psnr,ssim', file=opt.test_stats)

    import time
    t0 = time.perf_counter()
    if not opt.test:
        train(dataloader, validloader, net, nepoch=opt.nepoch)
    else:
        if len(opt.weights) > 0:  # load previous weights?
            checkpoint = torch.load(opt.weights)
            print('loading checkpoint', opt.weights)
            if opt.undomulti:
                checkpoint['state_dict'] = remove_dataparallel_wrapper(
                    checkpoint['state_dict'])
            net.load_state_dict(checkpoint['state_dict'])
            print('time: ', time.perf_counter() - t0)
        testAndMakeCombinedPlots(net, validloader, opt)
    print('time: ', time.perf_counter() - t0)
Пример #3
0
def main(opt):
    opt.device = torch.device(
        'cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')

    os.makedirs(opt.out, exist_ok=True)
    shutil.copy2('options.py', opt.out)

    opt.fid = open(opt.out + '/log.txt', 'w')

    ostr = 'ARGS: ' + ' '.join(sys.argv[:])
    print(opt, '\n')
    print(opt, '\n', file=opt.fid)
    print('\n%s\n' % ostr)
    print('\n%s\n' % ostr, file=opt.fid)

    print('getting dataloader', opt.root)
    dataloader, validloader = GetDataloaders(opt)

    if opt.log:
        opt.writer = SummaryWriter(
            log_dir=opt.out,
            comment='_%s_%s' %
            (opt.out.replace('\\', '/').split('/')[-1], opt.model))
        opt.train_stats = open(
            opt.out.replace('\\', '/') + '/train_stats.csv', 'w')
        opt.test_stats = open(
            opt.out.replace('\\', '/') + '/test_stats.csv', 'w')
        print('iter,nsample,time,memory,meanloss', file=opt.train_stats)
        print('iter,time,memory,psnr,ssim', file=opt.test_stats)

    t0 = time.perf_counter()
    net = GetModel(opt)

    if not opt.test:
        train(opt, dataloader, validloader, net)
        # torch.save(net.state_dict(), opt.out + '/final.pth')
    else:
        if len(opt.weights) > 0:  # load previous weights?
            checkpoint = torch.load(opt.weights)
            print('loading checkpoint', opt.weights)
            net.load_state_dict(checkpoint['state_dict'])
            print('time: %0.1f' % (time.perf_counter() - t0))
        testAndMakeCombinedPlots(net, validloader, opt)

    opt.fid.close()
    if not opt.test:
        generate_convergence_plots(opt, opt.out + '/log.txt')

    print('time: %0.1f' % (time.perf_counter() - t0))

    # optional clean up
    if opt.disposableTrainingData and not opt.test:
        print('deleting training data')
        # preserve a few samples
        os.makedirs('%s/training_data_subset' % opt.out, exist_ok=True)

        samplecount = 0
        for file in glob.glob('%s/*' % opt.root):
            if os.path.isfile(file):
                basename = os.path.basename(file)
                shutil.copy2(
                    file, '%s/training_data_subset/%s' % (opt.out, basename))
                samplecount += 1
                if samplecount == 10:
                    break
        shutil.rmtree(opt.root)
Пример #4
0
def train(opt, dataloader, validloader, net):
    start_epoch = 0
    if opt.task == 'segment' or opt.task == 'classification':
        loss_function = nn.CrossEntropyLoss()
    else:
        loss_function = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=opt.lr)
    loss_function.cuda()
    if len(opt.weights) > 0:  # load previous weights?
        checkpoint = torch.load(opt.weights)
        print('loading checkpoint', opt.weights)

        net.load_state_dict(checkpoint['state_dict'])
        if opt.lr == 1:  # continue as it was
            optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']

    if len(opt.scheduler) > 0:
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=5, min_lr=0, eps=1e-08)
        stepsize, gamma = int(opt.scheduler.split(',')[0]), float(
            opt.scheduler.split(',')[1])
        scheduler = optim.lr_scheduler.StepLR(optimizer, stepsize, gamma=gamma)
        if len(opt.weights) > 0:
            if 'scheduler' in checkpoint:
                scheduler.load_state_dict(checkpoint['scheduler'])

    opt.t0 = time.perf_counter()

    for epoch in range(start_epoch, opt.nepoch):
        count = 0
        mean_loss = 0

        # for param_group in optimizer.param_groups:
        #     print('\nLearning rate', param_group['lr'])

        for i, bat in enumerate(dataloader):
            lr, hr = bat[0], bat[1]

            optimizer.zero_grad()

            sr = net(lr.to(opt.device))
            loss = loss_function(sr, hr.to(opt.device))

            loss.backward()
            optimizer.step()

            ######### Status and display #########
            mean_loss += loss.data.item()
            print('\r[%d/%d][%d/%d] Loss: %0.6f' %
                  (epoch + 1, opt.nepoch, i + 1, len(dataloader),
                   loss.data.item()),
                  end='')

            count += 1
            if opt.log and count * opt.batchSize // 1000 > 0:
                t1 = time.perf_counter() - opt.t0
                mem = torch.cuda.memory_allocated()
                opt.writer.add_scalar('data/mean_loss_per_1000',
                                      mean_loss / count, epoch)
                opt.writer.add_scalar('data/time_per_1000', t1, epoch)
                print(epoch,
                      count * opt.batchSize,
                      t1,
                      mem,
                      mean_loss / count,
                      file=opt.train_stats)
                opt.train_stats.flush()
                count = 0

        # ---------------- Scheduler -----------------
        if len(opt.scheduler) > 0:
            scheduler.step()
            for param_group in optimizer.param_groups:
                print('\nLearning rate', param_group['lr'])
                break

        # ---------------- Printing -----------------
        mean_loss = mean_loss / len(dataloader)
        t1 = time.perf_counter() - opt.t0
        eta = (opt.nepoch - (epoch + 1)) * t1 / (epoch + 1)
        ostr = '\nEpoch [%d/%d] done, mean loss: %0.6f, time spent: %0.1fs, ETA: %0.1fs' % (
            epoch + 1, opt.nepoch, mean_loss, t1, eta)
        print(ostr)
        print(ostr, file=opt.fid)
        opt.fid.flush()
        if opt.log:
            opt.writer.add_scalar('data/mean_loss',
                                  mean_loss / len(dataloader), epoch)

        # ---------------- TEST -----------------
        if (epoch + 1) % opt.testinterval == 0:
            testAndMakeCombinedPlots(net, validloader, opt, epoch)

        if (epoch + 1) % opt.saveinterval == 0:
            # torch.save(net.state_dict(), opt.out + '/prelim.pth')
            checkpoint = {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            if len(opt.scheduler) > 0:
                checkpoint['scheduler'] = scheduler.state_dict()
            torch.save(checkpoint, '%s/prelim%d.pth' % (opt.out, epoch + 1))

    checkpoint = {
        'epoch': opt.nepoch,
        'state_dict': net.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    if len(opt.scheduler) > 0:
        checkpoint['scheduler'] = scheduler.state_dict()
    torch.save(checkpoint, opt.out + '/final.pth')
Пример #5
0
def train(dataloader, validloader, net, nepoch=10):

    start_epoch = 0
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=opt.lr)
    loss_function.cuda()

    loss_function_custom = nn.MSELoss()
    loss_function_custom.cuda()

    if len(opt.weights) > 0:  # load previous weights?
        checkpoint = torch.load(opt.weights)
        print('loading checkpoint', opt.weights)
        if opt.undomulti:
            checkpoint['state_dict'] = remove_dataparallel_wrapper(
                checkpoint['state_dict'])
        if opt.modifyPretrainedModel:
            pretrained_dict = checkpoint['state_dict']
            model_dict = net.state_dict()
            # 1. filter out unnecessary keys
            for k, v in list(pretrained_dict.items()):
                print(k)
            pretrained_dict = {
                k: v
                for k, v in list(pretrained_dict.items())[:-2]
            }
            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)
            # 3. load the new state dict
            net.load_state_dict(model_dict)

            # optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']
        else:
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']

    if len(opt.scheduler) > 0:
        stepsize, gamma = int(opt.scheduler.split(',')[0]), float(
            opt.scheduler.split(',')[1])
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              stepsize,
                                              gamma=gamma,
                                              last_epoch=start_epoch - 1)

    count = 0
    opt.t0 = time.perf_counter()

    for epoch in range(start_epoch, nepoch):
        mean_loss = 0

        for i, bat in enumerate(dataloader):
            lr, hr = bat[0], bat[1]

            optimizer.zero_grad()
            if opt.model == 'ffdnet':
                stdvec = torch.zeros(lr.shape[0])
                for j in range(lr.shape[0]):
                    noise = lr[j] - hr[j]
                    stdvec[j] = torch.std(noise)
                noise = net(lr.cuda(), stdvec.cuda())
                sr = torch.clamp(lr.cuda() - noise, 0, 1)
                gt_noise = lr.cuda() - hr.cuda()
                loss = loss_function(noise, gt_noise)
            elif opt.task == 'residualdenoising':
                noise = net(lr.cuda())
                gt_noise = lr.cuda() - hr.cuda()
                loss = loss_function(noise, gt_noise)
            else:
                sr = net(lr.cuda())
                if opt.task == 'segment':
                    hr_classes = torch.round(2 * hr).long()
                    loss = loss_function(sr.squeeze(),
                                         hr_classes.squeeze().cuda())
                else:
                    loss = loss_function(sr, hr.cuda())

            loss.backward()
            optimizer.step()

            ######### Status and display #########
            mean_loss += loss.data.item()
            print(
                '\r[%d/%d][%d/%d] Loss: %0.6f' %
                (epoch + 1, nepoch, i + 1, len(dataloader), loss.data.item()),
                end='')

            count += 1
            if opt.log and count * opt.batchSize // 1000 > 0:
                t1 = time.perf_counter() - opt.t0
                mem = torch.cuda.memory_allocated()
                print(epoch,
                      count * opt.batchSize,
                      t1,
                      mem,
                      mean_loss / count,
                      file=opt.train_stats)
                opt.train_stats.flush()
                count = 0

        # ---------------- Scheduler -----------------
        if len(opt.scheduler) > 0:
            scheduler.step()
            for param_group in optimizer.param_groups:
                print('\nLearning rate', param_group['lr'])
                break

        # ---------------- Printing -----------------
        print('\nEpoch %d done, %0.6f' % (epoch,
                                          (mean_loss / len(dataloader))))
        print('\nEpoch %d done, %0.6f' % (epoch,
                                          (mean_loss / len(dataloader))),
              file=opt.fid)
        opt.fid.flush()

        # ---------------- TEST -----------------
        if (epoch + 1) % opt.testinterval == 0:
            testAndMakeCombinedPlots(net, validloader, opt, epoch)
            # if opt.scheduler:
            # scheduler.step(mean_loss / len(dataloader))

        if (epoch + 1) % opt.saveinterval == 0:
            # torch.save(net.state_dict(), opt.out + '/prelim.pth')
            checkpoint = {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(checkpoint, opt.out + '/prelim.pth')

    checkpoint = {
        'epoch': nepoch,
        'state_dict': net.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(checkpoint, opt.out + '/final.pth')