Ejemplo n.º 1
0
 def network_initializers(self, hr_shape, use_LeakyReLU_Mish=False):
     generator = GeneratorRRDB(self.opt.channels,
                               filters=64,
                               num_res_blocks=self.opt.residual_blocks,
                               use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(
                                   self.device, non_blocking=True)
     discriminator = Discriminator(
         input_shape=(self.opt.channels, *hr_shape),
         use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(self.device,
                                                   non_blocking=True)
     feature_extractor = FeatureExtractor().to(self.device,
                                               non_blocking=True)
     # Set feature extractor to inference mode
     feature_extractor.eval()
     return discriminator, feature_extractor, generator
Ejemplo n.º 2
0
    def _set_model(self, device, hr_shape):
        # Initialize generator and discriminator
        self.generator = GeneratorRRDB(
            opt.channels, filters=64,
            num_res_blocks=opt.residual_blocks).to(device)
        self.discriminator = Discriminator(input_shape=(opt.channels,
                                                        *hr_shape)).to(device)
        self.feature_extractor = FeatureExtractor().to(device)

        # Set feature extractor to inference mode
        self.feature_extractor.eval()

        # Losses
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
        self.criterion_content = torch.nn.L1Loss().to(device)
        self.criterion_pixel = torch.nn.L1Loss().to(device)
Ejemplo n.º 3
0
def print_network():
    opt = setup()
    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    printer('generator')
    summary(generator.cuda(), (3, 32, 32))
    printer('discriminator')
    summary(discriminator.cuda(), (3, 32, 32))
    printer('feature_extractor')
    summary(feature_extractor.cuda(), (3, 32, 32))
Ejemplo n.º 4
0
def get_prediction(model_checkpoint, resnet_type):
    # models
    F = FeatureExtractor(resnet=resnet_type).to(device)
    C = LabelPredictor(resnet=resnet_type).to(device)

    checkpoint = torch.load(model_checkpoint)
    F.load_state_dict(checkpoint['feature_extractor'])
    C.load_state_dict(checkpoint['label_predictor'])

    # predict
    F.eval()
    C.eval()
    result = []
    for i, (data, _) in enumerate(target_loader):
        print(i + 1, len(target_loader), end='\r')
        data = data.to(device)

        logits = C(F(data))

        x = torch.argmax(logits, dim=1).cpu().detach().numpy()
        result.append(x)

    # delete model
    del F
    del C
    torch.cuda.empty_cache()

    return np.concatenate(result)
Ejemplo n.º 5
0
    def __init__(self):
        super(A3Cagent, self).__init__()

        self.Conv = FeatureExtractor()

        self.A, self.C = Actor(), Critic()

        # Try loading checkpoints
        if LOAD_CHECKPOINTS:
            self.load_weights()

        self.opt = torch.optim.RMSprop(self.parameters(), lr=LEARNING_RATE)

        self.mem = [[], [],
                    []]  # Stores log_probs, values, rewards during episode
        self.total_entropy = 0
        self.steps = 0
def init(opt):
    # [folder] create folder for checkpoints
    try: os.makedirs(opt.out)
    except OSError: pass

    # [cuda] check cuda, if cuda is available, then display warning
    if torch.cuda.is_available() and not opt.cuda:
        sys.stdout.write('[WARNING] : You have a CUDA device, so you should probably run with --cuda')

    # [normalization] __return__ normalize images, set up mean and std
    normalize = transforms.Normalize(
                                        mean = [0.485, 0.456, 0.406],
                                        std = [0.229, 0.224, 0.225])
    # [scale] __return__
    scale = transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize(opt.imageSize),
                                    transforms.ToTensor(),
                                    transforms.Normalize(
                                                            mean = [0.485, 0.456, 0.406],
                                                            std = [0.229, 0.224, 0.225])])

    # [transform] up sampling transforms
    transform = transforms.Compose([transforms.RandomCrop((opt.imageSize[0] * opt.upSampling,
                                                           opt.imageSize[1] * opt.upSampling)),
                                    transforms.ToTensor()])
    # [dataset] training dataset
    if opt.dataset == 'folder':
        dataset = datasets.ImageFolder(root = opt.dataroot, transform = transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root = opt.dataroot, train = True, download = True, transform = transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root = opt.dataroot, train = True, download = False, transform = transform)
    assert dataset
    
    # [dataloader] __return__ loading dataset
    dataloader = torch.utils.data.DataLoader(
                                                 dataset,
                                                 batch_size = opt.batchSize,
                                                 shuffle = True,
                                                 num_workers = int(opt.workers))
    # [generator] __return__ generator of GAN
    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '' and os.path.exists(opt.generatorWeights):
        generator.load_state_dict(torch.load(opt.generatorWeights))

    # [discriminator] __return__ discriminator of GAN
    discriminator = Discriminator()
    if opt.discriminatorWeights != '' and os.path.exists(opt.discriminatorWeights):
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    # [extractor] __return__ feature extractor of GAN
    # For the content loss
    feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained = True))

    # [loss] __return__ loss function
    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()
    ones_const = Variable(torch.ones(opt.batchSize, 1))

    # [cuda] if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    # [optimizer] __return__ Optimizer for GAN 
    optim_generator = optim.Adam(generator.parameters(), lr = opt.generatorLR)
    optim_discriminator = optim.Adam(discriminator.parameters(), lr = opt.discriminatorLR)

    # record configure
    configure('logs/{}-{}-{} -{}'.format(opt.dataset, str(opt.batchSize), str(opt.generatorLR), str(opt.discriminatorLR)), flush_secs = 5)
    # visualizer = Visualizer(image_size = (opt.imageSize[0] * opt.upSampling, opt.imageSize[1] * opt.upSampling))

    # __return__ low resolution images
    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0], opt.imageSize[1])

    return normalize,\
           scale,\
           dataloader,\
           generator,\
           discriminator,\
           feature_extractor,\
           content_criterion,\
           adversarial_criterion,\
           ones_const,\
           optim_generator,\
           optim_discriminator,\
           low_res
Ejemplo n.º 7
0
    return np.random.choice(len(policy), 1, p=policy)[0]


def to_tensor(x, dtype=None):
    return torch.tensor(x, dtype=dtype).unsqueeze(0)


if __name__ == '__main__':
    env = gym.make('CartPole-v1')

    # Actor Critic
    actor = Actor(n_actions=env.action_space.n, space_dims=4, hidden_dims=32)
    critic = Critic(space_dims=4, hidden_dims=32)

    # ICM
    feature_extractor = FeatureExtractor(env.observation_space.shape[0], 32)
    forward_model = ForwardModel(env.action_space.n, 32)
    inverse_model = InverseModel(env.action_space.n, 32)

    # Actor Critic
    a_optim = torch.optim.Adam(actor.parameters(), lr=args.lr_actor)
    c_optim = torch.optim.Adam(critic.parameters(), lr=args.lr_critic)

    # ICM
    icm_params = list(feature_extractor.parameters()) + list(
        forward_model.parameters()) + list(inverse_model.parameters())
    icm_optim = torch.optim.Adam(icm_params, lr=args.lr_icm)

    pg_loss = PGLoss()
    mse_loss = nn.MSELoss()
    xe_loss = nn.CrossEntropyLoss()
