Exemplo n.º 1
0
def main():

    global opt, model_G, model_D, netContent, writer, STEPS

    opt = parser.parse_args()
    options = option.parse(opt.options)
    print(opt)

    out_folder = "steps({})_lrIN({})_lrOUT({})_lambda(mseIN={},mseOUT={},vgg={},adv={},preserve={})".format(
        opt.inner_loop_steps, opt.lr_inner, opt.lr_outer,
        opt.mse_loss_coefficient_inner, opt.mse_loss_coefficient_outer,
        opt.vgg_loss_coefficient, opt.adversarial_loss_coefficient,
        opt.preservation_loss_coefficient)

    writer = SummaryWriter(logdir=os.path.join(opt.logs_dir, out_folder),
                           comment="-srgan-")

    opt.sample_dir = os.path.join(opt.sample_dir, out_folder)
    opt.fine_sample_dir = os.path.join(opt.fine_sample_dir, out_folder)

    opt.checkpoint_file_init = os.path.join(opt.checkpoint_dir,
                                            "init/" + out_folder)
    opt.checkpoint_file_final = os.path.join(opt.checkpoint_dir,
                                             "final/" + out_folder)
    opt.checkpoint_file_fine = os.path.join(opt.fine_checkpoint_dir,
                                            out_folder)

    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    if not torch.cuda.is_available():
        raise Exception(
            "No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True

    print("===> Loading datasets")
    dataset_opt = options['datasets']['train']
    dataset_opt['batch_size'] = opt.batchSize
    print(dataset_opt)
    train_set = create_dataset(dataset_opt)
    training_data_loader = create_dataloader(train_set, dataset_opt)
    print('===> Train Dataset: %s   Number of images: [%d]' %
          (train_set.name(), len(train_set)))
    if training_data_loader is None:
        raise ValueError("[Error] The training data does not exist")

    print('===> Loading VGG model')
    netVGG = models.vgg19()
    if os.path.isfile('data/vgg19-dcbb9e9d.pth'):
        netVGG.load_state_dict(torch.load('data/vgg19-dcbb9e9d.pth'))
    else:
        netVGG.load_state_dict(
            model_zoo.load_url(
                'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'))

    class _content_model(nn.Module):
        def __init__(self):
            super(_content_model, self).__init__()
            self.feature = nn.Sequential(
                *list(netVGG.features.children())[:-1])

        def forward(self, x):
            out = self.feature(x)
            return out

    G_init = _NetG(opt).cuda()
    model_D = _NetD().cuda()
    netContent = _content_model().cuda()
    criterion_G = GeneratorLoss(netContent, writer, STEPS).cuda()
    criterion_D = nn.BCELoss().cuda()

    # optionally copy weights from a checkpoint
    if opt.pretrained:
        assert os.path.isfile(opt.pretrained)
        print("=> loading model '{}'".format(opt.pretrained))
        weights = torch.load(opt.pretrained)
        # changed
        G_init.load_state_dict(weights['model'].state_dict())

    print("===> Setting Optimizer")
    # changed
    optimizer_G_outer = optim.Adam(G_init.parameters(), lr=opt.lr_outer)
    optimizer_D = optim.Adam(model_D.parameters(), lr=opt.lr_disc)

    print("===> Pre-fetching validation data for monitoring training")
    test_dump_file = 'data/dump/Test5.pickle'

    if os.path.isfile(test_dump_file):
        with open(test_dump_file, 'rb') as p:
            images_test = pickle.load(p)
        images_hr = images_test['images_hr']
        images_lr = images_test['images_lr']
        print("===>Loading Checkpoint Test images")
    else:
        images_hr, images_lr = create_val_ims()
        print("===>Creating Checkpoint Test images")

    print("===> Training")
    epoch = opt.start_epoch
    try:
        while STEPS < (opt.inner_loop_steps + 1) * opt.max_updates:
            # changed
            last_model_G = train(training_data_loader, optimizer_G_outer,
                                 optimizer_D, G_init, model_D, criterion_G,
                                 criterion_D, epoch, STEPS, writer)
            assert last_model_G is not None
            save_checkpoint(images_hr, images_lr, G_init, last_model_G, epoch)
            epoch += 1
    except KeyboardInterrupt:
        print("KeyboardInterrupt HANDLED! Running the final epoch on G_init")
    epoch += 1
    if STEPS < 5e4:
        lr_finetune = opt.lr_inner
    elif STEPS < 1e5:
        lr_finetune = opt.lr_inner / 2
    elif STEPS < 2e5:
        lr_finetune = opt.lr_inner / 4
    elif STEPS < 4e5:
        lr_finetune = opt.lr_inner / 8
    elif STEPS < 8e5:
        lr_finetune = opt.lr_inner / 16
    else:
        lr_finetune = opt.lr_inner / 32

    model_G = deepcopy(G_init)
    optimizer_G_inner = optim.Adam(model_G.parameters(), lr=lr_finetune)
    model_G.train()
    optimizer_G_inner.zero_grad()
    init_parameters = torch.cat(
        [p.view(-1) for k, p in G_init.named_parameters() if p.requires_grad])

    opt.adversarial_loss = False
    opt.vgg_loss = True
    opt.mse_loss_coefficient = opt.mse_loss_coefficient_inner

    start_time = dt.datetime.now()
    total_num_examples = len(training_data_loader)
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch['LR']), Variable(batch['HR'],
                                                        requires_grad=False)
        input = input.cuda() / 255
        target = target.cuda() / 255
        STEPS += 1
        output = model_G(input)
        fake_out = None
        optimizer_G_inner.zero_grad()
        loss_g_inner = criterion_G(fake_out, output, target, opt)
        curr_parameters = torch.cat([
            p.view(-1) for k, p in model_G.named_parameters()
            if p.requires_grad
        ])
        preservation_loss = ((Variable(init_parameters).detach() -
                              curr_parameters)**2).sum()
        loss_g_inner += preservation_loss
        loss_g_inner.backward()
        optimizer_G_inner.step()
        writer.add_scalar("Loss_G_finetune", loss_g_inner.item(), STEPS)
        if iteration % 5 == 0:
            fine_sample_img = torch_utils.make_grid(torch.cat(
                [output.detach().clone(), target], dim=0),
                                                    padding=2,
                                                    normalize=False)
            if not os.path.exists(opt.fine_sample_dir):
                os.makedirs(opt.fine_sample_dir)
            torch_utils.save_image(fine_sample_img,
                                   os.path.join(
                                       opt.fine_sample_dir,
                                       "Epoch-{}--Iteration-{}.png".format(
                                           epoch, iteration)),
                                   padding=5)

            print("===> Finetuning Epoch[{}]({}/{}): G_Loss(finetune): {:.3}".
                  format(epoch, iteration, total_num_examples,
                         loss_g_inner.item(),
                         (dt.datetime.now() - start_time).seconds))
            start_time = dt.datetime.now()
            save_checkpoint(images_hr,
                            images_lr,
                            None,
                            model_G,
                            iteration,
                            finetune=True)
    save_checkpoint(images_hr,
                    images_lr,
                    None,
                    model_G,
                    total_num_examples,
                    finetune=True)
