def load(self):
     model_path = self.config['save_dir']
     if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
         latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
                             'r').read().splitlines()[-1]
     else:
         ckpts = [
             os.path.basename(i).split('.pth')[0]
             for i in glob.glob(os.path.join(model_path, '*.pth'))
         ]
         ckpts.sort()
         latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
     if latest_epoch is not None:
         gen_path = os.path.join(
             model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
         dis_path = os.path.join(
             model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
         opt_path = os.path.join(
             model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
         if self.config['global_rank'] == 0:
             print('Loading model from {}...'.format(gen_path))
         data = torch.load(
             gen_path,
             map_location=lambda storage, loc: set_device(storage))
         self.netG.load_state_dict(data['netG'])
         data = torch.load(
             dis_path,
             map_location=lambda storage, loc: set_device(storage))
         self.netD.load_state_dict(data['netD'])
         data = torch.load(
             opt_path,
             map_location=lambda storage, loc: set_device(storage))
         self.optimG.load_state_dict(data['optimG'])
         self.optimD.load_state_dict(data['optimD'])
         self.epoch = data['epoch']
         self.iteration = data['iteration']
     else:
         if self.config['global_rank'] == 0:
             print(
                 'Warnning: There is no trained model found. An initialized model will be used.'
             )
def main_worker(gpu, ngpus_per_node, config):
  torch.cuda.set_device(gpu)
  set_seed(config['seed'])

  # Model and version
  net = importlib.import_module('model.'+args.model_name)
  model = set_device(net.InpaintGenerator())
  latest_epoch = open(os.path.join(config['save_dir'], 'latest.ckpt'), 'r').read().splitlines()[-1]
  path = os.path.join(config['save_dir'], 'gen_{}.pth'.format(latest_epoch))
  data = torch.load(path, map_location = lambda storage, loc: set_device(storage)) 
  model.load_state_dict(data['netG'])
  model.eval()

  # prepare dataset
  dataset = Dataset(config['data_loader'], debug=False, split='test', level=args.level)
  step = math.ceil(len(dataset) / ngpus_per_node)
  dataset.set_subset(gpu*step, min(gpu*step+step, len(dataset)))
  dataloader = DataLoader(dataset, batch_size= BATCH_SIZE, shuffle=False, num_workers=config['trainer']['num_workers'], pin_memory=True)

  path = os.path.join(config['save_dir'], 'results_{}_level_{}'.format(str(latest_epoch).zfill(5), str(args.level).zfill(2)))
  os.makedirs(path, exist_ok=True)
  # iteration through datasets
  for idx, (images, masks, names) in enumerate(dataloader):
    print('[{}] GPU{} {}/{}: {}'.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
      gpu, idx, len(dataloader), names[0]))
    images, masks = set_device([images, masks])
    images_masked = images*(1-masks) + masks
    with torch.no_grad():
      _, output = model(torch.cat((images_masked, masks), dim=1), masks)
    orig_imgs = postprocess(images)
    mask_imgs = postprocess(images_masked)
    comp_imgs = postprocess((1-masks)*images+masks*output)
    pred_imgs = postprocess(output)
    for i in range(len(orig_imgs)):
      Image.fromarray(pred_imgs[i]).save(os.path.join(path, '{}_pred.png'.format(names[i].split('.')[0])))
      Image.fromarray(orig_imgs[i]).save(os.path.join(path, '{}_orig.png'.format(names[i].split('.')[0])))
      Image.fromarray(comp_imgs[i]).save(os.path.join(path, '{}_comp.png'.format(names[i].split('.')[0])))
      Image.fromarray(mask_imgs[i]).save(os.path.join(path, '{}_mask.png'.format(names[i].split('.')[0])))
  print('Finish in {}'.format(path))
    def __init__(self, config, debug=False):
        self.config = config
        self.epoch = 0
        self.iteration = 0
        if debug:
            self.config['trainer']['save_freq'] = 5
            self.config['trainer']['valid_freq'] = 5

        # setup data set and data loader
        self.train_dataset = Dataset(config['data_loader'],
                                     debug=debug,
                                     split='train')
        worker_init_fn = partial(set_seed, base=config['seed'])
        self.train_sampler = None
        if config['distributed']:
            self.train_sampler = DistributedSampler(
                self.train_dataset,
                num_replicas=config['world_size'],
                rank=config['global_rank'])
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=config['trainer']['batch_size'] // config['world_size'],
            shuffle=(self.train_sampler is None),
            num_workers=config['trainer']['num_workers'],
            pin_memory=True,
            sampler=self.train_sampler,
            worker_init_fn=worker_init_fn)

        # set up losses and metrics
        self.adversarial_loss = set_device(
            AdversarialLoss(type=self.config['losses']['gan_type']))
        self.l1_loss = nn.L1Loss()
        self.dis_writer = None
        self.gen_writer = None
        self.summary = {}
        if self.config['global_rank'] == 0 or (not config['distributed']):
            self.dis_writer = SummaryWriter(
                os.path.join(config['save_dir'], 'dis'))
            self.gen_writer = SummaryWriter(
                os.path.join(config['save_dir'], 'gen'))
        self.train_args = self.config['trainer']

        net = importlib.import_module('model.' + config['model_name'])
        self.netG = set_device(net.InpaintGenerator())
        self.netD = set_device(
            net.Discriminator(
                in_channels=3,
                use_sigmoid=config['losses']['gan_type'] != 'hinge'))
        self.optimG = torch.optim.Adam(self.netG.parameters(),
                                       lr=config['trainer']['lr'],
                                       betas=(self.config['trainer']['beta1'],
                                              self.config['trainer']['beta2']))
        self.optimD = torch.optim.Adam(self.netD.parameters(),
                                       lr=config['trainer']['lr'] *
                                       config['trainer']['d2glr'],
                                       betas=(self.config['trainer']['beta1'],
                                              self.config['trainer']['beta2']))
        self.load()
        if config['distributed']:
            self.netG = DDP(self.netG,
                            device_ids=[config['global_rank']],
                            output_device=config['global_rank'],
                            broadcast_buffers=True,
                            find_unused_parameters=False)
            self.netD = DDP(self.netD,
                            device_ids=[config['global_rank']],
                            output_device=config['global_rank'],
                            broadcast_buffers=True,
                            find_unused_parameters=False)
    def _train_epoch(self):
        progbar = Progbar(len(self.train_dataset),
                          width=20,
                          stateful_metrics=['epoch', 'iter'])
        mae = 0
        for images, masks, _ in self.train_loader:
            self.iteration += 1
            self.adjust_learning_rate()
            end = time.time()
            images, masks = set_device([images, masks])
            images_masked = (images * (1 - masks).float()) + masks
            inputs = torch.cat((images_masked, masks), dim=1)
            feats, pred_img = self.netG(inputs,
                                        masks)  # in: [rgb(3) + edge(1)]
            comp_img = (1 - masks) * images + masks * pred_img
            self.add_summary(self.dis_writer, 'lr/dis_lr',
                             self.get_lr(type='D'))
            self.add_summary(self.gen_writer, 'lr/gen_lr',
                             self.get_lr(type='G'))

            gen_loss = 0
            dis_loss = 0
            # image discriminator loss
            dis_real_feat = self.netD(images)
            dis_fake_feat = self.netD(comp_img.detach())
            dis_real_loss = self.adversarial_loss(dis_real_feat, True, True)
            dis_fake_loss = self.adversarial_loss(dis_fake_feat, False, True)
            dis_loss += (dis_real_loss + dis_fake_loss) / 2
            self.add_summary(self.dis_writer, 'loss/dis_fake_loss',
                             dis_fake_loss.item())
            self.optimD.zero_grad()
            dis_loss.backward()
            self.optimD.step()

            # generator adversarial loss
            gen_fake_feat = self.netD(comp_img)  # in: [rgb(3)]
            gen_fake_loss = self.adversarial_loss(gen_fake_feat, True, False)
            gen_loss += gen_fake_loss * self.config['losses'][
                'adversarial_weight']
            self.add_summary(self.gen_writer, 'loss/gen_fake_loss',
                             gen_fake_loss.item())

            # generator l1 loss
            hole_loss = self.l1_loss(pred_img * masks,
                                     images * masks) / torch.mean(masks)
            gen_loss += hole_loss * self.config['losses']['hole_weight']
            self.add_summary(self.gen_writer, 'loss/hole_loss',
                             hole_loss.item())
            valid_loss = self.l1_loss(pred_img *
                                      (1 - masks), images *
                                      (1 - masks)) / torch.mean(1 - masks)
            gen_loss += valid_loss * self.config['losses']['valid_weight']
            self.add_summary(self.gen_writer, 'loss/valid_loss',
                             valid_loss.item())
            if feats is not None:
                pyramid_loss = 0
                for _, f in enumerate(feats):
                    pyramid_loss += self.l1_loss(
                        f,
                        F.interpolate(images,
                                      size=f.size()[2:4],
                                      mode='bilinear',
                                      align_corners=True))
                gen_loss += pyramid_loss * self.config['losses'][
                    'pyramid_weight']
                self.add_summary(self.gen_writer, 'loss/pyramid_loss',
                                 pyramid_loss.item())

            # generator backward
            self.optimG.zero_grad()
            gen_loss.backward()
            self.optimG.step()

            # logs
            new_mae = (torch.mean(torch.abs(images - pred_img)) /
                       torch.mean(masks)).item()
            mae = new_mae if mae == 0 else (new_mae + mae) / 2
            speed = images.size(0) / (time.time() -
                                      end) * self.config['world_size']
            logs = [("epoch", self.epoch), ("iter", self.iteration),
                    ("lr", self.get_lr()), ('mae', mae), ('samples/s', speed)]
            if self.config['global_rank'] == 0:
                progbar.add(len(images)*self.config['world_size'], values=logs \
                  if self.train_args['verbosity'] else [x for x in logs if not x[0].startswith('l_')])

            # saving and evaluating
            if self.iteration % self.train_args['save_freq'] == 0:
                self.save(int(self.iteration // self.train_args['save_freq']))
            if self.iteration % self.train_args['valid_freq'] == 0:
                self._test_epoch(
                    int(self.iteration // self.train_args['save_freq']))
                if self.config['global_rank'] == 0:
                    print('[**] Training till {} in Rank {}\n'.format(
                        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                        self.config['global_rank']))
            if self.iteration > self.config['trainer']['iterations']:
                break
Exemple #5
0
from numpy import random
import numpy as np

from core import metric as module_metric
from core.utils import set_device
from core.inception import InceptionV3
from core.metric import calculate_activation_statistics, calculate_frechet_distance

parser = argparse.ArgumentParser(description='PyTorch Template')
parser.add_argument('-r', '--resume', required=True, type=str)
args = parser.parse_args()

dims = 2048
batch_size = 4
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = set_device(InceptionV3([block_idx]))

def main():
  real_names = list(glob.glob('{}/*_orig.png'.format(args.resume)))
  fake_names = list(glob.glob('{}/*_comp.png'.format(args.resume)))
  real_names.sort()
  fake_names.sort()
  # metrics prepare for image assesments
  metrics = {met: getattr(module_metric, met) for met in ['mae', 'psnr', 'ssim']}
  # infer through videos
  real_images = []
  fake_images = []
  evaluation_scores = {key: 0 for key,val in metrics.items()}
  for rname, fname in zip(real_names, fake_names):
    rimg = Image.open(rname)
    fimg = Image.open(fname)