Ejemplo n.º 8
0
    def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.firsttime = 0

        self.env = env
        self.action_range = [env.action_space.low, env.action_space.high]
        #self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]  #1

        self.conv_channels = 4
        self.kernel_size = (3, 3)

        self.img_size = (500, 500, 3)

        print("Diagnostics:")
        print(f"action_range: {self.action_range}")
        #print(f"obs_dim: {self.obs_dim}")
        print(f"action_dim: {self.action_dim}")

        # hyperparameters
        self.gamma = gamma
        self.tau = tau
        self.update_step = 0
        self.delay_step = 2

        # initialize networks
        self.feature_net = FeatureExtractor(self.img_size[2],
                                            self.conv_channels,
                                            self.kernel_size).to(self.device)
        print("Feature net init'd successfully")

        input_dim = self.feature_net.get_output_size(self.img_size)
        self.input_size = input_dim[0] * input_dim[1] * input_dim[2]
        print(f"input_size: {self.input_size}")

        self.value_net = ValueNetwork(self.input_size, 1).to(self.device)
        self.target_value_net = ValueNetwork(self.input_size,
                                             1).to(self.device)
        self.q_net1 = SoftQNetwork(self.input_size,
                                   self.action_dim).to(self.device)
        self.q_net2 = SoftQNetwork(self.input_size,
                                   self.action_dim).to(self.device)
        self.policy_net = PolicyNetwork(self.input_size,
                                        self.action_dim).to(self.device)

        print("Finished initing all nets")

        # copy params to target param
        for target_param, param in zip(self.target_value_net.parameters(),
                                       self.value_net.parameters()):
            target_param.data.copy_(param)

        print("Finished copying targets")

        # initialize optimizers
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr)
        self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
        self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(),
                                           lr=policy_lr)

        print("Finished initing optimizers")

        self.replay_buffer = BasicBuffer(buffer_maxlen)
        print("End of init")
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='folder', help='cifar10 | cifar100 | folder')
    parser.add_argument('--dataroot', type=str, default='./data', help='path to dataset')
    parser.add_argument('--workers', type=int, default=2, help='number of data loading workers')
    parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
    parser.add_argument('--upSampling', type=int, default=4, help='low to high resolution scaling factor')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--nGPU', type=int, default=2, help='number of GPUs to use')
    parser.add_argument('--generatorWeights', type=str, default='checkpoints/generator_final.pth', help="path to generator weights (to continue training)")
    parser.add_argument('--discriminatorWeights', type=str, default='checkpoints/discriminator_final.pth', help="path to discriminator weights (to continue training)")

    opt = parser.parse_args()
    print(opt)
	


    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    transform = transforms.Compose([transforms.ToTensor()])

    normalize = transforms.Compose([transforms.ToPILImage(),
                                transforms.ToTensor(),
                                transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                    std = [0.229, 0.224, 0.225])
                                ])

    # Equivalent to un-normalizing ImageNet (for correct visualization)
    unnormalize = transforms.Normalize(mean = [-2.118, -2.036, -1.804], std = [4.367, 4.464, 4.444])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root=opt.dataroot, download=True, train=False, transform=transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root=opt.dataroot, download=True, train=False, transform=transform)
    assert dataset
    
    #print(dataset)
    image_name = dataset.imgs  # image path
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                             shuffle=False, num_workers=int(opt.workers))

    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))
    print(generator)

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
    print(discriminator)

    # For the content loss
    feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
    print(feature_extractor)

    # if gpu is to be used
    if opt.cuda:
        #generator.cuda()
        #discriminator.cuda()
        #feature_extractor.cuda()
        gpu_ids = [0,2] 
        torch.cuda.set_device(gpu_ids[0])
        generator = torch.nn.DataParallel(generator, device_ids=gpu_ids).cuda()
        discriminator = torch.nn.DataParallel(discriminator, device_ids=gpu_ids).cuda()
        feature_extractor = torch.nn.DataParallel(feature_extractor, device_ids=gpu_ids).cuda()

    print('Test started...')

    # Set evaluation mode (not training)
    generator.eval()
    discriminator.eval()
    #print(len(dataloader))
    for i, data in enumerate(dataloader):
        # Generate data
        low_res, _ = data
        #print(low_res.shape)
        #print(image_name[i])
        # eg: image_type_path, image_detail_name = bounding_box_test , -1_c1s3_065901_04.jpg
        image_type_path, image_detail_name = [],[]
        for j in range(len(low_res)):  # not opt.batchSize means never skip final batch
            # 'replace' is for window path issue
            image_type_path.append(image_name[i*opt.batchSize+j][0].replace('\\','/').split('/')[-2])
            image_detail_name.append(image_name[i*opt.batchSize+j][0].replace('\\','/').split('/')[-1])
        #print(len(image_type_path), len(image_detail_name))
        for j in range(len(low_res)):  # never skip final batch
            low_res[j] = normalize(low_res[j])

        # Generate real and fake inputs
        if opt.cuda:
            high_res_fake = generator(Variable(low_res).cuda())
        else:
            high_res_fake = generator(Variable(low_res))
        
        # high_res_fake = high_res_fake.to(torch.device('cuda:0'))
        for j in range(len(low_res)):  # not opt.batchSize means never skip final batch
            print(image_type_path[j],image_detail_name[j])
            if not os.path.exists('output/high_res_fake/A/{}'.format(image_type_path[j])):
                os.makedirs('output/high_res_fake/A/{}'.format(image_type_path[j]))
            #print(high_res_fake[j])
            #print(low_res[j])
            # if use unnormalize would lead error, why? sys say `high_res_real[j]` is not cuda but i print it show is cuda
            # must be cpu???

            # comment 1,uncommnet 2 when get full market high_res dataset; uncomment 1,comment 2 when only test some images
            # 1. when only test some images
            ##################################################################################################
            #save_image(unnormalize(high_res_fake[j].cpu()), 'output/high_res_fake/' + str(i * opt.batchSize + j) + '.png')  
            #save_image(unnormalize(low_res[j]), 'output/low_res/' + str(i*opt.batchSize + j) + '.png')  # save raw low_res images
            ##################################################################################################

            # 2. when get full dataset
            ##################################################################################################
            save_image(unnormalize(high_res_fake[j].cpu()), 'output/high_res_fake/A/{}/{}'.format(image_type_path[j], image_detail_name[j]))