Exemplo n.º 2
0
def main():
    global opt, model
    if opt.cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpus)
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True
    scale = int(args.scale[0])
    print("===> Loading datasets")

    opt.n_train = 400
    loader = data.Data(opt)
    opt_high = copy.deepcopy(opt)
    opt_high.offset_train = 400
    opt_high.n_train = 400

    loader_high = data.Data(opt_high)

    training_data_loader = loader.loader_train
    training_high_loader = loader_high.loader_train
    test_data_loader = loader.loader_test

    print("===> Building model")
    GLR = _NetG_DOWN(stride=2)
    GHR = EDSR(args)
    GDN = _NetG_DOWN(stride=1)
    DLR = _NetD(stride=1)
    DHR = _NetD(stride=2)
    GNO = _NetG_DOWN(stride=1)

    Loaded = torch.load(
        '../experiment/model/EDSR_baseline_x{}.pt'.format(scale))
    GHR.load_state_dict(Loaded)

    model = nn.ModuleList()

    model.append(GDN)  #DN
    model.append(GHR)
    model.append(GLR)  #LR
    model.append(DLR)
    model.append(DHR)
    model.append(GNO)  #

    cudnn.benchmark = True

    print("===> Setting GPU")
    if opt.cuda:
        model = model.cuda()

    optG = torch.optim.Adam(
        list(model[0].parameters()) + list(model[1].parameters()) +
        list(model[2].parameters()) + list(model[5].parameters()),
        lr=opt.lr,
        weight_decay=0)
    optD = torch.optim.Adam(list(model[3].parameters()) +
                            list(model[4].parameters()),
                            lr=opt.lr,
                            weight_decay=0)

    # optionally resume from a checkpoint
    opt.resume = 'model_total_{}.pth'.format(scale)
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1

            optG.load_state_dict(checkpoint['optimizer'][0])
            optD.load_state_dict(checkpoint['optimizer'][1])
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # opt.start_epoch = 401
    step = 2 if opt.start_epoch > opt.epochs else 1

    # model.load_state_dict(torch.load('backup.pt'))

    optimizer = [optG, optD]

    # print("===> Setting Optimizer")

    if opt.test_only:
        print('===> Testing')
        test(test_data_loader, model, opt.start_epoch)
        return

    if step == 1:
        print("===> Training Step 1.")
        for epoch in range(opt.start_epoch, opt.epochs + 1):
            train(training_data_loader, training_high_loader, model, optimizer,
                  epoch, False)
            save_checkpoint(model, optimizer, epoch, scale)
            test(test_data_loader, model, epoch)
        torch.save(model.state_dict(), 'backup.pt')
    elif step == 2:
        print("===> Training Step 2.")
        opt.lr = 1e-4
        for epoch in range(opt.start_epoch + 1, opt.epochs * 2 + 1):
            train(training_data_loader, training_high_loader, model, optimizer,
                  epoch, True)
            save_checkpoint(model, optimizer, epoch, scale)
            test(test_data_loader, model, epoch)
