def network_initializers(self, hr_shape, use_LeakyReLU_Mish=False): generator = GeneratorRRDB(self.opt.channels, filters=64, num_res_blocks=self.opt.residual_blocks, use_LeakyReLU_Mish=use_LeakyReLU_Mish).to( self.device, non_blocking=True) discriminator = Discriminator( input_shape=(self.opt.channels, *hr_shape), use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(self.device, non_blocking=True) feature_extractor = FeatureExtractor().to(self.device, non_blocking=True) # Set feature extractor to inference mode feature_extractor.eval() return discriminator, feature_extractor, generator
def _set_model(self, device, hr_shape): # Initialize generator and discriminator self.generator = GeneratorRRDB( opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device) self.discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device) self.feature_extractor = FeatureExtractor().to(device) # Set feature extractor to inference mode self.feature_extractor.eval() # Losses self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device) self.criterion_content = torch.nn.L1Loss().to(device) self.criterion_pixel = torch.nn.L1Loss().to(device)
def print_network(): opt = setup() generator = Generator(16, opt.upSampling) if opt.generatorWeights != '': generator.load_state_dict(torch.load(opt.generatorWeights)) discriminator = Discriminator() if opt.discriminatorWeights != '': discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) feature_extractor = FeatureExtractor( torchvision.models.vgg19(pretrained=True)) printer('generator') summary(generator.cuda(), (3, 32, 32)) printer('discriminator') summary(discriminator.cuda(), (3, 32, 32)) printer('feature_extractor') summary(feature_extractor.cuda(), (3, 32, 32))
def get_prediction(model_checkpoint, resnet_type): # models F = FeatureExtractor(resnet=resnet_type).to(device) C = LabelPredictor(resnet=resnet_type).to(device) checkpoint = torch.load(model_checkpoint) F.load_state_dict(checkpoint['feature_extractor']) C.load_state_dict(checkpoint['label_predictor']) # predict F.eval() C.eval() result = [] for i, (data, _) in enumerate(target_loader): print(i + 1, len(target_loader), end='\r') data = data.to(device) logits = C(F(data)) x = torch.argmax(logits, dim=1).cpu().detach().numpy() result.append(x) # delete model del F del C torch.cuda.empty_cache() return np.concatenate(result)
def __init__(self): super(A3Cagent, self).__init__() self.Conv = FeatureExtractor() self.A, self.C = Actor(), Critic() # Try loading checkpoints if LOAD_CHECKPOINTS: self.load_weights() self.opt = torch.optim.RMSprop(self.parameters(), lr=LEARNING_RATE) self.mem = [[], [], []] # Stores log_probs, values, rewards during episode self.total_entropy = 0 self.steps = 0
def init(opt): # [folder] create folder for checkpoints try: os.makedirs(opt.out) except OSError: pass # [cuda] check cuda, if cuda is available, then display warning if torch.cuda.is_available() and not opt.cuda: sys.stdout.write('[WARNING] : You have a CUDA device, so you should probably run with --cuda') # [normalization] __return__ normalize images, set up mean and std normalize = transforms.Normalize( mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) # [scale] __return__ scale = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(opt.imageSize), transforms.ToTensor(), transforms.Normalize( mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])]) # [transform] up sampling transforms transform = transforms.Compose([transforms.RandomCrop((opt.imageSize[0] * opt.upSampling, opt.imageSize[1] * opt.upSampling)), transforms.ToTensor()]) # [dataset] training dataset if opt.dataset == 'folder': dataset = datasets.ImageFolder(root = opt.dataroot, transform = transform) elif opt.dataset == 'cifar10': dataset = datasets.CIFAR10(root = opt.dataroot, train = True, download = True, transform = transform) elif opt.dataset == 'cifar100': dataset = datasets.CIFAR100(root = opt.dataroot, train = True, download = False, transform = transform) assert dataset # [dataloader] __return__ loading dataset dataloader = torch.utils.data.DataLoader( dataset, batch_size = opt.batchSize, shuffle = True, num_workers = int(opt.workers)) # [generator] __return__ generator of GAN generator = Generator(16, opt.upSampling) if opt.generatorWeights != '' and os.path.exists(opt.generatorWeights): generator.load_state_dict(torch.load(opt.generatorWeights)) # [discriminator] __return__ discriminator of GAN discriminator = Discriminator() if opt.discriminatorWeights != '' and os.path.exists(opt.discriminatorWeights): discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) # [extractor] __return__ feature extractor of GAN # For the content loss feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained = True)) # [loss] __return__ loss function content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() ones_const = Variable(torch.ones(opt.batchSize, 1)) # [cuda] if gpu is to be used if opt.cuda: generator.cuda() discriminator.cuda() feature_extractor.cuda() content_criterion.cuda() adversarial_criterion.cuda() ones_const = ones_const.cuda() # [optimizer] __return__ Optimizer for GAN optim_generator = optim.Adam(generator.parameters(), lr = opt.generatorLR) optim_discriminator = optim.Adam(discriminator.parameters(), lr = opt.discriminatorLR) # record configure configure('logs/{}-{}-{} -{}'.format(opt.dataset, str(opt.batchSize), str(opt.generatorLR), str(opt.discriminatorLR)), flush_secs = 5) # visualizer = Visualizer(image_size = (opt.imageSize[0] * opt.upSampling, opt.imageSize[1] * opt.upSampling)) # __return__ low resolution images low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0], opt.imageSize[1]) return normalize,\ scale,\ dataloader,\ generator,\ discriminator,\ feature_extractor,\ content_criterion,\ adversarial_criterion,\ ones_const,\ optim_generator,\ optim_discriminator,\ low_res
return np.random.choice(len(policy), 1, p=policy)[0] def to_tensor(x, dtype=None): return torch.tensor(x, dtype=dtype).unsqueeze(0) if __name__ == '__main__': env = gym.make('CartPole-v1') # Actor Critic actor = Actor(n_actions=env.action_space.n, space_dims=4, hidden_dims=32) critic = Critic(space_dims=4, hidden_dims=32) # ICM feature_extractor = FeatureExtractor(env.observation_space.shape[0], 32) forward_model = ForwardModel(env.action_space.n, 32) inverse_model = InverseModel(env.action_space.n, 32) # Actor Critic a_optim = torch.optim.Adam(actor.parameters(), lr=args.lr_actor) c_optim = torch.optim.Adam(critic.parameters(), lr=args.lr_critic) # ICM icm_params = list(feature_extractor.parameters()) + list( forward_model.parameters()) + list(inverse_model.parameters()) icm_optim = torch.optim.Adam(icm_params, lr=args.lr_icm) pg_loss = PGLoss() mse_loss = nn.MSELoss() xe_loss = nn.CrossEntropyLoss()
def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.firsttime = 0 self.env = env self.action_range = [env.action_space.low, env.action_space.high] #self.obs_dim = env.observation_space.shape[0] self.action_dim = env.action_space.shape[0] #1 self.conv_channels = 4 self.kernel_size = (3, 3) self.img_size = (500, 500, 3) print("Diagnostics:") print(f"action_range: {self.action_range}") #print(f"obs_dim: {self.obs_dim}") print(f"action_dim: {self.action_dim}") # hyperparameters self.gamma = gamma self.tau = tau self.update_step = 0 self.delay_step = 2 # initialize networks self.feature_net = FeatureExtractor(self.img_size[2], self.conv_channels, self.kernel_size).to(self.device) print("Feature net init'd successfully") input_dim = self.feature_net.get_output_size(self.img_size) self.input_size = input_dim[0] * input_dim[1] * input_dim[2] print(f"input_size: {self.input_size}") self.value_net = ValueNetwork(self.input_size, 1).to(self.device) self.target_value_net = ValueNetwork(self.input_size, 1).to(self.device) self.q_net1 = SoftQNetwork(self.input_size, self.action_dim).to(self.device) self.q_net2 = SoftQNetwork(self.input_size, self.action_dim).to(self.device) self.policy_net = PolicyNetwork(self.input_size, self.action_dim).to(self.device) print("Finished initing all nets") # copy params to target param for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): target_param.data.copy_(param) print("Finished copying targets") # initialize optimizers self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr) self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr) self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) print("Finished initing optimizers") self.replay_buffer = BasicBuffer(buffer_maxlen) print("End of init")
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='folder', help='cifar10 | cifar100 | folder') parser.add_argument('--dataroot', type=str, default='./data', help='path to dataset') parser.add_argument('--workers', type=int, default=2, help='number of data loading workers') parser.add_argument('--batchSize', type=int, default=1, help='input batch size') parser.add_argument('--upSampling', type=int, default=4, help='low to high resolution scaling factor') parser.add_argument('--cuda', action='store_true', help='enables cuda') parser.add_argument('--nGPU', type=int, default=2, help='number of GPUs to use') parser.add_argument('--generatorWeights', type=str, default='checkpoints/generator_final.pth', help="path to generator weights (to continue training)") parser.add_argument('--discriminatorWeights', type=str, default='checkpoints/discriminator_final.pth', help="path to discriminator weights (to continue training)") opt = parser.parse_args() print(opt) if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") transform = transforms.Compose([transforms.ToTensor()]) normalize = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) ]) # Equivalent to un-normalizing ImageNet (for correct visualization) unnormalize = transforms.Normalize(mean = [-2.118, -2.036, -1.804], std = [4.367, 4.464, 4.444]) if opt.dataset == 'folder': # folder dataset dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform) elif opt.dataset == 'cifar10': dataset = datasets.CIFAR10(root=opt.dataroot, download=True, train=False, transform=transform) elif opt.dataset == 'cifar100': dataset = datasets.CIFAR100(root=opt.dataroot, download=True, train=False, transform=transform) assert dataset #print(dataset) image_name = dataset.imgs # image path dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=False, num_workers=int(opt.workers)) generator = Generator(16, opt.upSampling) if opt.generatorWeights != '': generator.load_state_dict(torch.load(opt.generatorWeights)) print(generator) discriminator = Discriminator() if opt.discriminatorWeights != '': discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) print(discriminator) # For the content loss feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True)) print(feature_extractor) # if gpu is to be used if opt.cuda: #generator.cuda() #discriminator.cuda() #feature_extractor.cuda() gpu_ids = [0,2] torch.cuda.set_device(gpu_ids[0]) generator = torch.nn.DataParallel(generator, device_ids=gpu_ids).cuda() discriminator = torch.nn.DataParallel(discriminator, device_ids=gpu_ids).cuda() feature_extractor = torch.nn.DataParallel(feature_extractor, device_ids=gpu_ids).cuda() print('Test started...') # Set evaluation mode (not training) generator.eval() discriminator.eval() #print(len(dataloader)) for i, data in enumerate(dataloader): # Generate data low_res, _ = data #print(low_res.shape) #print(image_name[i]) # eg: image_type_path, image_detail_name = bounding_box_test , -1_c1s3_065901_04.jpg image_type_path, image_detail_name = [],[] for j in range(len(low_res)): # not opt.batchSize means never skip final batch # 'replace' is for window path issue image_type_path.append(image_name[i*opt.batchSize+j][0].replace('\\','/').split('/')[-2]) image_detail_name.append(image_name[i*opt.batchSize+j][0].replace('\\','/').split('/')[-1]) #print(len(image_type_path), len(image_detail_name)) for j in range(len(low_res)): # never skip final batch low_res[j] = normalize(low_res[j]) # Generate real and fake inputs if opt.cuda: high_res_fake = generator(Variable(low_res).cuda()) else: high_res_fake = generator(Variable(low_res)) # high_res_fake = high_res_fake.to(torch.device('cuda:0')) for j in range(len(low_res)): # not opt.batchSize means never skip final batch print(image_type_path[j],image_detail_name[j]) if not os.path.exists('output/high_res_fake/A/{}'.format(image_type_path[j])): os.makedirs('output/high_res_fake/A/{}'.format(image_type_path[j])) #print(high_res_fake[j]) #print(low_res[j]) # if use unnormalize would lead error, why? sys say `high_res_real[j]` is not cuda but i print it show is cuda # must be cpu??? # comment 1,uncommnet 2 when get full market high_res dataset; uncomment 1,comment 2 when only test some images # 1. when only test some images ################################################################################################## #save_image(unnormalize(high_res_fake[j].cpu()), 'output/high_res_fake/' + str(i * opt.batchSize + j) + '.png') #save_image(unnormalize(low_res[j]), 'output/low_res/' + str(i*opt.batchSize + j) + '.png') # save raw low_res images ################################################################################################## # 2. when get full dataset ################################################################################################## save_image(unnormalize(high_res_fake[j].cpu()), 'output/high_res_fake/A/{}/{}'.format(image_type_path[j], image_detail_name[j]))
def upsampling(path, picture_name, upsampling): opt = setup() # image = Image.open(os.getcwd() + r'\images\\' + path) image = Image.open(path) opt.imageSize = (image.size[1], image.size[0]) log = '>>> process image : {} size : ({}, {}) sr_reconstruct size : ({}, {})'.format( picture_name, image.size[0], image.size[1], image.size[0] * upsampling, image.size[1] * upsampling) try: os.makedirs(os.getcwd() + r'\output\result') except OSError: pass if torch.cuda.is_available() and not opt.cuda: print( '[WARNING] : You have a CUDA device, so you should probably run with --cuda' ) transform = transforms.Compose([ transforms.RandomCrop(opt.imageSize), transforms.Pad(padding=0), transforms.ToTensor() ]) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Equivalent to un-normalizing ImageNet (for correct visualization) unnormalize = transforms.Normalize(mean=[-2.118, -2.036, -1.804], std=[4.367, 4.464, 4.444]) scale = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(opt.imageSize), transforms.Pad(padding=0), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if opt.dataset == 'folder': # folder dataset dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform) elif opt.dataset == 'cifar10': dataset = datasets.CIFAR10(root=opt.dataroot, download=True, train=False, transform=transform) elif opt.dataset == 'cifar100': dataset = datasets.CIFAR100(root=opt.dataroot, download=True, train=False, transform=transform) assert dataset dataloader = transforms.Compose([transforms.ToTensor()]) image = dataloader(image) # loading paras from networks generator = Generator(16, opt.upSampling) if opt.generatorWeights != '': generator.load_state_dict(torch.load(opt.generatorWeights)) discriminator = Discriminator() if opt.discriminatorWeights != '': discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) # For the content loss feature_extractor = FeatureExtractor( torchvision.models.vgg19(pretrained=True)) content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() target_real = Variable(torch.ones(opt.batchSize, 1)) target_fake = Variable(torch.zeros(opt.batchSize, 1)) # if gpu is to be used if opt.cuda: generator.cuda() discriminator.cuda() feature_extractor.cuda() content_criterion.cuda() adversarial_criterion.cuda() target_real = target_real.cuda() target_fake = target_fake.cuda() low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0], opt.imageSize[1]) # Set evaluation mode (not training) generator.eval() discriminator.eval() # Generate data high_res_real = image # Downsample images to low resolution low_res = scale(high_res_real) low_res = torch.tensor([np.array(low_res)]) high_res_real = normalize(high_res_real) high_res_real = torch.tensor([np.array(high_res_real)]) # Generate real and fake inputs if opt.cuda: high_res_real = Variable(high_res_real.cuda()) high_res_fake = generator(Variable(low_res).cuda()) else: high_res_real = Variable(high_res_real) high_res_fake = generator(Variable(low_res)) save_image(unnormalize(high_res_fake[0]), './output/result/' + picture_name) return log
# Init os.makedirs("saved_models", exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" writer = SummaryWriter() # Get models hr_shape = (opt.hr_height, opt.hr_width) generator = Generator(filters=64, num_res_blocks=opt.residual_blocks, num_upsample=opt.num_upsample) \ .to(device).train() discriminator = Discriminator() \ .to(device).train() feature_extractor = FeatureExtractor() \ .to(device).eval() if opt.netG_checkpoint: try: generator.load_state_dict( torch.load(opt.netG_checkpoint, map_location="cpu")) print( f"[x] Restored generator weights from: {opt.netG_checkpoint}") except: print("[!] Generator weights from scratch.") if opt.netD_checkpoint: try: discriminator.load_state_dict( torch.load(opt.netD_checkpoint, map_location="cpu")) print( f"[x] Restored discriminator weights from: {opt.netD_checkpoint}"
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Fetch Data dataset = datasets.ImageFolder(root="./data", transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) #FIXME devicedataloader -> do we need it? generator = Generator(opt.resBlocks, opt.upSampling) discriminator = Discriminator() feature_extractor = FeatureExtractor( torchvision.models.vgg19(pretrained=True)) content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() generator = nn.DataParallel(generator) generator.to(device) discriminator = nn.DataParallel(discriminator) discriminator.to(device) #feature_extractor = nn.DataParallel(feature_extractor) feature_extractor.to(device) #content_criterion = nn.DataParallel(content_criterion) content_criterion.to(device)
class ESRGAN(): def __init__(self, opt): self.opt = opt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hr_shape = (self.opt.hr_height, self.opt.hr_width) self._set_model(device, hr_shape) def _set_model(self, device, hr_shape): # Initialize generator and discriminator self.generator = GeneratorRRDB( opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device) self.discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device) self.feature_extractor = FeatureExtractor().to(device) # Set feature extractor to inference mode self.feature_extractor.eval() # Losses self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device) self.criterion_content = torch.nn.L1Loss().to(device) self.criterion_pixel = torch.nn.L1Loss().to(device) def _set_param(self): for key, value in vars(opt).items(): mlflow.log_param(key, value) def _load_weigth(self): if opt.epoch != 0: # Load pretrained models load_g_weight_path = osp.join(weight_save_dir, "generator_%d.pth" % opt.epoch) load_d_weight_path = osp.join(weight_save_dir, "discriminator_%d.pth" % opt.epoch) self.generator.load_state_dict(torch.load(load_g_weight_path)) self.discriminator.load_state_dict(torch.load(load_d_weight_path)) # Optimizers self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # ---------- # Training # ---------- def train(self, dataloader, opt): for epoch in range(opt.epoch + 1, opt.n_epochs + 1): for batch_num, imgs in enumerate(dataloader): Tensor = torch.cuda.FloatTensor if torch.cuda.is_available( ) else torch.Tensor batches_done = (epoch - 1) * len(dataloader) + batch_num # Configure model input imgs_lr = Variable(imgs["lr"].type(Tensor)) imgs_hr = Variable(imgs["hr"].type(Tensor)) # Adversarial ground truths valid = Variable(Tensor( np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False) fake = Variable(Tensor( np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False) # ------------------ # Train Generators # ------------------ optimizer_G.zero_grad() # Generate a high resolution image from low resolution input gen_hr = generator(imgs_lr) # Measure pixel-wise loss against ground truth loss_pixel = criterion_pixel(gen_hr, imgs_hr) # Warm-up (pixel-wise loss only) if batches_done <= opt.warmup_batches: loss_pixel.backward() optimizer_G.step() log_info = "[Epoch {}/{}] [Batch {}/{}] [G pixel: {}]".format( epoch, opt.n_epochs, batch_num, len(dataloader), loss_pixel.item()) sys.stdout.write("\r{}".format(log_info)) sys.stdout.flush() mlflow.log_metric('train_{}'.format('loss_pixel'), loss_pixel.item(), step=batches_done) else: # Extract validity predictions from discriminator pred_real = discriminator(imgs_hr).detach() pred_fake = discriminator(gen_hr) # Adversarial loss (relativistic average GAN) loss_GAN = criterion_GAN( pred_fake - pred_real.mean(0, keepdim=True), valid) # Content loss gen_features = feature_extractor(gen_hr) real_features = feature_extractor(imgs_hr).detach() loss_content = criterion_content(gen_features, real_features) # Total generator loss loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel loss_G.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() pred_real = discriminator(imgs_hr) pred_fake = discriminator(gen_hr.detach()) # Adversarial loss for real and fake images (relativistic average GAN) loss_real = criterion_GAN( pred_real - pred_fake.mean(0, keepdim=True), valid) loss_fake = criterion_GAN( pred_fake - pred_real.mean(0, keepdim=True), fake) # Total loss loss_D = (loss_real + loss_fake) / 2 loss_D.backward() optimizer_D.step() # -------------- # Log Progress # -------------- log_info = "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}, content: {}, adv: {}, pixel: {}]".format( epoch, opt.n_epochs, batch_num, len(dataloader), loss_D.item(), loss_G.item(), loss_content.item(), loss_GAN.item(), loss_pixel.item(), ) if batch_num == 1: sys.stdout.write("\n{}".format(log_info)) else: sys.stdout.write("\r{}".format(log_info)) sys.stdout.flush() # import pdb; pdb.set_trace() if batches_done % opt.sample_interval == 0: # Save image grid with upsampled inputs and ESRGAN outputs imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4) img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1)) image_batch_save_dir = osp.join( image_train_save_dir, '{:07}'.format(batches_done)) os.makedirs(osp.join(image_batch_save_dir, "hr_image"), exist_ok=True) save_image(img_grid, osp.join(image_batch_save_dir, "hr_image", "%d.png" % batches_done), nrow=1, normalize=False) if batches_done % opt.checkpoint_interval == 0: # Save model checkpoints torch.save( generator.state_dict(), osp.join(weight_save_dir, "generator_%d.pth" % epoch)) torch.save( discriminator.state_dict(), osp.join(weight_save_dir, "discriminator_%d.pth" % epoch)) mlflow.log_metric('train_{}'.format('loss_D'), loss_D.item(), step=batches_done) mlflow.log_metric('train_{}'.format('loss_G'), loss_G.item(), step=batches_done) mlflow.log_metric('train_{}'.format('loss_content'), loss_content.item(), step=batches_done) mlflow.log_metric('train_{}'.format('loss_GAN'), loss_GAN.item(), step=batches_done) mlflow.log_metric('train_{}'.format('loss_pixel'), loss_pixel.item(), step=batches_done)
def down_and_up_sampling(image, save_name, upsampling): opt = setup() # create output folder try: os.makedirs('output/high_res_fake') os.makedirs('output/high_res_real') os.makedirs('output/low_res') except OSError: pass if torch.cuda.is_available() and not opt.cuda: print('[WARNING]: You have a CUDA device, so you should probably run with --cuda') transform = transforms.Compose([transforms.RandomCrop(( image.size[0], image.size[1])), transforms.Pad(padding = 0), transforms.ToTensor()]) normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) # [down sampling] down-sampling part scale = transforms.Compose([transforms.ToPILImage(), transforms.Resize((int(image.size[1] / opt.upSampling), int(image.size[0] / opt.upSampling))), transforms.Pad(padding=0), transforms.ToTensor(), transforms.Normalize( mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])]) # Equivalent to un-normalizing ImageNet (for correct visualization) unnormalize = transforms.Normalize( mean = [-2.118, -2.036, -1.804], std = [4.367, 4.464, 4.444]) if opt.dataset == 'folder': # folder dataset dataset = datasets.ImageFolder(root = opt.dataroot, transform = transform) elif opt.dataset == 'cifar10': dataset = datasets.CIFAR10(root = opt.dataroot, download = True, train = False, transform = transform) elif opt.dataset == 'cifar100': dataset = datasets.CIFAR100(root = opt.dataroot, download = True, train = False, transform = transform) assert dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size = opt.batchSize, shuffle = False, num_workers = int(opt.workers)) my_loader = transforms.Compose([transforms.ToTensor()]) image = my_loader(image) # [paras] loading paras from .pth files generator = Generator(16, opt.upSampling) if opt.generatorWeights != '': generator.load_state_dict(torch.load(opt.generatorWeights)) discriminator = Discriminator() if opt.discriminatorWeights != '': discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) # For the content loss feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained = True)) content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() target_real = Variable(torch.ones(opt.batchSize, 1)) target_fake = Variable(torch.zeros(opt.batchSize, 1)) # if gpu is to be used if opt.cuda: generator.cuda() discriminator.cuda() feature_extractor.cuda() content_criterion.cuda() adversarial_criterion.cuda() target_real = target_real.cuda() target_fake = target_fake.cuda() low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize[0], opt.imageSize[1]) # print('Test started...') mean_generator_content_loss = 0.0 mean_generator_adversarial_loss = 0.0 mean_generator_total_loss = 0.0 mean_discriminator_loss = 0.0 # Set evaluation mode (not training) generator.eval() discriminator.eval() data = image for i in range(1): # Generate data high_res_real = data low_res = scale(high_res_real) low_res = torch.tensor([np.array(low_res)]) high_res_real = normalize(high_res_real) high_res_real = torch.tensor([np.array(high_res_real)]) # Generate real and fake inputs if opt.cuda: high_res_real = Variable(high_res_real.cuda()) high_res_fake = generator(Variable(low_res).cuda()) else: high_res_real = Variable(high_res_real) high_res_fake = generator(Variable(low_res)) # >>> create hr images save_image(unnormalize(high_res_real[0]), 'output/high_res_real/' + save_name) save_image(unnormalize(high_res_fake[0]), 'output/high_res_fake/' + save_name) save_image(unnormalize(low_res[0]), 'output/low_res/' + save_name)
help='Pass 1 to load checkpoint') parser.add_argument('--b', default=16, type=int, help='number of residual blocks in generator') args = parser.parse_args() # Load data dataset = TrainDataset(args.root_dir) dataloader = DataLoader(dataset, args.batch_size, True, num_workers=args.num_workers) # Initialize models vgg = models.vgg19(pretrained=True) feature_extractor = FeatureExtractor(vgg, 5, 4) if torch.cuda.device_count() > 1: feature_extractor = nn.DataParallel(feature_extractor) feature_extractor = feature_extractor.to(device) disc = Discriminator() if torch.cuda.device_count() > 1: disc = nn.DataParallel(disc) disc = disc.to(device) if args.load_checkpoint == 1 and os.path.exists('disc.pt'): disc.load_state_dict(torch.load('disc.pt')) print(disc) gen = Generator(args.b) if torch.cuda.device_count() > 1: gen = nn.DataParallel(gen)
transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) source_dataset = ImageFolder(args.source, transform=transform_source) target_dataset = ImageFolder(args.target, transform=transform_target) source_loader = DataLoader(source_dataset, batch_size=args.batch_size, shuffle=True) target_loader = DataLoader(target_dataset, batch_size=args.batch_size, shuffle=True) # models F = FeatureExtractor(resnet=args.resnet_type).to(device) C = LabelPredictor(resnet=args.resnet_type).to(device) D = DomainClassifier(resnet=args.resnet_type).to(device) class_criterion = nn.CrossEntropyLoss() domain_criterion = nn.BCEWithLogitsLoss() opt_F = optim.AdamW(F.parameters()) opt_C = optim.AdamW(C.parameters()) opt_D = optim.AdamW(D.parameters()) # train F.train() D.train() C.train() lamb, p, gamma, now, tot = 0, 0, 10, 0, len(source_loader) * args.n_epoch
seed = random.randint(1, 10000) print("Random Seed: ", seed) torch.manual_seed(seed) if opt.cuda: torch.cuda.manual_seed(seed) # build network print('==>building network...') generator = Generator(in_nc=opt.in_nc, mid_nc=opt.mid_nc, out_nc=opt.out_nc, scale_factor=opt.scale_factor, num_RRDBS=opt.num_RRDBs) discriminator = Discriminator() feature_extractor = FeatureExtractor() # loss # content loss if opt.content_loss_type == 'L1_Charbonnier': content_loss = L1_Charbonnier_loss() elif opt.content_loss_type == 'L1': content_loss = torch.nn.L1Loss() elif opt.content_loss_type == 'L2': content_loss = torch.nn.MSELoss() # pixel loss if opt.pixel_loss_type == 'L1': pixel_loss = torch.nn.L1Loss() elif opt.pixel_loss_type == 'L2':
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='cifar100', help='cifar10 | cifar100 | folder') parser.add_argument('--dataroot', type=str, default='./data', help='path to dataset') parser.add_argument('--workers', type=int, default=2, help='number of data loading workers') parser.add_argument('--batchSize', type=int, default=16, help='input batch size') parser.add_argument('--imageSize', type=int, default=15, help='the low resolution image size') parser.add_argument('--upSampling', type=int, default=2, help='low to high resolution scaling factor') parser.add_argument('--nEpochs', type=int, default=100, help='number of epochs to train for') parser.add_argument('--nPreEpochs', type=int, default=2, help='number of epochs to pre-train Generator') parser.add_argument('--generatorLR', type=float, default=0.0001, help='learning rate for generator') parser.add_argument('--discriminatorLR', type=float, default=0.0001, help='learning rate for discriminator') parser.add_argument('--cuda', action='store_true', help='enables cuda') parser.add_argument('--nGPU', type=int, default=1, help='number of GPUs to use') parser.add_argument('--generatorWeights', type=str, default='', help="path to generator weights (to continue training)") parser.add_argument('--discriminatorWeights', type=str, default='', help="path to discriminator weights (to continue training)") parser.add_argument('--out', type=str, default='checkpoints', help='folder to output model checkpoints') opt = parser.parse_args() print(opt) try: os.makedirs(opt.out) except OSError: pass if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") transform = transforms.Compose([transforms.RandomCrop(opt.imageSize*opt.upSampling), transforms.ToTensor()]) normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) scale = transforms.Compose([transforms.ToPILImage(), transforms.Scale(opt.imageSize), transforms.ToTensor(), transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) ]) if opt.dataset == 'folder': # folder dataset dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform) elif opt.dataset == 'cifar10': dataset = datasets.CIFAR10(root=opt.dataroot, train=True, download=True, transform=transform) elif opt.dataset == 'cifar100': dataset = datasets.CIFAR100(root=opt.dataroot, train=True, download=True, transform=transform) assert dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) generator = Generator(16, opt.upSampling) if opt.generatorWeights != '': generator.load_state_dict(torch.load(opt.generatorWeights)) print(generator) discriminator = Discriminator() if opt.discriminatorWeights != '': discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) print(discriminator) # For the content loss feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True)) print(feature_extractor) content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() ones_const = Variable(torch.ones(opt.batchSize, 1)) # if gpu is to be used if opt.cuda: generator.cuda() discriminator.cuda() feature_extractor.cuda() content_criterion.cuda() adversarial_criterion.cuda() ones_const = ones_const.cuda() optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR) optim_discriminator = optim.Adam(discriminator.parameters(), lr=opt.discriminatorLR) configure('logs/' + opt.dataset + '-' + str(opt.batchSize) + '-' + str(opt.generatorLR) + '-' + str(opt.discriminatorLR), flush_secs=5) visualizer = Visualizer(image_size=opt.imageSize*opt.upSampling) low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) # Pre-train generator using raw MSE loss print('Generator pre-training') for epoch in range(opt.nPreEpochs): mean_generator_content_loss = 0.0 for i, data in enumerate(dataloader): # Generate data high_res_real, _ = data # Downsample images to low resolution if len(high_res_real) < opt.batchSize: # skip final batch , len = batchsize if not last batch else len < batchsize continue for j in range(opt.batchSize): low_res[j] = scale(high_res_real[j]) high_res_real[j] = normalize(high_res_real[j]) # Generate real and fake inputs if opt.cuda: high_res_real = Variable(high_res_real.cuda()) high_res_fake = generator(Variable(low_res).cuda()) else: high_res_real = Variable(high_res_real) high_res_fake = generator(Variable(low_res)) ######### Train generator ######### generator.zero_grad() generator_content_loss = content_criterion(high_res_fake, high_res_real) mean_generator_content_loss += generator_content_loss.data generator_content_loss.backward() optim_generator.step() ######### Status and display ######### sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f' % (epoch, opt.nPreEpochs, i, len(dataloader), generator_content_loss.data)) visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data) sys.stdout.write('\r[%d/%d][%d/%d] Generator_MSE_Loss: %.4f\n' % (epoch, 2, i, len(dataloader), mean_generator_content_loss/len(dataloader))) log_value('generator_mse_loss', mean_generator_content_loss/len(dataloader), epoch) # Do checkpointing every epoch # torch.save(generator.state_dict(), '%s/generator_pretrain_%s.pth' %(opt.out,str(epoch))) # Do checkpointing torch.save(generator.state_dict(), '%s/generator_pretrain.pth' % opt.out) # SRGAN training optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR*0.1) optim_discriminator = optim.Adam(discriminator.parameters(), lr=opt.discriminatorLR*0.1) print('SRGAN training') for epoch in range(opt.nEpochs): mean_generator_content_loss = 0.0 mean_generator_adversarial_loss = 0.0 mean_generator_total_loss = 0.0 mean_discriminator_loss = 0.0 for i, data in enumerate(dataloader): # Generate data high_res_real, _ = data # Downsample images to low resolution if len(high_res_real) < opt.batchSize: # skip final batch , len = batchsize if not last batch else len < batchsize continue for j in range(opt.batchSize): low_res[j] = scale(high_res_real[j]) high_res_real[j] = normalize(high_res_real[j]) # Generate real and fake inputs if opt.cuda: high_res_real = Variable(high_res_real.cuda()) high_res_fake = generator(Variable(low_res).cuda()) target_real = Variable(torch.rand(opt.batchSize,1)*0.5 + 0.7).cuda() # size: opt.batchSize*1, and element is in 0.7~1.2 target_fake = Variable(torch.rand(opt.batchSize,1)*0.3).cuda() # size: opt.batchSize*1, and element is in 0~0.3 else: high_res_real = Variable(high_res_real) high_res_fake = generator(Variable(low_res)) target_real = Variable(torch.rand(opt.batchSize,1)*0.5 + 0.7) target_fake = Variable(torch.rand(opt.batchSize,1)*0.3) ######### Train discriminator ######### discriminator.zero_grad() discriminator_loss = adversarial_criterion(discriminator(high_res_real), target_real) + \ adversarial_criterion(discriminator(Variable(high_res_fake.data)), target_fake) mean_discriminator_loss += discriminator_loss.data discriminator_loss.backward() optim_discriminator.step() ######### Train generator ######### generator.zero_grad() real_features = Variable(feature_extractor(high_res_real).data) fake_features = feature_extractor(high_res_fake) # for content loss, we use total images' pixel-wise MSE loss and 0.006* VggLoss, which VggLoss is actual # MSE loss of some layers result(feature) in VggNet generator_content_loss = content_criterion(high_res_fake, high_res_real) + 0.006*content_criterion(fake_features, real_features) mean_generator_content_loss += generator_content_loss.data generator_adversarial_loss = adversarial_criterion(discriminator(high_res_fake), ones_const) mean_generator_adversarial_loss += generator_adversarial_loss.data generator_total_loss = generator_content_loss + 1e-3*generator_adversarial_loss mean_generator_total_loss += generator_total_loss.data generator_total_loss.backward() optim_generator.step() ######### Status and display ######### sys.stdout.write('\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f' % (epoch, opt.nEpochs, i, len(dataloader), discriminator_loss.data, generator_content_loss.data, generator_adversarial_loss.data, generator_total_loss.data)) visualizer.show(low_res, high_res_real.cpu().data, high_res_fake.cpu().data) sys.stdout.write('\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f\n' % (epoch, opt.nEpochs, i, len(dataloader), mean_discriminator_loss/len(dataloader), mean_generator_content_loss/len(dataloader), mean_generator_adversarial_loss/len(dataloader), mean_generator_total_loss/len(dataloader))) log_value('generator_content_loss', mean_generator_content_loss/len(dataloader), epoch) log_value('generator_adversarial_loss', mean_generator_adversarial_loss/len(dataloader), epoch) log_value('generator_total_loss', mean_generator_total_loss/len(dataloader), epoch) log_value('discriminator_loss', mean_discriminator_loss/len(dataloader), epoch) # Do checkpointing every epoch torch.save(generator.state_dict(), '%s/generator_final.pth' % opt.out) torch.save(discriminator.state_dict(), '%s/discriminator_final.pth' % opt.out) # Avoid closing print("train is over, and here can kill off threading after you watch the control log...") while True: pass
def main(args): #with torch.cuda.device(args.gpu): layers_map = { 'relu4_2': '22', 'relu2_2': '8', 'relu3_2': '13', 'relu1_2': '4' } vis = visdom.Visdom(port=args.display_port) loss_graph = { "g": [], "gd": [], "gf": [], "gpl": [], "gpab": [], "gs": [], "d": [], "gdl": [], "dl": [], } # for rgb the change is to feed 3 channels to D instead of just 1. and feed 3 channels to vgg. # can leave pixel separate between r and gb for now. assume user use the same weights transforms = get_transforms(args) if args.color_space == 'rgb': args.pixel_weight_ab = args.pixel_weight_rgb args.pixel_weight_l = args.pixel_weight_rgb rgbify = custom_transforms.toRGB() train_dataset = ImageFolder('train', args.data_path, transforms) train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) val_dataset = ImageFolder('val', args.data_path, transforms) indices = torch.randperm(len(val_dataset)) val_display_size = args.batch_size val_display_sampler = SequentialSampler(indices[:val_display_size]) val_loader = DataLoader(dataset=val_dataset, batch_size=val_display_size, sampler=val_display_sampler) # renormalize = transforms.Normalize(mean=[+0.5+0.485, +0.5+0.456, +0.5+0.406], std=[0.229, 0.224, 0.225]) feat_model = models.vgg19(pretrained=True) netG, netD, netD_local = get_models(args) criterion_gan, criterion_pixel_l, criterion_pixel_ab, criterion_style, criterion_feat, criterion_texturegan = get_criterions( args) real_label = 1 fake_label = 0 optimizerD = optim.Adam(netD.parameters(), lr=args.learning_rate_D, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=args.learning_rate, betas=(0.5, 0.999)) optimizerD_local = optim.Adam(netD_local.parameters(), lr=args.learning_rate_D_local, betas=(0.5, 0.999)) with torch.cuda.device(args.gpu): netG.cuda() netD.cuda() netD_local.cuda() feat_model.cuda() criterion_gan.cuda() criterion_pixel_l.cuda() criterion_pixel_ab.cuda() criterion_feat.cuda() criterion_texturegan.cuda() input_stack = torch.FloatTensor().cuda() target_img = torch.FloatTensor().cuda() target_texture = torch.FloatTensor().cuda() segment = torch.FloatTensor().cuda() label = torch.FloatTensor(args.batch_size).cuda() label_local = torch.FloatTensor(args.batch_size).cuda() extract_content = FeatureExtractor(feat_model.features, [layers_map[args.content_layers]]) extract_style = FeatureExtractor( feat_model.features, [layers_map[x.strip()] for x in args.style_layers.split(',')]) model = { "netG": netG, "netD": netD, "netD_local": netD_local, "criterion_gan": criterion_gan, "criterion_pixel_l": criterion_pixel_l, "criterion_pixel_ab": criterion_pixel_ab, "criterion_feat": criterion_feat, "criterion_style": criterion_style, "criterion_texturegan": criterion_texturegan, "real_label": real_label, "fake_label": fake_label, "optimizerD": optimizerD, "optimizerD_local": optimizerD_local, "optimizerG": optimizerG } for epoch in range(args.load_epoch, args.num_epoch): train(model, train_loader, val_loader, input_stack, target_img, target_texture, segment, label, label_local, extract_content, extract_style, loss_graph, vis, epoch, args)
model_D = network.discriminator_snIns().cuda() #model_local_D = SNnetwork.Discriminator(3, 64).cuda() model_local_D = network.discriminator_snIns().cuda() elif network_type == 'nlayerD': model_G = network.generator().cuda() model_D = network.NLayerDiscriminator(input_nc=3, ndf=64, n_layers=3).cuda() model_local_D = network.NLayerDiscriminator(input_nc=3, ndf=64, n_layers=3).cuda() else: model_G = network.generator().cuda() model_D = network.discriminator().cuda() model_local_D = network.discriminator().cuda() feat_model = tmodels.vgg19(pretrained=True).cuda() extract_content = FeatureExtractor( feat_model.features, [layers_map[x.strip()] for x in content_layers.split(',')]) extract_style = FeatureExtractor( feat_model.features, [layers_map[x.strip()] for x in style_layers.split(',')]) # loss criterion BCE_loss = nn.BCELoss().cuda() MSE_loss = nn.MSELoss().cuda() TV_loss = TVLoss().cuda() criterion = nn.L1Loss().cuda() # Adam optimizer #G_optimizer = torch.optim.Adam(model_G.parameters(), lr=learning_rate, weight_decay=1e-5) G_optimizer = torch.optim.Adam(model_G.parameters(), lr=learning_rate,
batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) generator = Generator(16, opt.upSampling) if opt.generatorWeights != '': generator.load_state_dict(torch.load(opt.generatorWeights)) print generator discriminator = Discriminator() if opt.discriminatorWeights != '': discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) print discriminator # For the content loss feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True)) print feature_extractor content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() ones_const = Variable(torch.ones(opt.batchSize, 1)) # if gpu is to be used if opt.cuda: generator.cuda() discriminator.cuda() feature_extractor.cuda() content_criterion.cuda() adversarial_criterion.cuda() ones_const = ones_const.cuda()
train_dataset = TrainDatasetFromFolder('data/DIV2K_train_HR/Train_HR', crop_size=opt.crop_size, upscale_factor=opt.upSampling) train_dataloader = DataLoader(dataset=train_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=4) val_dataset = ValDatasetFromFolder('data/DIV2K_valid_HR/Val_HR', upscale_factor=opt.upSampling) # 使用loader,从训练集中,一次性处理一个batch的文件 (批量加载器) val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=4) generator = Generator(3, filters=64, num_res_blocks=opt.residual_blocks, up_scale=opt.upSampling).to(device) # load pretrain model checkpoint = torch.load(opt.generator_pretrainWeights) generator.load_state_dict(checkpoint['generator_model_pre']) print('Load Generator pre successfully!') discriminator = Discriminator(in_channels=3, out_filters=64).to(device) feature_extractor = FeatureExtractor().to(device) feature_extractor.eval() # 内容损失和对抗损失 criterion_pixel = torch.nn.L1Loss().to(device) # 像素差的绝对值 content_criterion = torch.nn.L1Loss().to(device) adversarial_criterion = torch.nn.BCEWithLogitsLoss().to(device) # 交叉熵 Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor # tensorboard --logdir=logs configure( 'logs/' + opt.train_dataroot + '-' + str(64) + '-' + str(opt.generatorLR) + '-' + str(opt.discriminatorLR), flush_secs=5)
batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) G = Generator(10, opt.upSampling) if opt.generatorWeights != '': G.load_state_dict(torch.load(opt.generatorWeights)) print(G) D = Discriminator() if opt.discriminatorWeights != '': D.load_state_dict(torch.load(opt.discriminatorWeights)) print(D) # For the content loss FE = FeatureExtractor(torchvision.models.vgg19(pretrained=True)) print(FE) content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() ones_const = Variable(torch.ones(opt.batchSize, 1)) # if gpu is to be used if opt.cuda: G.cuda() D.cuda() FE.cuda() content_criterion.cuda() adversarial_criterion.cuda() ones_const = ones_const.cuda()
loss = logloss(d.unsqueeze(1), y) return loss def get_sync_loss(mel, g): g = g[:, :, :, g.size(3) // 2:] g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1) # B, 3 * T, H//2, W a, v = syncnet(mel, g) y = torch.ones(g.size(0), 1).float().to(device) return cosine_loss(a, v, y) recon_loss = nn.L1Loss() feature_extractor = FeatureExtractor() feature_extractor.eval() # --------- Add content loss here --------------- def get_content_loss(g, gt): gen_feautres = feature_extractor(g) real_features = feature_extractor(gt) loss_content = recon_loss(gen_feautres, real_features.detach()) return loss_content def train(device, model,
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='folder', help='cifar10 | cifar100 | folder') parser.add_argument('--dataroot', type=str, default='./data', help='path to dataset') parser.add_argument('--workers', type=int, default=1, help='number of data loading workers') parser.add_argument('--batchSize', type=int, default=1, help='input batch size') parser.add_argument('--imageSize', type=int, default=32, help='the low resolution image size') parser.add_argument('--upSampling', type=int, default=4, help='low to high resolution scaling factor') parser.add_argument('--cuda', action='store_true', help='enables cuda') parser.add_argument('--nGPU', type=int, default=1, help='number of GPUs to use') parser.add_argument( '--generatorWeights', type=str, default='checkpoints/generator_final.pth', help="path to generator weights (to continue training)") parser.add_argument( '--discriminatorWeights', type=str, default='checkpoints/discriminator_final.pth', help="path to discriminator weights (to continue training)") opt = parser.parse_args() print(opt) if not os.path.exists('output/high_res_fake'): os.makedirs('output/high_res_fake') if not os.path.exists('output/high_res_real'): os.makedirs('output/high_res_real') if not os.path.exists('output/low_res'): os.makedirs('output/low_res') if torch.cuda.is_available() and not opt.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) transform = transforms.Compose([ transforms.RandomCrop(opt.imageSize * opt.upSampling), transforms.ToTensor() ]) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) scale = transforms.Compose([ transforms.ToPILImage(), transforms.Scale(opt.imageSize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Equivalent to un-normalizing ImageNet (for correct visualization) unnormalize = transforms.Normalize(mean=[-2.118, -2.036, -1.804], std=[4.367, 4.464, 4.444]) if opt.dataset == 'folder': # folder dataset dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform) elif opt.dataset == 'cifar10': dataset = datasets.CIFAR10(root=opt.dataroot, download=True, train=False, transform=transform) elif opt.dataset == 'cifar100': dataset = datasets.CIFAR100(root=opt.dataroot, download=True, train=False, transform=transform) assert dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=False, num_workers=int(opt.workers)) generator = Generator(16, opt.upSampling) if opt.generatorWeights != '': generator.load_state_dict(torch.load(opt.generatorWeights)) print(generator) discriminator = Discriminator() if opt.discriminatorWeights != '': discriminator.load_state_dict(torch.load(opt.discriminatorWeights)) print(discriminator) # For the content loss feature_extractor = FeatureExtractor( torchvision.models.vgg19(pretrained=True)) print(feature_extractor) content_criterion = nn.MSELoss() adversarial_criterion = nn.BCELoss() target_real = Variable(torch.ones(opt.batchSize, 1)) target_fake = Variable(torch.zeros(opt.batchSize, 1)) # if gpu is to be used if opt.cuda: generator.cuda() discriminator.cuda() feature_extractor.cuda() content_criterion.cuda() adversarial_criterion.cuda() target_real = target_real.cuda() target_fake = target_fake.cuda() low_res = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) print('Test started...') mean_generator_content_loss = 0.0 mean_generator_adversarial_loss = 0.0 mean_generator_total_loss = 0.0 mean_discriminator_loss = 0.0 # Set evaluation mode (not training) generator.eval() discriminator.eval() for i, data in enumerate(dataloader): # Generate data high_res_real, _ = data # Downsample images to low resolution if len( high_res_real ) < opt.batchSize: # skip final batch , len = batchsize if not last batch else len < batchsize continue for j in range(opt.batchSize): low_res[j] = scale(high_res_real[j]) high_res_real[j] = normalize(high_res_real[j]) # Generate real and fake inputs if opt.cuda: high_res_real = Variable(high_res_real.cuda()) high_res_fake = generator(Variable(low_res).cuda()) else: high_res_real = Variable(high_res_real) high_res_fake = generator(Variable(low_res)) ######### Test discriminator ######### discriminator_loss = adversarial_criterion(discriminator(high_res_real), target_real) + \ adversarial_criterion(discriminator(high_res_fake), target_fake) mean_discriminator_loss += discriminator_loss.data ######### Test generator ######### real_features = feature_extractor(high_res_real) fake_features = feature_extractor(high_res_fake) generator_content_loss = content_criterion( high_res_fake, high_res_real) + 0.006 * content_criterion( fake_features, real_features) mean_generator_content_loss += generator_content_loss.data generator_adversarial_loss = adversarial_criterion( discriminator(high_res_fake), target_real) mean_generator_adversarial_loss += generator_adversarial_loss.data generator_total_loss = generator_content_loss + 1e-3 * generator_adversarial_loss mean_generator_total_loss += generator_total_loss.data ######### Status and display ######### sys.stdout.write( '\r[%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f' % (i, len(dataloader), discriminator_loss.data, generator_content_loss.data, generator_adversarial_loss.data, generator_total_loss.data)) if len( high_res_real ) < opt.batchSize: # skip final batch , len = batchsize if not last batch else len < batchsize continue for j in range(opt.batchSize): save_image( unnormalize(high_res_real[j].cpu()), 'output/high_res_real/' + str(i * opt.batchSize + j) + '.png') save_image( unnormalize(high_res_fake[j].cpu()), 'output/high_res_fake/' + str(i * opt.batchSize + j) + '.png') #save_image(high_res_real[j], 'output/high_res_real/' + str(i*opt.batchSize + j) + '.png') # without normlize, will mis-color real #save_image(high_res_fake[j], 'output/high_res_fake/' + str(i*opt.batchSize + j) + '.png') save_image(unnormalize(low_res[j]), 'output/low_res/' + str(i * opt.batchSize + j) + '.png') sys.stdout.write( '\r[%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f\n' % (i, len(dataloader), mean_discriminator_loss / len(dataloader), mean_generator_content_loss / len(dataloader), mean_generator_adversarial_loss / len(dataloader), mean_generator_total_loss / len(dataloader)))
class SACAgent: def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.firsttime = 0 self.env = env self.action_range = [env.action_space.low, env.action_space.high] #self.obs_dim = env.observation_space.shape[0] self.action_dim = env.action_space.shape[0] #1 self.conv_channels = 4 self.kernel_size = (3, 3) self.img_size = (500, 500, 3) print("Diagnostics:") print(f"action_range: {self.action_range}") #print(f"obs_dim: {self.obs_dim}") print(f"action_dim: {self.action_dim}") # hyperparameters self.gamma = gamma self.tau = tau self.update_step = 0 self.delay_step = 2 # initialize networks self.feature_net = FeatureExtractor(self.img_size[2], self.conv_channels, self.kernel_size).to(self.device) print("Feature net init'd successfully") input_dim = self.feature_net.get_output_size(self.img_size) self.input_size = input_dim[0] * input_dim[1] * input_dim[2] print(f"input_size: {self.input_size}") self.value_net = ValueNetwork(self.input_size, 1).to(self.device) self.target_value_net = ValueNetwork(self.input_size, 1).to(self.device) self.q_net1 = SoftQNetwork(self.input_size, self.action_dim).to(self.device) self.q_net2 = SoftQNetwork(self.input_size, self.action_dim).to(self.device) self.policy_net = PolicyNetwork(self.input_size, self.action_dim).to(self.device) print("Finished initing all nets") # copy params to target param for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): target_param.data.copy_(param) print("Finished copying targets") # initialize optimizers self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr) self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr) self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) print("Finished initing optimizers") self.replay_buffer = BasicBuffer(buffer_maxlen) print("End of init") def get_action(self, state): if state.shape != self.img_size: print( f"Invalid size, expected shape {self.img_size}, got {state.shape}" ) return None inp = torch.from_numpy(state).float().permute(2, 0, 1).unsqueeze(0).to( self.device) features = self.feature_net(inp) features = features.view(-1, self.input_size) mean, log_std = self.policy_net.forward(features) std = log_std.exp() normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) action = action.cpu().detach().squeeze(0).numpy() return self.rescale_action(action) def rescale_action(self, action): return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\ (self.action_range[1] + self.action_range[0]) / 2.0 def update(self, batch_size): states, actions, rewards, next_states, dones = self.replay_buffer.sample( batch_size) # states and next states are lists of ndarrays, np.stack converts them to # ndarrays of shape (batch_size, height, width, num_channels) states = np.stack(states) next_states = np.stack(next_states) states = torch.FloatTensor(states).permute(0, 3, 1, 2).to(self.device) actions = torch.FloatTensor(actions).to(self.device) rewards = torch.FloatTensor(rewards).to(self.device) next_states = torch.FloatTensor(next_states).permute(0, 3, 1, 2).to(self.device) dones = torch.FloatTensor(dones).to(self.device) dones = dones.view(dones.size(0), -1) # Process images features = self.feature_net( states) #.contiguous() # Properly shaped due to batching next_features = self.feature_net(next_states) #.contiguous() features = torch.reshape(features, (64, self.input_size)) next_features = torch.reshape(next_features, (64, self.input_size)) next_actions, next_log_pi = self.policy_net.sample(next_features) next_q1 = self.q_net1(next_features, next_actions) next_q2 = self.q_net2(next_features, next_actions) next_v = self.target_value_net(next_features) next_v_target = torch.min(next_q1, next_q2) - next_log_pi curr_v = self.value_net.forward(features) v_loss = F.mse_loss(curr_v, next_v_target.detach()) # q loss expected_q = rewards + (1 - dones) * self.gamma * next_v curr_q1 = self.q_net1.forward(features, actions) curr_q2 = self.q_net2.forward(features, actions) q1_loss = F.mse_loss(curr_q1, expected_q.detach()) q2_loss = F.mse_loss(curr_q2, expected_q.detach()) # update value and q networks self.value_optimizer.zero_grad() v_loss.backward(retain_graph=True) self.value_optimizer.step() self.q1_optimizer.zero_grad() q1_loss.backward(retain_graph=True) self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward(retain_graph=True) self.q2_optimizer.step() # delayed update for policy network and target q networks if self.update_step % self.delay_step == 0: new_actions, log_pi = self.policy_net.sample(features) min_q = torch.min(self.q_net1.forward(features, new_actions), self.q_net2.forward(features, new_actions)) policy_loss = (log_pi - min_q).mean() self.policy_optimizer.zero_grad() policy_loss.backward(retain_graph=True) self.policy_optimizer.step() # target networks for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) self.update_step += 1