def upsampling(path, picture_name, upsampling):
    opt = setup()
    # image = Image.open(os.getcwd() + r'\images\\' + path)
    image = Image.open(path)
    opt.imageSize = (image.size[1], image.size[0])

    log = '>>> process image : {} size : ({}, {}) sr_reconstruct size : ({}, {})'.format(
        picture_name, image.size[0], image.size[1], image.size[0] * upsampling,
        image.size[1] * upsampling)
    try:
        os.makedirs(os.getcwd() + r'\output\result')
    except OSError:
        pass

    if torch.cuda.is_available() and not opt.cuda:
        print(
            '[WARNING] : You have a CUDA device, so you should probably run with --cuda'
        )

    transform = transforms.Compose([
        transforms.RandomCrop(opt.imageSize),
        transforms.Pad(padding=0),
        transforms.ToTensor()
    ])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # Equivalent to un-normalizing ImageNet (for correct visualization)
    unnormalize = transforms.Normalize(mean=[-2.118, -2.036, -1.804],
                                       std=[4.367, 4.464, 4.444])
    scale = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(opt.imageSize),
        transforms.Pad(padding=0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root=opt.dataroot,
                                   download=True,
                                   train=False,
                                   transform=transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root=opt.dataroot,
                                    download=True,
                                    train=False,
                                    transform=transform)
    assert dataset

    dataloader = transforms.Compose([transforms.ToTensor()])
    image = dataloader(image)

    # loading paras from networks
    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    # For the content loss
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    target_real = Variable(torch.ones(opt.batchSize, 1))
    target_fake = Variable(torch.zeros(opt.batchSize, 1))

    # if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        target_real = target_real.cuda()
        target_fake = target_fake.cuda()

    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0],
                                opt.imageSize[1])

    # Set evaluation mode (not training)
    generator.eval()
    discriminator.eval()

    # Generate data
    high_res_real = image

    # Downsample images to low resolution
    low_res = scale(high_res_real)
    low_res = torch.tensor([np.array(low_res)])

    high_res_real = normalize(high_res_real)
    high_res_real = torch.tensor([np.array(high_res_real)])

    # Generate real and fake inputs
    if opt.cuda:
        high_res_real = Variable(high_res_real.cuda())
        high_res_fake = generator(Variable(low_res).cuda())
    else:
        high_res_real = Variable(high_res_real)
        high_res_fake = generator(Variable(low_res))

    save_image(unnormalize(high_res_fake[0]),
               './output/result/' + picture_name)
    return log
Ejemplo n.º 11
0
    # Init

    os.makedirs("saved_models", exist_ok=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    writer = SummaryWriter()

    # Get models

    hr_shape = (opt.hr_height, opt.hr_width)
    generator = Generator(filters=64,
                          num_res_blocks=opt.residual_blocks,
                          num_upsample=opt.num_upsample) \
        .to(device).train()
    discriminator = Discriminator() \
        .to(device).train()
    feature_extractor = FeatureExtractor() \
        .to(device).eval()

    if opt.netG_checkpoint:
        try:
            generator.load_state_dict(
                torch.load(opt.netG_checkpoint, map_location="cpu"))
            print(
                f"[x] Restored generator weights from: {opt.netG_checkpoint}")
        except:
            print("[!] Generator weights from scratch.")
    if opt.netD_checkpoint:
        try:
            discriminator.load_state_dict(
                torch.load(opt.netD_checkpoint, map_location="cpu"))
            print(
                f"[x] Restored discriminator weights from: {opt.netD_checkpoint}"
Ejemplo n.º 12
0
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Fetch Data
    dataset = datasets.ImageFolder(root="./data", transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    #FIXME devicedataloader -> do we need it?

    generator = Generator(opt.resBlocks, opt.upSampling)
    discriminator = Discriminator()
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    generator = nn.DataParallel(generator)
    generator.to(device)

    discriminator = nn.DataParallel(discriminator)
    discriminator.to(device)

    #feature_extractor = nn.DataParallel(feature_extractor)
    feature_extractor.to(device)

    #content_criterion = nn.DataParallel(content_criterion)
    content_criterion.to(device)
Ejemplo n.º 13
0
class ESRGAN():
    def __init__(self, opt):
        self.opt = opt
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        hr_shape = (self.opt.hr_height, self.opt.hr_width)
        self._set_model(device, hr_shape)

    def _set_model(self, device, hr_shape):
        # Initialize generator and discriminator
        self.generator = GeneratorRRDB(
            opt.channels, filters=64,
            num_res_blocks=opt.residual_blocks).to(device)
        self.discriminator = Discriminator(input_shape=(opt.channels,
                                                        *hr_shape)).to(device)
        self.feature_extractor = FeatureExtractor().to(device)

        # Set feature extractor to inference mode
        self.feature_extractor.eval()

        # Losses
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
        self.criterion_content = torch.nn.L1Loss().to(device)
        self.criterion_pixel = torch.nn.L1Loss().to(device)

    def _set_param(self):
        for key, value in vars(opt).items():
            mlflow.log_param(key, value)

    def _load_weigth(self):
        if opt.epoch != 0:
            # Load pretrained models
            load_g_weight_path = osp.join(weight_save_dir,
                                          "generator_%d.pth" % opt.epoch)
            load_d_weight_path = osp.join(weight_save_dir,
                                          "discriminator_%d.pth" % opt.epoch)

            self.generator.load_state_dict(torch.load(load_g_weight_path))
            self.discriminator.load_state_dict(torch.load(load_d_weight_path))

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))

    # ----------
    #  Training
    # ----------
    def train(self, dataloader, opt):
        for epoch in range(opt.epoch + 1, opt.n_epochs + 1):
            for batch_num, imgs in enumerate(dataloader):
                Tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
                ) else torch.Tensor
                batches_done = (epoch - 1) * len(dataloader) + batch_num

                # Configure model input
                imgs_lr = Variable(imgs["lr"].type(Tensor))
                imgs_hr = Variable(imgs["hr"].type(Tensor))

                # Adversarial ground truths
                valid = Variable(Tensor(
                    np.ones((imgs_lr.size(0), *discriminator.output_shape))),
                                 requires_grad=False)
                fake = Variable(Tensor(
                    np.zeros((imgs_lr.size(0), *discriminator.output_shape))),
                                requires_grad=False)

                # ------------------
                #  Train Generators
                # ------------------

                optimizer_G.zero_grad()

                # Generate a high resolution image from low resolution input
                gen_hr = generator(imgs_lr)

                # Measure pixel-wise loss against ground truth
                loss_pixel = criterion_pixel(gen_hr, imgs_hr)

                # Warm-up (pixel-wise loss only)
                if batches_done <= opt.warmup_batches:
                    loss_pixel.backward()
                    optimizer_G.step()
                    log_info = "[Epoch {}/{}] [Batch {}/{}] [G pixel: {}]".format(
                        epoch, opt.n_epochs, batch_num, len(dataloader),
                        loss_pixel.item())

                    sys.stdout.write("\r{}".format(log_info))
                    sys.stdout.flush()

                    mlflow.log_metric('train_{}'.format('loss_pixel'),
                                      loss_pixel.item(),
                                      step=batches_done)
                else:
                    # Extract validity predictions from discriminator
                    pred_real = discriminator(imgs_hr).detach()
                    pred_fake = discriminator(gen_hr)

                    # Adversarial loss (relativistic average GAN)
                    loss_GAN = criterion_GAN(
                        pred_fake - pred_real.mean(0, keepdim=True), valid)

                    # Content loss
                    gen_features = feature_extractor(gen_hr)
                    real_features = feature_extractor(imgs_hr).detach()
                    loss_content = criterion_content(gen_features,
                                                     real_features)

                    # Total generator loss
                    loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel

                    loss_G.backward()
                    optimizer_G.step()

                    # ---------------------
                    #  Train Discriminator
                    # ---------------------

                    optimizer_D.zero_grad()

                    pred_real = discriminator(imgs_hr)
                    pred_fake = discriminator(gen_hr.detach())

                    # Adversarial loss for real and fake images (relativistic average GAN)
                    loss_real = criterion_GAN(
                        pred_real - pred_fake.mean(0, keepdim=True), valid)
                    loss_fake = criterion_GAN(
                        pred_fake - pred_real.mean(0, keepdim=True), fake)

                    # Total loss
                    loss_D = (loss_real + loss_fake) / 2

                    loss_D.backward()
                    optimizer_D.step()

                    # --------------
                    #  Log Progress
                    # --------------

                    log_info = "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}, content: {}, adv: {}, pixel: {}]".format(
                        epoch,
                        opt.n_epochs,
                        batch_num,
                        len(dataloader),
                        loss_D.item(),
                        loss_G.item(),
                        loss_content.item(),
                        loss_GAN.item(),
                        loss_pixel.item(),
                    )

                    if batch_num == 1:
                        sys.stdout.write("\n{}".format(log_info))
                    else:
                        sys.stdout.write("\r{}".format(log_info))

                    sys.stdout.flush()

                    # import pdb; pdb.set_trace()

                    if batches_done % opt.sample_interval == 0:
                        # Save image grid with upsampled inputs and ESRGAN outputs
                        imgs_lr = nn.functional.interpolate(imgs_lr,
                                                            scale_factor=4)
                        img_grid = denormalize(torch.cat((imgs_lr, gen_hr),
                                                         -1))

                        image_batch_save_dir = osp.join(
                            image_train_save_dir, '{:07}'.format(batches_done))
                        os.makedirs(osp.join(image_batch_save_dir, "hr_image"),
                                    exist_ok=True)
                        save_image(img_grid,
                                   osp.join(image_batch_save_dir, "hr_image",
                                            "%d.png" % batches_done),
                                   nrow=1,
                                   normalize=False)

                    if batches_done % opt.checkpoint_interval == 0:
                        # Save model checkpoints
                        torch.save(
                            generator.state_dict(),
                            osp.join(weight_save_dir,
                                     "generator_%d.pth" % epoch))
                        torch.save(
                            discriminator.state_dict(),
                            osp.join(weight_save_dir,
                                     "discriminator_%d.pth" % epoch))

                    mlflow.log_metric('train_{}'.format('loss_D'),
                                      loss_D.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_G'),
                                      loss_G.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_content'),
                                      loss_content.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_GAN'),
                                      loss_GAN.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_pixel'),
                                      loss_pixel.item(),
                                      step=batches_done)
