def load(self): model_path = self.config['save_dir'] if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), 'r').read().splitlines()[-1] else: ckpts = [ os.path.basename(i).split('.pth')[0] for i in glob.glob(os.path.join(model_path, '*.pth')) ] ckpts.sort() latest_epoch = ckpts[-1] if len(ckpts) > 0 else None if latest_epoch is not None: gen_path = os.path.join( model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5))) dis_path = os.path.join( model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5))) opt_path = os.path.join( model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5))) if self.config['global_rank'] == 0: print('Loading model from {}...'.format(gen_path)) data = torch.load( gen_path, map_location=lambda storage, loc: set_device(storage)) self.netG.load_state_dict(data['netG']) data = torch.load( dis_path, map_location=lambda storage, loc: set_device(storage)) self.netD.load_state_dict(data['netD']) data = torch.load( opt_path, map_location=lambda storage, loc: set_device(storage)) self.optimG.load_state_dict(data['optimG']) self.optimD.load_state_dict(data['optimD']) self.epoch = data['epoch'] self.iteration = data['iteration'] else: if self.config['global_rank'] == 0: print( 'Warnning: There is no trained model found. An initialized model will be used.' )
def main_worker(gpu, ngpus_per_node, config): torch.cuda.set_device(gpu) set_seed(config['seed']) # Model and version net = importlib.import_module('model.'+args.model_name) model = set_device(net.InpaintGenerator()) latest_epoch = open(os.path.join(config['save_dir'], 'latest.ckpt'), 'r').read().splitlines()[-1] path = os.path.join(config['save_dir'], 'gen_{}.pth'.format(latest_epoch)) data = torch.load(path, map_location = lambda storage, loc: set_device(storage)) model.load_state_dict(data['netG']) model.eval() # prepare dataset dataset = Dataset(config['data_loader'], debug=False, split='test', level=args.level) step = math.ceil(len(dataset) / ngpus_per_node) dataset.set_subset(gpu*step, min(gpu*step+step, len(dataset))) dataloader = DataLoader(dataset, batch_size= BATCH_SIZE, shuffle=False, num_workers=config['trainer']['num_workers'], pin_memory=True) path = os.path.join(config['save_dir'], 'results_{}_level_{}'.format(str(latest_epoch).zfill(5), str(args.level).zfill(2))) os.makedirs(path, exist_ok=True) # iteration through datasets for idx, (images, masks, names) in enumerate(dataloader): print('[{}] GPU{} {}/{}: {}'.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), gpu, idx, len(dataloader), names[0])) images, masks = set_device([images, masks]) images_masked = images*(1-masks) + masks with torch.no_grad(): _, output = model(torch.cat((images_masked, masks), dim=1), masks) orig_imgs = postprocess(images) mask_imgs = postprocess(images_masked) comp_imgs = postprocess((1-masks)*images+masks*output) pred_imgs = postprocess(output) for i in range(len(orig_imgs)): Image.fromarray(pred_imgs[i]).save(os.path.join(path, '{}_pred.png'.format(names[i].split('.')[0]))) Image.fromarray(orig_imgs[i]).save(os.path.join(path, '{}_orig.png'.format(names[i].split('.')[0]))) Image.fromarray(comp_imgs[i]).save(os.path.join(path, '{}_comp.png'.format(names[i].split('.')[0]))) Image.fromarray(mask_imgs[i]).save(os.path.join(path, '{}_mask.png'.format(names[i].split('.')[0]))) print('Finish in {}'.format(path))
def __init__(self, config, debug=False): self.config = config self.epoch = 0 self.iteration = 0 if debug: self.config['trainer']['save_freq'] = 5 self.config['trainer']['valid_freq'] = 5 # setup data set and data loader self.train_dataset = Dataset(config['data_loader'], debug=debug, split='train') worker_init_fn = partial(set_seed, base=config['seed']) self.train_sampler = None if config['distributed']: self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=config['world_size'], rank=config['global_rank']) self.train_loader = DataLoader( self.train_dataset, batch_size=config['trainer']['batch_size'] // config['world_size'], shuffle=(self.train_sampler is None), num_workers=config['trainer']['num_workers'], pin_memory=True, sampler=self.train_sampler, worker_init_fn=worker_init_fn) # set up losses and metrics self.adversarial_loss = set_device( AdversarialLoss(type=self.config['losses']['gan_type'])) self.l1_loss = nn.L1Loss() self.dis_writer = None self.gen_writer = None self.summary = {} if self.config['global_rank'] == 0 or (not config['distributed']): self.dis_writer = SummaryWriter( os.path.join(config['save_dir'], 'dis')) self.gen_writer = SummaryWriter( os.path.join(config['save_dir'], 'gen')) self.train_args = self.config['trainer'] net = importlib.import_module('model.' + config['model_name']) self.netG = set_device(net.InpaintGenerator()) self.netD = set_device( net.Discriminator( in_channels=3, use_sigmoid=config['losses']['gan_type'] != 'hinge')) self.optimG = torch.optim.Adam(self.netG.parameters(), lr=config['trainer']['lr'], betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) self.optimD = torch.optim.Adam(self.netD.parameters(), lr=config['trainer']['lr'] * config['trainer']['d2glr'], betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) self.load() if config['distributed']: self.netG = DDP(self.netG, device_ids=[config['global_rank']], output_device=config['global_rank'], broadcast_buffers=True, find_unused_parameters=False) self.netD = DDP(self.netD, device_ids=[config['global_rank']], output_device=config['global_rank'], broadcast_buffers=True, find_unused_parameters=False)
def _train_epoch(self): progbar = Progbar(len(self.train_dataset), width=20, stateful_metrics=['epoch', 'iter']) mae = 0 for images, masks, _ in self.train_loader: self.iteration += 1 self.adjust_learning_rate() end = time.time() images, masks = set_device([images, masks]) images_masked = (images * (1 - masks).float()) + masks inputs = torch.cat((images_masked, masks), dim=1) feats, pred_img = self.netG(inputs, masks) # in: [rgb(3) + edge(1)] comp_img = (1 - masks) * images + masks * pred_img self.add_summary(self.dis_writer, 'lr/dis_lr', self.get_lr(type='D')) self.add_summary(self.gen_writer, 'lr/gen_lr', self.get_lr(type='G')) gen_loss = 0 dis_loss = 0 # image discriminator loss dis_real_feat = self.netD(images) dis_fake_feat = self.netD(comp_img.detach()) dis_real_loss = self.adversarial_loss(dis_real_feat, True, True) dis_fake_loss = self.adversarial_loss(dis_fake_feat, False, True) dis_loss += (dis_real_loss + dis_fake_loss) / 2 self.add_summary(self.dis_writer, 'loss/dis_fake_loss', dis_fake_loss.item()) self.optimD.zero_grad() dis_loss.backward() self.optimD.step() # generator adversarial loss gen_fake_feat = self.netD(comp_img) # in: [rgb(3)] gen_fake_loss = self.adversarial_loss(gen_fake_feat, True, False) gen_loss += gen_fake_loss * self.config['losses'][ 'adversarial_weight'] self.add_summary(self.gen_writer, 'loss/gen_fake_loss', gen_fake_loss.item()) # generator l1 loss hole_loss = self.l1_loss(pred_img * masks, images * masks) / torch.mean(masks) gen_loss += hole_loss * self.config['losses']['hole_weight'] self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item()) valid_loss = self.l1_loss(pred_img * (1 - masks), images * (1 - masks)) / torch.mean(1 - masks) gen_loss += valid_loss * self.config['losses']['valid_weight'] self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item()) if feats is not None: pyramid_loss = 0 for _, f in enumerate(feats): pyramid_loss += self.l1_loss( f, F.interpolate(images, size=f.size()[2:4], mode='bilinear', align_corners=True)) gen_loss += pyramid_loss * self.config['losses'][ 'pyramid_weight'] self.add_summary(self.gen_writer, 'loss/pyramid_loss', pyramid_loss.item()) # generator backward self.optimG.zero_grad() gen_loss.backward() self.optimG.step() # logs new_mae = (torch.mean(torch.abs(images - pred_img)) / torch.mean(masks)).item() mae = new_mae if mae == 0 else (new_mae + mae) / 2 speed = images.size(0) / (time.time() - end) * self.config['world_size'] logs = [("epoch", self.epoch), ("iter", self.iteration), ("lr", self.get_lr()), ('mae', mae), ('samples/s', speed)] if self.config['global_rank'] == 0: progbar.add(len(images)*self.config['world_size'], values=logs \ if self.train_args['verbosity'] else [x for x in logs if not x[0].startswith('l_')]) # saving and evaluating if self.iteration % self.train_args['save_freq'] == 0: self.save(int(self.iteration // self.train_args['save_freq'])) if self.iteration % self.train_args['valid_freq'] == 0: self._test_epoch( int(self.iteration // self.train_args['save_freq'])) if self.config['global_rank'] == 0: print('[**] Training till {} in Rank {}\n'.format( datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), self.config['global_rank'])) if self.iteration > self.config['trainer']['iterations']: break
from numpy import random import numpy as np from core import metric as module_metric from core.utils import set_device from core.inception import InceptionV3 from core.metric import calculate_activation_statistics, calculate_frechet_distance parser = argparse.ArgumentParser(description='PyTorch Template') parser.add_argument('-r', '--resume', required=True, type=str) args = parser.parse_args() dims = 2048 batch_size = 4 block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] model = set_device(InceptionV3([block_idx])) def main(): real_names = list(glob.glob('{}/*_orig.png'.format(args.resume))) fake_names = list(glob.glob('{}/*_comp.png'.format(args.resume))) real_names.sort() fake_names.sort() # metrics prepare for image assesments metrics = {met: getattr(module_metric, met) for met in ['mae', 'psnr', 'ssim']} # infer through videos real_images = [] fake_images = [] evaluation_scores = {key: 0 for key,val in metrics.items()} for rname, fname in zip(real_names, fake_names): rimg = Image.open(rname) fimg = Image.open(fname)