Exemplo n.º 3
0
def main():
    global opt, model
    if opt.cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpus)
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True
    scale = int(args.scale[0])
    print("===> Loading datasets")

    opt.n_train = 400
    loader = data.Data(opt)
    opt_high = copy.deepcopy(opt)
    opt_high.offset_train = 400
    opt_high.n_train = 400

    loader_high = data.Data(opt_high)

    training_data_loader = loader.loader_train
    training_high_loader = loader_high.loader_train
    test_data_loader = loader.loader_test

    print("===> Building model")
    GLR = _NetG_DOWN(stride=2)  #EDSR(args)
    GHR = EDSR(
        args)  #_NetG_UP()#Generator(G_input_dim, num_filters, G_output_dim)
    GDN = _NetG_DOWN(stride=1)  #EDSR(args)
    DLR = _NetD(
        stride=1
    )  # True)# _NetD(3)#Generator(G_input_dim, num_filters, G_output_dim)
    DHR = _NetD(stride=2)  #Generator(G_input_dim, num_filters, G_output_dim)
    GNO = _NetG_DOWN(stride=1)  #EDSR(args)

    Loaded = torch.load(
        '../experiment/model/EDSR_baseline_x{}.pt'.format(scale))
    GHR.load_state_dict(Loaded)

    model = nn.ModuleList()

    model.append(GDN)  #DN
    model.append(GHR)
    model.append(GLR)  #LR
    model.append(DLR)
    model.append(DHR)
    model.append(GNO)  #

    print(model)

    cudnn.benchmark = True
    # optionally resume from a checkpoint
    opt.resume = 'model_total_{}.pth'.format(scale)
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # model[4] = _NetD_(4, True)#, True, 4)

    print("===> Setting GPU")
    if opt.cuda:
        model = model.cuda()

    print("===> Setting Optimizer")
    # optimizer = optim.Adam(model.parameters(), lr=opt.lr)#, momentum=opt.momentum, weight_decay=opt.weight_decay)

    if opt.test_only:
        print('===> Testing')
        test(test_data_loader, model, opt.start_epoch)
        return

    print("===> Training Step 1.")
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        train(training_data_loader, training_high_loader, model, epoch, False)
        save_checkpoint(model, epoch, scale)
        test(test_data_loader, model, epoch)

    print("===> Training Step 2.")
    opt.lr = 1e-4
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        train(training_data_loader, training_high_loader, model, epoch, True)
        save_checkpoint(model, epoch, scale)
        test(test_data_loader, model, epoch)
def main():

    global opt, model, netContent
    opt = parser.parse_args()
    print(opt)

    cuda = opt.cuda
    if cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
        if not torch.cuda.is_available():
                raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True

    print("===> Loading datasets")
    train_set = DatasetFromHdf5("/path/to/your/hdf5/data/like/rgb_srresnet_x4.h5")
    training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, \
        batch_size=opt.batchSize, shuffle=True)

    # if opt.vgg_loss:
    #     print('===> Loading VGG model')
    #     netVGG = models.vgg19()
    #     netVGG.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'))
    #     class _content_model(nn.Module):
    #         def __init__(self):
    #             super(_content_model, self).__init__()
    #             self.feature = nn.Sequential(*list(netVGG.features.children())[:-1])
                
    #         def forward(self, x):
    #             out = self.feature(x)
    #             return out

    #     netContent = _content_model()

    print("===> Building model")
    model_G = _NetG()
    model_D = _NetD()
    criterion = nn.MSELoss()
    criterion_D = nn.MSELoss()

    print("===> Setting GPU")
    if cuda:
        model = model.cuda()
        criterion = criterion.cuda()
        if opt.vgg_loss:
            netContent = netContent.cuda() 

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    optionally copy weights from a checkpoint
    if opt.pretrained:
        if os.path.isfile(opt.pretrained):
            print("=> loading model '{}'".format(opt.pretrained))
            weights = torch.load(opt.pretrained)
            model.load_state_dict(weights['model'].state_dict())
        else:
            print("=> no model found at '{}'".format(opt.pretrained))

    print("===> Setting Optimizer")
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    print("===> Training")
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        train(training_data_loader, optimizer, model, criterion, epoch)
        save_checkpoint(model, epoch)