Ejemplo n.º 14
0
def down_and_up_sampling(image, save_name, upsampling):
    
    opt = setup()
    # create output folder
    try:
        os.makedirs('output/high_res_fake')
        os.makedirs('output/high_res_real')
        os.makedirs('output/low_res')
    except OSError:
        pass

    if torch.cuda.is_available() and not opt.cuda:
        print('[WARNING]: You have a CUDA device, so you should probably run with --cuda')

    transform = transforms.Compose([transforms.RandomCrop((
                                                                image.size[0],
                                                                image.size[1])),
                                    transforms.Pad(padding = 0),
                                    transforms.ToTensor()])
    normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                     std = [0.229, 0.224, 0.225])

    # [down sampling] down-sampling part
    scale = transforms.Compose([transforms.ToPILImage(),
                                transforms.Resize((int(image.size[1] / opt.upSampling), int(image.size[0] / opt.upSampling))),
                                transforms.Pad(padding=0),
                                transforms.ToTensor(),
                                transforms.Normalize(
                                                        mean = [0.485, 0.456, 0.406],
                                                        std = [0.229, 0.224, 0.225])])
    
    # Equivalent to un-normalizing ImageNet (for correct visualization)
    unnormalize = transforms.Normalize(
                                            mean = [-2.118, -2.036, -1.804],
                                            std = [4.367, 4.464, 4.444])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root = opt.dataroot, transform = transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root = opt.dataroot, download = True, train = False, transform = transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root = opt.dataroot, download = True, train = False, transform = transform)
    assert dataset

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size = opt.batchSize,
                                             shuffle = False,
                                             num_workers = int(opt.workers))

    my_loader = transforms.Compose([transforms.ToTensor()])
    image = my_loader(image)

    # [paras] loading paras from .pth files
    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))

    # For the content loss
    feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained = True))

    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    target_real = Variable(torch.ones(opt.batchSize, 1))
    target_fake = Variable(torch.zeros(opt.batchSize, 1))

    # if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        target_real = target_real.cuda()
        target_fake = target_fake.cuda()

    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0], opt.imageSize[1])

    # print('Test started...')
    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0

    # Set evaluation mode (not training)
    generator.eval()
    discriminator.eval()

    data = image
    for i in range(1):
        # Generate data
        high_res_real = data
        low_res = scale(high_res_real)
        low_res = torch.tensor([np.array(low_res)])
        high_res_real = normalize(high_res_real)
        high_res_real = torch.tensor([np.array(high_res_real)])
            
        # Generate real and fake inputs
        if opt.cuda:
            high_res_real = Variable(high_res_real.cuda())
            high_res_fake = generator(Variable(low_res).cuda())
        else:
            high_res_real = Variable(high_res_real)
            high_res_fake = generator(Variable(low_res)) # >>> create hr images

        save_image(unnormalize(high_res_real[0]), 'output/high_res_real/' + save_name)
        save_image(unnormalize(high_res_fake[0]), 'output/high_res_fake/' + save_name)
        save_image(unnormalize(low_res[0]), 'output/low_res/' + save_name)
Ejemplo n.º 15
0
                    help='Pass 1 to load checkpoint')
parser.add_argument('--b',
                    default=16,
                    type=int,
                    help='number of residual blocks in generator')
args = parser.parse_args()

# Load data
dataset = TrainDataset(args.root_dir)
dataloader = DataLoader(dataset,
                        args.batch_size,
                        True,
                        num_workers=args.num_workers)
# Initialize models
vgg = models.vgg19(pretrained=True)
feature_extractor = FeatureExtractor(vgg, 5, 4)
if torch.cuda.device_count() > 1:
    feature_extractor = nn.DataParallel(feature_extractor)
feature_extractor = feature_extractor.to(device)

disc = Discriminator()
if torch.cuda.device_count() > 1:
    disc = nn.DataParallel(disc)
disc = disc.to(device)
if args.load_checkpoint == 1 and os.path.exists('disc.pt'):
    disc.load_state_dict(torch.load('disc.pt'))
print(disc)

gen = Generator(args.b)
if torch.cuda.device_count() > 1:
    gen = nn.DataParallel(gen)
Ejemplo n.º 16
0
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

source_dataset = ImageFolder(args.source, transform=transform_source)
target_dataset = ImageFolder(args.target, transform=transform_target)

source_loader = DataLoader(source_dataset,
                           batch_size=args.batch_size,
                           shuffle=True)
target_loader = DataLoader(target_dataset,
                           batch_size=args.batch_size,
                           shuffle=True)

# models
F = FeatureExtractor(resnet=args.resnet_type).to(device)
C = LabelPredictor(resnet=args.resnet_type).to(device)
D = DomainClassifier(resnet=args.resnet_type).to(device)

class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.BCEWithLogitsLoss()

opt_F = optim.AdamW(F.parameters())
opt_C = optim.AdamW(C.parameters())
opt_D = optim.AdamW(D.parameters())

# train
F.train()
D.train()
C.train()
lamb, p, gamma, now, tot = 0, 0, 10, 0, len(source_loader) * args.n_epoch
Ejemplo n.º 17
0
seed = random.randint(1, 10000)
print("Random Seed: ", seed)
torch.manual_seed(seed)
if opt.cuda:
    torch.cuda.manual_seed(seed)

# build network
print('==>building network...')
generator = Generator(in_nc=opt.in_nc,
                      mid_nc=opt.mid_nc,
                      out_nc=opt.out_nc,
                      scale_factor=opt.scale_factor,
                      num_RRDBS=opt.num_RRDBs)
discriminator = Discriminator()
feature_extractor = FeatureExtractor()

# loss

# content loss
if opt.content_loss_type == 'L1_Charbonnier':
    content_loss = L1_Charbonnier_loss()
elif opt.content_loss_type == 'L1':
    content_loss = torch.nn.L1Loss()
elif opt.content_loss_type == 'L2':
    content_loss = torch.nn.MSELoss()

# pixel loss
if opt.pixel_loss_type == 'L1':
    pixel_loss = torch.nn.L1Loss()
elif opt.pixel_loss_type == 'L2':
Ejemplo n.º 18
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cifar100', help='cifar10 | cifar100 | folder')
    parser.add_argument('--dataroot', type=str, default='./data', help='path to dataset')
    parser.add_argument('--workers', type=int, default=2, help='number of data loading workers')
    parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
    parser.add_argument('--imageSize', type=int, default=15, help='the low resolution image size')
    parser.add_argument('--upSampling', type=int, default=2, help='low to high resolution scaling factor')
    parser.add_argument('--nEpochs', type=int, default=100, help='number of epochs to train for')
    parser.add_argument('--nPreEpochs', type=int, default=2, help='number of epochs to pre-train Generator')
    parser.add_argument('--generatorLR', type=float, default=0.0001, help='learning rate for generator')
    parser.add_argument('--discriminatorLR', type=float, default=0.0001, help='learning rate for discriminator')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--nGPU', type=int, default=1, help='number of GPUs to use')
    parser.add_argument('--generatorWeights', type=str, default='', help="path to generator weights (to continue training)")
    parser.add_argument('--discriminatorWeights', type=str, default='', help="path to discriminator weights (to continue training)")
    parser.add_argument('--out', type=str, default='checkpoints', help='folder to output model checkpoints')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.out)
    except OSError:
        pass

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    transform = transforms.Compose([transforms.RandomCrop(opt.imageSize*opt.upSampling),
                                    transforms.ToTensor()])

    normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                    std = [0.229, 0.224, 0.225])

    scale = transforms.Compose([transforms.ToPILImage(),
                                transforms.Scale(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                    std = [0.229, 0.224, 0.225])
                                ])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root=opt.dataroot, train=True, download=True, transform=transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root=opt.dataroot, train=True, download=True, transform=transform)
    assert dataset

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                             shuffle=True, num_workers=int(opt.workers))

    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))
    print(generator)

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
    print(discriminator)

    # For the content loss
    feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
    print(feature_extractor)
    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    ones_const = Variable(torch.ones(opt.batchSize, 1))

    # if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        ones_const = ones_const.cuda()

    optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR)
    optim_discriminator = optim.Adam(discriminator.parameters(), lr=opt.discriminatorLR)

    configure('logs/' + opt.dataset + '-' + str(opt.batchSize) + '-' + str(opt.generatorLR) + '-' + str(opt.discriminatorLR), flush_secs=5)
    visualizer = Visualizer(image_size=opt.imageSize*opt.upSampling)

    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

    # Pre-train generator using raw MSE loss
    print('Generator pre-training')
    for epoch in range(opt.nPreEpochs):
        mean_generator_content_loss = 0.0

        for i, data in enumerate(dataloader):
            # Generate data
            high_res_real, _ = data

            # Downsample images to low resolution
            if len(high_res_real) < opt.batchSize:  # skip final batch  , len = batchsize if not last batch else len < batchsize
                continue
            for j in range(opt.batchSize):  
                low_res[j] = scale(high_res_real[j])
                high_res_real[j] = normalize(high_res_real[j])

            # Generate real and fake inputs
            if opt.cuda:
                high_res_real = Variable(high_res_real.cuda())
                high_res_fake = generator(Variable(low_res).cuda())
            else:
                high_res_real = Variable(high_res_real)
                high_res_fake = generator(Variable(low_res))

            ######### Train generator #########
            generator.zero_grad()

            generator_content_loss = content_criterion(high_res_fake, high_res_real)

            mean_generator_content_loss += generator_content_loss.data

            generator_content_loss.backward()
            optim_generator.step()

            ######### Status and display #########
            sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f' % (epoch, opt.nPreEpochs, i, len(dataloader), generator_content_loss.data))
            visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data)

        sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f\n' % (epoch, 2, i, len(dataloader), mean_generator_content_loss/len(dataloader)))
        log_value('generator_mse_loss', mean_generator_content_loss/len(dataloader), epoch)
        
        # Do checkpointing every epoch
        # torch.save(generator.state_dict(), '%s/generator_pretrain_%s.pth' %(opt.out,str(epoch)))

    # Do checkpointing
    torch.save(generator.state_dict(), '%s/generator_pretrain.pth' % opt.out)

    # SRGAN training
    optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR*0.1)
    optim_discriminator = optim.Adam(discriminator.parameters(), lr=opt.discriminatorLR*0.1)

    print('SRGAN training')
    for epoch in range(opt.nEpochs):
        mean_generator_content_loss = 0.0
        mean_generator_adversarial_loss = 0.0
        mean_generator_total_loss = 0.0
        mean_discriminator_loss = 0.0

        for i, data in enumerate(dataloader):
            # Generate data
            high_res_real, _ = data

            # Downsample images to low resolution
            if len(high_res_real) < opt.batchSize:  # skip final batch  , len = batchsize if not last batch else len < batchsize
                continue
            for j in range(opt.batchSize): 
                low_res[j] = scale(high_res_real[j])
                high_res_real[j] = normalize(high_res_real[j])

            # Generate real and fake inputs
            if opt.cuda:
                high_res_real = Variable(high_res_real.cuda())
                high_res_fake = generator(Variable(low_res).cuda())
                target_real = Variable(torch.rand(opt.batchSize,1)*0.5 + 0.7).cuda()
                # size: opt.batchSize*1, and element is in 0.7~1.2
                target_fake = Variable(torch.rand(opt.batchSize,1)*0.3).cuda()
                # size: opt.batchSize*1, and element is in 0~0.3
            else:
                high_res_real = Variable(high_res_real)
                high_res_fake = generator(Variable(low_res))
                target_real = Variable(torch.rand(opt.batchSize,1)*0.5 + 0.7)
                target_fake = Variable(torch.rand(opt.batchSize,1)*0.3)

            ######### Train discriminator #########
            discriminator.zero_grad()

            discriminator_loss = adversarial_criterion(discriminator(high_res_real), target_real) + \
                                 adversarial_criterion(discriminator(Variable(high_res_fake.data)), target_fake)
            mean_discriminator_loss += discriminator_loss.data

            discriminator_loss.backward()
            optim_discriminator.step()

            ######### Train generator #########
            generator.zero_grad()

            real_features = Variable(feature_extractor(high_res_real).data)
            fake_features = feature_extractor(high_res_fake)

            # for content loss, we use total images' pixel-wise MSE loss and 0.006* VggLoss, which VggLoss is actual
            # MSE loss of some layers result(feature) in VggNet
            generator_content_loss = content_criterion(high_res_fake, high_res_real) + 0.006*content_criterion(fake_features, real_features)
            mean_generator_content_loss += generator_content_loss.data
            generator_adversarial_loss = adversarial_criterion(discriminator(high_res_fake), ones_const)
            mean_generator_adversarial_loss += generator_adversarial_loss.data

            generator_total_loss = generator_content_loss + 1e-3*generator_adversarial_loss
            mean_generator_total_loss += generator_total_loss.data

            generator_total_loss.backward()
            optim_generator.step()

            ######### Status and display #########
            sys.stdout.write('\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f' % (epoch, opt.nEpochs, i, len(dataloader),
                discriminator_loss.data, generator_content_loss.data, generator_adversarial_loss.data, generator_total_loss.data))
            visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data)

        sys.stdout.write('\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f\n' % (epoch, opt.nEpochs, i, len(dataloader),
        mean_discriminator_loss/len(dataloader), mean_generator_content_loss/len(dataloader),
        mean_generator_adversarial_loss/len(dataloader), mean_generator_total_loss/len(dataloader)))

        log_value('generator_content_loss', mean_generator_content_loss/len(dataloader), epoch)
        log_value('generator_adversarial_loss', mean_generator_adversarial_loss/len(dataloader), epoch)
        log_value('generator_total_loss', mean_generator_total_loss/len(dataloader), epoch)
        log_value('discriminator_loss', mean_discriminator_loss/len(dataloader), epoch)

        # Do checkpointing every epoch
        torch.save(generator.state_dict(), '%s/generator_final.pth' % opt.out)
        torch.save(discriminator.state_dict(), '%s/discriminator_final.pth' % opt.out)

    # Avoid closing
    print("train is over, and here can kill off threading after you watch the control log...")
    while True:
        pass
Ejemplo n.º 19
0
def main(args):
    #with torch.cuda.device(args.gpu):
    layers_map = {
        'relu4_2': '22',
        'relu2_2': '8',
        'relu3_2': '13',
        'relu1_2': '4'
    }

    vis = visdom.Visdom(port=args.display_port)

    loss_graph = {
        "g": [],
        "gd": [],
        "gf": [],
        "gpl": [],
        "gpab": [],
        "gs": [],
        "d": [],
        "gdl": [],
        "dl": [],
    }

    # for rgb the change is to feed 3 channels to D instead of just 1. and feed 3 channels to vgg.
    # can leave pixel separate between r and gb for now. assume user use the same weights
    transforms = get_transforms(args)

    if args.color_space == 'rgb':
        args.pixel_weight_ab = args.pixel_weight_rgb
        args.pixel_weight_l = args.pixel_weight_rgb

    rgbify = custom_transforms.toRGB()

    train_dataset = ImageFolder('train', args.data_path, transforms)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)

    val_dataset = ImageFolder('val', args.data_path, transforms)
    indices = torch.randperm(len(val_dataset))
    val_display_size = args.batch_size
    val_display_sampler = SequentialSampler(indices[:val_display_size])
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=val_display_size,
                            sampler=val_display_sampler)
    # renormalize = transforms.Normalize(mean=[+0.5+0.485, +0.5+0.456, +0.5+0.406], std=[0.229, 0.224, 0.225])

    feat_model = models.vgg19(pretrained=True)
    netG, netD, netD_local = get_models(args)

    criterion_gan, criterion_pixel_l, criterion_pixel_ab, criterion_style, criterion_feat, criterion_texturegan = get_criterions(
        args)

    real_label = 1
    fake_label = 0

    optimizerD = optim.Adam(netD.parameters(),
                            lr=args.learning_rate_D,
                            betas=(0.5, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=args.learning_rate,
                            betas=(0.5, 0.999))
    optimizerD_local = optim.Adam(netD_local.parameters(),
                                  lr=args.learning_rate_D_local,
                                  betas=(0.5, 0.999))

    with torch.cuda.device(args.gpu):
        netG.cuda()
        netD.cuda()
        netD_local.cuda()
        feat_model.cuda()
        criterion_gan.cuda()
        criterion_pixel_l.cuda()
        criterion_pixel_ab.cuda()
        criterion_feat.cuda()
        criterion_texturegan.cuda()

        input_stack = torch.FloatTensor().cuda()
        target_img = torch.FloatTensor().cuda()
        target_texture = torch.FloatTensor().cuda()
        segment = torch.FloatTensor().cuda()
        label = torch.FloatTensor(args.batch_size).cuda()
        label_local = torch.FloatTensor(args.batch_size).cuda()
        extract_content = FeatureExtractor(feat_model.features,
                                           [layers_map[args.content_layers]])
        extract_style = FeatureExtractor(
            feat_model.features,
            [layers_map[x.strip()] for x in args.style_layers.split(',')])

        model = {
            "netG": netG,
            "netD": netD,
            "netD_local": netD_local,
            "criterion_gan": criterion_gan,
            "criterion_pixel_l": criterion_pixel_l,
            "criterion_pixel_ab": criterion_pixel_ab,
            "criterion_feat": criterion_feat,
            "criterion_style": criterion_style,
            "criterion_texturegan": criterion_texturegan,
            "real_label": real_label,
            "fake_label": fake_label,
            "optimizerD": optimizerD,
            "optimizerD_local": optimizerD_local,
            "optimizerG": optimizerG
        }

        for epoch in range(args.load_epoch, args.num_epoch):
            train(model, train_loader, val_loader, input_stack, target_img,
                  target_texture, segment, label, label_local, extract_content,
                  extract_style, loss_graph, vis, epoch, args)
Ejemplo n.º 20
0
        model_D = network.discriminator_snIns().cuda()
        #model_local_D = SNnetwork.Discriminator(3, 64).cuda()
        model_local_D = network.discriminator_snIns().cuda()
elif network_type == 'nlayerD':
    model_G = network.generator().cuda()
    model_D = network.NLayerDiscriminator(input_nc=3, ndf=64,
                                          n_layers=3).cuda()
    model_local_D = network.NLayerDiscriminator(input_nc=3, ndf=64,
                                                n_layers=3).cuda()
else:
    model_G = network.generator().cuda()
    model_D = network.discriminator().cuda()
    model_local_D = network.discriminator().cuda()
feat_model = tmodels.vgg19(pretrained=True).cuda()
extract_content = FeatureExtractor(
    feat_model.features,
    [layers_map[x.strip()] for x in content_layers.split(',')])
extract_style = FeatureExtractor(
    feat_model.features,
    [layers_map[x.strip()] for x in style_layers.split(',')])

# loss criterion
BCE_loss = nn.BCELoss().cuda()
MSE_loss = nn.MSELoss().cuda()
TV_loss = TVLoss().cuda()
criterion = nn.L1Loss().cuda()

# Adam optimizer
#G_optimizer = torch.optim.Adam(model_G.parameters(), lr=learning_rate, weight_decay=1e-5)
G_optimizer = torch.optim.Adam(model_G.parameters(),
                               lr=learning_rate,
Ejemplo n.º 21
0
                                         batch_size=opt.batchSize,
                                         shuffle=True,
                                         num_workers=int(opt.workers))

generator = Generator(16, opt.upSampling)
if opt.generatorWeights != '':
    generator.load_state_dict(torch.load(opt.generatorWeights))
print generator

discriminator = Discriminator()
if opt.discriminatorWeights != '':
    discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
print discriminator

# For the content loss
feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
print feature_extractor
content_criterion = nn.MSELoss()
adversarial_criterion = nn.BCELoss()

ones_const = Variable(torch.ones(opt.batchSize, 1))

# if gpu is to be used
if opt.cuda:
    generator.cuda()
    discriminator.cuda()
    feature_extractor.cuda()
    content_criterion.cuda()
    adversarial_criterion.cuda()
    ones_const = ones_const.cuda()
Ejemplo n.º 22
0
    train_dataset = TrainDatasetFromFolder('data/DIV2K_train_HR/Train_HR', crop_size=opt.crop_size,
                                           upscale_factor=opt.upSampling)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=opt.batchSize, shuffle=True,
                                  num_workers=4)
    val_dataset = ValDatasetFromFolder('data/DIV2K_valid_HR/Val_HR', upscale_factor=opt.upSampling)

    # 使用loader,从训练集中,一次性处理一个batch的文件 (批量加载器)
    val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=4)

    generator = Generator(3, filters=64, num_res_blocks=opt.residual_blocks, up_scale=opt.upSampling).to(device)
    # load pretrain model
    checkpoint = torch.load(opt.generator_pretrainWeights)
    generator.load_state_dict(checkpoint['generator_model_pre'])
    print('Load Generator pre successfully!')
    discriminator = Discriminator(in_channels=3, out_filters=64).to(device)
    feature_extractor = FeatureExtractor().to(device)

    feature_extractor.eval()

    # 内容损失和对抗损失
    criterion_pixel = torch.nn.L1Loss().to(device)  # 像素差的绝对值
    content_criterion = torch.nn.L1Loss().to(device)
    adversarial_criterion = torch.nn.BCEWithLogitsLoss().to(device)  # 交叉熵

    Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

    # tensorboard --logdir=logs
    configure(
        'logs/' + opt.train_dataroot + '-' + str(64) + '-' + str(opt.generatorLR) + '-' + str(opt.discriminatorLR),
        flush_secs=5)
Ejemplo n.º 23
0
                                         batch_size=opt.batchSize,
                                         shuffle=True,
                                         num_workers=int(opt.workers))

G = Generator(10, opt.upSampling)
if opt.generatorWeights != '':
    G.load_state_dict(torch.load(opt.generatorWeights))
print(G)

D = Discriminator()
if opt.discriminatorWeights != '':
    D.load_state_dict(torch.load(opt.discriminatorWeights))
print(D)

# For the content loss
FE = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
print(FE)
content_criterion = nn.MSELoss()
adversarial_criterion = nn.BCELoss()

ones_const = Variable(torch.ones(opt.batchSize, 1))

# if gpu is to be used
if opt.cuda:
    G.cuda()
    D.cuda()
    FE.cuda()
    content_criterion.cuda()
    adversarial_criterion.cuda()
    ones_const = ones_const.cuda()
Ejemplo n.º 24
0
    loss = logloss(d.unsqueeze(1), y)

    return loss


def get_sync_loss(mel, g):
    g = g[:, :, :, g.size(3) // 2:]
    g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
    # B, 3 * T, H//2, W
    a, v = syncnet(mel, g)
    y = torch.ones(g.size(0), 1).float().to(device)
    return cosine_loss(a, v, y)


recon_loss = nn.L1Loss()
feature_extractor = FeatureExtractor()
feature_extractor.eval()


# --------- Add content loss here ---------------
def get_content_loss(g, gt):

    gen_feautres = feature_extractor(g)
    real_features = feature_extractor(gt)
    loss_content = recon_loss(gen_feautres, real_features.detach())

    return loss_content


def train(device,
          model,
Ejemplo n.º 25
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset',
                        type=str,
                        default='folder',
                        help='cifar10 | cifar100 | folder')
    parser.add_argument('--dataroot',
                        type=str,
                        default='./data',
                        help='path to dataset')
    parser.add_argument('--workers',
                        type=int,
                        default=1,
                        help='number of data loading workers')
    parser.add_argument('--batchSize',
                        type=int,
                        default=1,
                        help='input batch size')
    parser.add_argument('--imageSize',
                        type=int,
                        default=32,
                        help='the low resolution image size')
    parser.add_argument('--upSampling',
                        type=int,
                        default=4,
                        help='low to high resolution scaling factor')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--nGPU',
                        type=int,
                        default=1,
                        help='number of GPUs to use')
    parser.add_argument(
        '--generatorWeights',
        type=str,
        default='checkpoints/generator_final.pth',
        help="path to generator weights (to continue training)")
    parser.add_argument(
        '--discriminatorWeights',
        type=str,
        default='checkpoints/discriminator_final.pth',
        help="path to discriminator weights (to continue training)")

    opt = parser.parse_args()
    print(opt)

    if not os.path.exists('output/high_res_fake'):
        os.makedirs('output/high_res_fake')
    if not os.path.exists('output/high_res_real'):
        os.makedirs('output/high_res_real')
    if not os.path.exists('output/low_res'):
        os.makedirs('output/low_res')

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    transform = transforms.Compose([
        transforms.RandomCrop(opt.imageSize * opt.upSampling),
        transforms.ToTensor()
    ])

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Scale(opt.imageSize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Equivalent to un-normalizing ImageNet (for correct visualization)
    unnormalize = transforms.Normalize(mean=[-2.118, -2.036, -1.804],
                                       std=[4.367, 4.464, 4.444])

    if opt.dataset == 'folder':
        # folder dataset
        dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform)
    elif opt.dataset == 'cifar10':
        dataset = datasets.CIFAR10(root=opt.dataroot,
                                   download=True,
                                   train=False,
                                   transform=transform)
    elif opt.dataset == 'cifar100':
        dataset = datasets.CIFAR100(root=opt.dataroot,
                                    download=True,
                                    train=False,
                                    transform=transform)
    assert dataset

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=False,
                                             num_workers=int(opt.workers))

    generator = Generator(16, opt.upSampling)
    if opt.generatorWeights != '':
        generator.load_state_dict(torch.load(opt.generatorWeights))
    print(generator)

    discriminator = Discriminator()
    if opt.discriminatorWeights != '':
        discriminator.load_state_dict(torch.load(opt.discriminatorWeights))
    print(discriminator)

    # For the content loss
    feature_extractor = FeatureExtractor(
        torchvision.models.vgg19(pretrained=True))
    print(feature_extractor)
    content_criterion = nn.MSELoss()
    adversarial_criterion = nn.BCELoss()

    target_real = Variable(torch.ones(opt.batchSize, 1))
    target_fake = Variable(torch.zeros(opt.batchSize, 1))

    # if gpu is to be used
    if opt.cuda:
        generator.cuda()
        discriminator.cuda()
        feature_extractor.cuda()
        content_criterion.cuda()
        adversarial_criterion.cuda()
        target_real = target_real.cuda()
        target_fake = target_fake.cuda()

    low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

    print('Test started...')
    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0

    # Set evaluation mode (not training)
    generator.eval()
    discriminator.eval()

    for i, data in enumerate(dataloader):
        # Generate data
        high_res_real, _ = data

        # Downsample images to low resolution
        if len(
                high_res_real
        ) < opt.batchSize:  # skip final batch  , len = batchsize if not last batch else len < batchsize
            continue
        for j in range(opt.batchSize):
            low_res[j] = scale(high_res_real[j])
            high_res_real[j] = normalize(high_res_real[j])

        # Generate real and fake inputs
        if opt.cuda:
            high_res_real = Variable(high_res_real.cuda())
            high_res_fake = generator(Variable(low_res).cuda())
        else:
            high_res_real = Variable(high_res_real)
            high_res_fake = generator(Variable(low_res))

        ######### Test discriminator #########

        discriminator_loss = adversarial_criterion(discriminator(high_res_real), target_real) + \
                                adversarial_criterion(discriminator(high_res_fake), target_fake)
        mean_discriminator_loss += discriminator_loss.data

        ######### Test generator #########

        real_features = feature_extractor(high_res_real)
        fake_features = feature_extractor(high_res_fake)

        generator_content_loss = content_criterion(
            high_res_fake, high_res_real) + 0.006 * content_criterion(
                fake_features, real_features)
        mean_generator_content_loss += generator_content_loss.data
        generator_adversarial_loss = adversarial_criterion(
            discriminator(high_res_fake), target_real)
        mean_generator_adversarial_loss += generator_adversarial_loss.data

        generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss
        mean_generator_total_loss += generator_total_loss.data

        ######### Status and display #########
        sys.stdout.write(
            '\r[%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f'
            % (i, len(dataloader), discriminator_loss.data,
               generator_content_loss.data, generator_adversarial_loss.data,
               generator_total_loss.data))

        if len(
                high_res_real
        ) < opt.batchSize:  # skip final batch  , len = batchsize if not last batch else len < batchsize
            continue
        for j in range(opt.batchSize):
            save_image(
                unnormalize(high_res_real[j].cpu()),
                'output/high_res_real/' + str(i * opt.batchSize + j) + '.png')
            save_image(
                unnormalize(high_res_fake[j].cpu()),
                'output/high_res_fake/' + str(i * opt.batchSize + j) + '.png')
            #save_image(high_res_real[j], 'output/high_res_real/' + str(i*opt.batchSize + j) + '.png') # without normlize, will mis-color real
            #save_image(high_res_fake[j], 'output/high_res_fake/' + str(i*opt.batchSize + j) + '.png')
            save_image(unnormalize(low_res[j]),
                       'output/low_res/' + str(i * opt.batchSize + j) + '.png')

    sys.stdout.write(
        '\r[%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f\n'
        % (i, len(dataloader), mean_discriminator_loss / len(dataloader),
           mean_generator_content_loss / len(dataloader),
           mean_generator_adversarial_loss / len(dataloader),
           mean_generator_total_loss / len(dataloader)))
Ejemplo n.º 26
0
class SACAgent:
    def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.firsttime = 0

        self.env = env
        self.action_range = [env.action_space.low, env.action_space.high]
        #self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]  #1

        self.conv_channels = 4
        self.kernel_size = (3, 3)

        self.img_size = (500, 500, 3)

        print("Diagnostics:")
        print(f"action_range: {self.action_range}")
        #print(f"obs_dim: {self.obs_dim}")
        print(f"action_dim: {self.action_dim}")

        # hyperparameters
        self.gamma = gamma
        self.tau = tau
        self.update_step = 0
        self.delay_step = 2

        # initialize networks
        self.feature_net = FeatureExtractor(self.img_size[2],
                                            self.conv_channels,
                                            self.kernel_size).to(self.device)
        print("Feature net init'd successfully")

        input_dim = self.feature_net.get_output_size(self.img_size)
        self.input_size = input_dim[0] * input_dim[1] * input_dim[2]
        print(f"input_size: {self.input_size}")

        self.value_net = ValueNetwork(self.input_size, 1).to(self.device)
        self.target_value_net = ValueNetwork(self.input_size,
                                             1).to(self.device)
        self.q_net1 = SoftQNetwork(self.input_size,
                                   self.action_dim).to(self.device)
        self.q_net2 = SoftQNetwork(self.input_size,
                                   self.action_dim).to(self.device)
        self.policy_net = PolicyNetwork(self.input_size,
                                        self.action_dim).to(self.device)

        print("Finished initing all nets")

        # copy params to target param
        for target_param, param in zip(self.target_value_net.parameters(),
                                       self.value_net.parameters()):
            target_param.data.copy_(param)

        print("Finished copying targets")

        # initialize optimizers
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr)
        self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
        self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(),
                                           lr=policy_lr)

        print("Finished initing optimizers")

        self.replay_buffer = BasicBuffer(buffer_maxlen)
        print("End of init")

    def get_action(self, state):
        if state.shape != self.img_size:
            print(
                f"Invalid size, expected shape {self.img_size}, got {state.shape}"
            )
            return None

        inp = torch.from_numpy(state).float().permute(2, 0, 1).unsqueeze(0).to(
            self.device)
        features = self.feature_net(inp)
        features = features.view(-1, self.input_size)

        mean, log_std = self.policy_net.forward(features)
        std = log_std.exp()

        normal = Normal(mean, std)
        z = normal.sample()
        action = torch.tanh(z)
        action = action.cpu().detach().squeeze(0).numpy()

        return self.rescale_action(action)

    def rescale_action(self, action):
        return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
            (self.action_range[1] + self.action_range[0]) / 2.0

    def update(self, batch_size):
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(
            batch_size)

        # states and next states are lists of ndarrays, np.stack converts them to
        # ndarrays of shape (batch_size, height, width, num_channels)
        states = np.stack(states)
        next_states = np.stack(next_states)

        states = torch.FloatTensor(states).permute(0, 3, 1, 2).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).permute(0, 3, 1,
                                                             2).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        dones = dones.view(dones.size(0), -1)

        # Process images
        features = self.feature_net(
            states)  #.contiguous() # Properly shaped due to batching
        next_features = self.feature_net(next_states)  #.contiguous()

        features = torch.reshape(features, (64, self.input_size))
        next_features = torch.reshape(next_features, (64, self.input_size))

        next_actions, next_log_pi = self.policy_net.sample(next_features)
        next_q1 = self.q_net1(next_features, next_actions)
        next_q2 = self.q_net2(next_features, next_actions)
        next_v = self.target_value_net(next_features)

        next_v_target = torch.min(next_q1, next_q2) - next_log_pi
        curr_v = self.value_net.forward(features)
        v_loss = F.mse_loss(curr_v, next_v_target.detach())

        # q loss
        expected_q = rewards + (1 - dones) * self.gamma * next_v
        curr_q1 = self.q_net1.forward(features, actions)
        curr_q2 = self.q_net2.forward(features, actions)
        q1_loss = F.mse_loss(curr_q1, expected_q.detach())
        q2_loss = F.mse_loss(curr_q2, expected_q.detach())

        # update value and q networks
        self.value_optimizer.zero_grad()
        v_loss.backward(retain_graph=True)
        self.value_optimizer.step()

        self.q1_optimizer.zero_grad()
        q1_loss.backward(retain_graph=True)
        self.q1_optimizer.step()

        self.q2_optimizer.zero_grad()
        q2_loss.backward(retain_graph=True)
        self.q2_optimizer.step()

        # delayed update for policy network and target q networks
        if self.update_step % self.delay_step == 0:
            new_actions, log_pi = self.policy_net.sample(features)
            min_q = torch.min(self.q_net1.forward(features, new_actions),
                              self.q_net2.forward(features, new_actions))
            policy_loss = (log_pi - min_q).mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward(retain_graph=True)
            self.policy_optimizer.step()

            # target networks
            for target_param, param in zip(self.target_value_net.parameters(),
                                           self.value_net.parameters()):
                target_param.data.copy_(self.tau * param +
                                        (1 - self.tau) * target_param)

        self.update_step += 1