Example #1
0
#   masked = ground * (1-mask)

### evaluate
model_path = args.modelpath
metrics = {}
out = {}
ground = None
mask = None
masked = None

for m in os.listdir(model_path):
    if '.pt' in m:
        epoch = int(re.search('epoch(\d+?)\_', m).group(1))
        full_path = os.path.join(model_path, m)
        print('eval', m)
        net_G = networks.get_network('generator', args.generator)
        G_statedict = torch.load(full_path, map_location=device)
        net_G.load_state_dict(G_statedict)

        # batch_ssim = torch.tensor([],dtype=torch.float32)
        # batch_l2 = torch.tensor([],dtype=torch.float32)
        # batch_ssim_local = torch.tensor([],dtype=torch.float32)
        # batch_l2_local = torch.tensor([],dtype=torch.float32)

        for gt, m, _ in eval_loader:

            ground = gt
            mask = m

            if 'real' in args.csvfile or 'extra' in args.csvfile:
                masked = ground * mask
Example #2
0
def begin(state, loaders):
  ### EXPERIMENT CONFIGURATION
  state = state.copy()
  state.update(
    {
      'title': 'minimaxgan_l1'
    }
  )

  train_loader, test_loader, extra_loader = loaders['train'], loaders['test'], loaders['extra']

  ### Create save directory
  experiment_dir = "/home/s2125048/thesis/model/{}/".format(state['title'])

  if not os.path.exists(experiment_dir):
    os.makedirs(experiment_dir)

  ### Setup logger
  logger = logging.getLogger(state['title'])
  hdlr = logging.FileHandler('log/{}.log'.format(state['title']))
  formatter = logging.Formatter('%(asctime)s %(message)s')
  hdlr.setFormatter(formatter)
  logger.addHandler(hdlr)
  logger.setLevel(logging.INFO)

  ### settings
  image_target_size = (state['imagedim'],state['imagedim'])
  batch_size = state['batchsize']
  shuffle = True 

  update_d_every = state['updatediscevery']
  evaluate_every = state['evalevery']
  num_epochs = state['numepoch']
  save_every = state['saveevery']

  device = torch.device('cpu')
  if torch.cuda.is_available():
    device = torch.device("cuda:0")
    
  logger.info('using device',device)

  ### networks
  net_G = networks.get_network('generator',state['generator']).to(device)
  net_D_global = networks.get_network('discriminator',state['discriminator']).to(device)

  rmse_criterion = loss.RMSELoss()
  bce_criterion = nn.BCELoss()
  l1_criterion = nn.L1Loss()

  G_optimizer = optim.Adam(net_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
  D_optimizer = optim.Adam(net_D_global.parameters(), lr=0.0002, betas=(0.5, 0.999))

  update_g_every = 5

  training_epoc_hist = []
  eval_hist = []

  ### Wassterstein Loss

  one = torch.FloatTensor(1)
  mone = one * -1

  one = one.to(device)
  mone = mone.to(device)

  G_iter_count = 0
  D_iter_count = 0

  for epoch in range(num_epochs + 1):
    start = time.time()

    epoch_g_loss = {
      'total': 0,
      'recon': 0,
      'adv': 0,
      'update_count': 0
    }
    epoch_d_loss = {
        'total': 0,
        'adv_real': 0,
        'adv_fake': 0,
        'update_count': 0
    }
    gradient_hist = {
        'avg_g': {},
        'avg_d': {},
    }

    for n, p in net_G.named_parameters():
      if("bias" not in n):
          gradient_hist['avg_g'][n] = 0
    for n, p in net_D_global.named_parameters():
      if("bias" not in n):
          gradient_hist['avg_d'][n] = 0
    
    for current_batch_index,(ground,mask,_) in enumerate(train_loader):

      curr_batch_size = ground.shape[0]
      ground = ground.to(device)
      mask = torch.ceil(mask.to(device))

      ## Inpaint masked images
      masked = ground * (1-mask)

      inpainted = net_G(masked)

      ### only replace the masked area
      inpainted = masked + inpainted * mask

      ###
      # 1: Discriminator maximize [ log(D(x)) + log(1 - D(G(x))) ]
      ###
      # Update Discriminator networks.

      util.set_requires_grad([net_D_global],True)
      D_optimizer.zero_grad()

      ## We want the discriminator to be able to identify fake and real images+ add_label_noise(noise_std,curr_batch_size)

      d_pred_real = net_D_global(ground).view(-1)
      d_loss_real = bce_criterion(d_pred_real, torch.ones(len(d_pred_real)).to(device) )
      d_loss_real.backward()

      D_x = d_loss_real.mean().item()

      d_pred_fake = net_D_global(inpainted.detach()).view(-1)
      d_loss_fake = bce_criterion(d_pred_fake, torch.zeros(len(d_pred_fake)).to(device) )    
      d_loss_fake.backward()

      D_G_z1 = d_pred_fake.mean().item()

      d_loss = d_loss_real + d_loss_fake

      D_optimizer.step()

      ###
      # 2: Generator maximize [ log(D(G(x))) ]
      ###

      util.set_requires_grad([net_D_global],False)
      G_optimizer.zero_grad()

      d_pred_fake = net_D_global(inpainted).view(-1)

      ### we want the generator to be able to fool the discriminator, 
      ### thus, the goal is to enable the discriminator to always predict the inpainted images as real
      ### 1 = real, 0 = fake
      g_adv_loss = bce_criterion(d_pred_fake, torch.ones(len(d_pred_fake)).to(device))

      ### the inpainted image should be close to ground truth
      
      recon_loss = l1_criterion(ground,inpainted)

      g_loss = g_adv_loss + recon_loss

      g_loss.backward()

      D_G_z2 = d_pred_fake.mean().item()
      G_optimizer.step()

      epoch_g_loss['total'] += g_loss.item()
      epoch_g_loss['recon'] += recon_loss.item()
      epoch_g_loss['adv'] += g_adv_loss.item()
      epoch_g_loss['update_count'] += 1

      for n, p in net_G.named_parameters():
        if("bias" not in n):
          gradient_hist['avg_g'][n] += p.grad.abs().mean().item()

      logger.info('[epoch %d/%d][batch %d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
              % (epoch, num_epochs, current_batch_index, len(train_loader),
                  d_loss.item(), g_adv_loss.item()))
        
      ### later will be divided by amount of training set
      ### in a minibatch, number of output may differ
      epoch_d_loss['total'] += d_loss.item()
      epoch_d_loss['adv_real'] += d_loss_real.item()
      epoch_d_loss['adv_fake'] += d_loss_fake.item()
      epoch_d_loss['update_count'] += 1

      for n, p in net_D_global.named_parameters():
        if("bias" not in n):
          gradient_hist['avg_d'][n] += p.grad.abs().mean().item()
    
    ### get gradient and epoch loss for generator
    try:
      epoch_g_loss['total'] = epoch_g_loss['total'] / epoch_g_loss['update_count']
      epoch_g_loss['recon'] = epoch_g_loss['recon'] / epoch_g_loss['update_count']
      epoch_g_loss['adv'] = epoch_g_loss['adv'] / epoch_g_loss['update_count']

      for n, p in net_G.named_parameters():
        if("bias" not in n):
          gradient_hist['avg_g'][n] = gradient_hist['avg_g'][n] / epoch_g_loss['update_count']

    except:

      ### set invalid values when G is not updated
      epoch_g_loss['total'] = -7777
      epoch_g_loss['recon'] = -7777
      epoch_g_loss['adv'] = -7777

      for n, p in net_G.named_parameters():
        if("bias" not in n):
          gradient_hist['avg_g'][n] = -7777

    ### get gradient and epoch loss for discriminator
    epoch_d_loss['total'] = epoch_d_loss['total'] / epoch_d_loss['update_count']
    epoch_d_loss['adv_real'] = epoch_d_loss['adv_real'] / epoch_d_loss['update_count']
    epoch_d_loss['adv_fake'] = epoch_d_loss['adv_fake'] / epoch_d_loss['update_count']
    
    for n, p in net_D_global.named_parameters():
      if("bias" not in n):
        gradient_hist['avg_d'][n] = gradient_hist['avg_d'][n] / epoch_d_loss['update_count']
    
    training_epoc_hist.append({
      "g": epoch_g_loss,
      "d": epoch_d_loss,
      "gradients": gradient_hist
    })

    if epoch % evaluate_every == 0 and epoch > 0:
      train_metric = evaluate.calculate_metric(device,train_loader,net_G,state['train_fid'],mode='train',inception_model=state['inception_model'])
      test_metric = evaluate.calculate_metric(device,test_loader,net_G,state['test_fid'],mode='test',inception_model=state['inception_model'])
      eval_hist.append({
        'train': train_metric,
        'test': test_metric
      })

      logger.info("VALIDATION: train - {}, test - {}".format(str(train_metric),str(test_metric)))

      ### save training history
      with open(os.path.join(experiment_dir,'training_epoch_history.obj'),'wb') as handle:
        pickle.dump(training_epoc_hist, handle, protocol=pickle.HIGHEST_PROTOCOL)
      
      ### save evaluation history
      with open(os.path.join(experiment_dir,'eval_history.obj'),'wb') as handle:
        pickle.dump(eval_hist, handle, protocol=pickle.HIGHEST_PROTOCOL)
  
    if epoch % save_every == 0 and epoch > 0:
      try:
        torch.save(net_G.state_dict(), os.path.join(experiment_dir, "epoch{}_G.pt".format(epoch)))
      except:
        logger.error(traceback.format_exc())
        pass
        
    elapsed = time.time() - start
    logger.info(f'epoch: {epoch}, time: {elapsed:.3f}s, D total: {epoch_d_loss["total"]:.10f}, D real: {epoch_d_loss["adv_real"]:.10f}, D fake: {epoch_d_loss["adv_fake"]:.10f}, G total: {epoch_g_loss["total"]:.3f}, G adv: {epoch_g_loss["adv"]:.10f}, G recon: {epoch_g_loss["recon"]:.3f}')
Example #3
0
def begin(state, loaders):
    ### EXPERIMENT CONFIGURATION
    state = state.copy()
    state.update({'title': 'wgan_l1'})

    train_loader, test_loader, extra_loader = loaders['train'], loaders[
        'test'], loaders['extra']

    ### Create save directory
    experiment_dir = "/home/s2125048/thesis/model/{}/".format(state['title'])

    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)

    ### Setup logger
    logger = logging.getLogger(state['title'])
    hdlr = logging.FileHandler('log/{}.log'.format(state['title']))
    formatter = logging.Formatter('%(asctime)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.INFO)

    ### settings
    image_target_size = (state['imagedim'], state['imagedim'])
    batch_size = state['batchsize']
    shuffle = True

    update_d_every = state['updatediscevery']
    evaluate_every = state['evalevery']
    num_epochs = state['numepoch']
    save_every = state['saveevery']

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device("cuda:0")

    logger.info('using device', device)

    ### networks
    net_G = networks.get_network('generator', state['generator']).to(device)
    net_D_global = networks.PatchGANDiscriminator(sigmoid=False).to(device)
    segment_model = state['segmentation_model']

    ### criterions
    rmse_global_criterion = loss.RMSELoss()
    rmse_local_criterion = loss.LocalLoss(loss.RMSELoss())
    bce_criterion = nn.BCELoss()
    l1_criterion = nn.L1Loss()

    w = torch.tensor([0, 1.2, 0.7, 0.7])
    weight_ce_criterion = nn.CrossEntropyLoss(weight=w).to(device)

    ### optimizer
    G_optimizer = optim.RMSprop(net_G.parameters(), lr=0.00005)
    D_optimizer = optim.RMSprop(net_D_global.parameters(), lr=0.00005)

    update_g_every = 5

    training_epoc_hist = []
    eval_hist = []

    ### Wassterstein Loss

    one = torch.FloatTensor(1)
    mone = one * -1

    one = one.to(device)
    mone = mone.to(device)

    G_iter_count = 0
    D_iter_count = 0

    for epoch in range(num_epochs + 1):
        start = time.time()

        epoch_g_loss = {
            'total': 0,
            'recon_global': 0,
            'recon_local': 0,
            'face_parsing': 0,
            'perceptual': 0,
            'style': 0,
            'tv': 0,
            'adv': 0,
            'update_count': 0
        }
        epoch_d_loss = {
            'total': 0,
            'adv_real': 0,
            'adv_fake': 0,
            'update_count': 0
        }
        gradient_hist = {
            'avg_g': {},
            'avg_d': {},
        }

        ### face attr segmentation metric
        classes_metric = {}
        for u in unique_labels:
            classes_metric[u] = {
                'precision': util.AverageMeter(),
                'recall': util.AverageMeter(),
                'iou': util.AverageMeter()
            }
        across_class_metric = {
            'precision': util.AverageMeter(),
            'recall': util.AverageMeter(),
            'iou': util.AverageMeter()
        }

        for n, p in net_G.named_parameters():
            if ("bias" not in n):
                gradient_hist['avg_g'][n] = 0
        for n, p in net_D_global.named_parameters():
            if ("bias" not in n):
                gradient_hist['avg_d'][n] = 0

        for current_batch_index, (ground, mask,
                                  segment) in enumerate(train_loader):

            curr_batch_size = ground.shape[0]
            ground = ground.to(device)
            segment = segment.to(device)
            mask = torch.ceil(mask.to(device))

            ## Inpaint masked images
            masked = ground * (1 - mask)

            inpainted = net_G(masked)

            ### only replace the masked area
            inpainted = masked + inpainted * mask

            ###
            # 1: Discriminator maximize [ log(D(x)) + log(1 - D(G(x))) ]
            ###
            # Update Discriminator networks.
            util.set_requires_grad([net_D_global], True)

            D_optimizer.zero_grad()

            ## We want the discriminator to be able to identify fake and real images+ add_label_noise(noise_std,curr_batch_size)

            d_pred_real = net_D_global(ground)
            d_pred_fake = net_D_global(inpainted.detach())

            d_loss_real = torch.mean(d_pred_real).view(1)
            d_loss_real.backward(one)

            d_loss_fake = torch.mean(d_pred_fake).view(1)
            d_loss_fake.backward(mone)

            ### put absolute value for EM distance
            d_loss = torch.abs(d_loss_real - d_loss_fake)

            D_G_z1 = d_pred_fake.mean().item()
            D_x = d_pred_real.mean().item()
            D_optimizer.step()

            D_iter_count += 1

            # Clip weights of discriminator
            for p in net_D_global.parameters():
                p.data.clamp_(-0.01, 0.01)

            ### On initial training epoch, we want D to converge as fast as possible before updating the generator
            ### that's why, G is updated every 100 D iterations
            if G_iter_count < 25 or G_iter_count % 500 == 0:
                update_G_every_batch = 140
            else:
                update_G_every_batch = update_g_every

            ### Update generator every `D_iter`
            if current_batch_index % update_G_every_batch == 0 and current_batch_index > 0:

                ###
                # 2: Generator maximize [ log(D(G(x))) ]
                ###

                util.set_requires_grad([net_D_global], False)
                G_optimizer.zero_grad()

                d_pred_fake = net_D_global(inpainted).view(-1)

                ### we want the generator to be able to fool the discriminator,
                ### thus, the goal is to enable the discriminator to always predict the inpainted images as real
                ### 1 = real, 0 = fake
                g_adv_loss = torch.mean(d_pred_fake).view(1)

                ### the inpainted image should be close to ground truth
                recon_global_loss = rmse_global_criterion(ground, inpainted)
                recon_local_loss = rmse_local_criterion(
                    ground, inpainted, mask)

                ### face parsing loss
                inpainted_segment = segment_model(inpainted)
                g_face_parsing_loss = 0.01 * weight_ce_criterion(
                    inpainted_segment, segment)

                ### perceptual and style
                g_perceptual_loss, g_style_loss = loss.perceptual_and_style_loss(
                    inpainted, ground, weight_p=0.01, weight_s=0.01)

                ### tv
                g_tv_loss = tv_loss(inpainted, tv_weight=1)

                g_loss = g_adv_loss + recon_loss + g_perceptual_loss + g_style_loss + g_face_parsing_loss + g_tv_loss
                g_loss.backward()

                D_G_z2 = d_pred_fake.mean().item()
                G_optimizer.step()

                G_iter_count += 1

                ### update segmentation metric
                class_m, across_class_m = evaluate.calculate_segmentation_eval_metric(
                    segment, inpainted_segment, unique_labels)
                for u in unique_labels:
                    for k, _ in classes_metric[u].items():
                        classes_metric[u][k].update(class_m[u][k],
                                                    ground.size(0))

                for k, _ in across_class_metric.items():
                    across_class_metric[k].update(across_class_m[k],
                                                  ground.size(0))

                epoch_g_loss['total'] += g_loss.item()
                epoch_g_loss['recon_global'] += recon_global_loss.item()
                epoch_g_loss['recon_local'] += recon_local_loss.item()
                epoch_g_loss['adv'] += g_adv_loss.item()
                epoch_g_loss['tv'] += g_tv_loss.item()
                epoch_g_loss['perceptual'] += g_tv_loss.item()
                epoch_g_loss['style'] += g_style_loss.item()
                epoch_g_loss['face_parsing'] += g_face_parsing_loss.item()
                epoch_g_loss['update_count'] += 1

                for n, p in net_G.named_parameters():
                    if ("bias" not in n):
                        gradient_hist['avg_g'][n] += p.grad.abs().mean().item()

                logger.info(
                    '[epoch %d/%d][batch %d/%d]\tLoss_D: %.4f\tLoss_G: %.4f' %
                    (epoch, num_epochs, current_batch_index, len(train_loader),
                     d_loss.item(), g_adv_loss.item()))

            ### later will be divided by amount of training set
            ### in a minibatch, number of output may differ
            epoch_d_loss['total'] += d_loss.item()
            epoch_d_loss['adv_real'] += d_loss_real.item()
            epoch_d_loss['adv_fake'] += d_loss_fake.item()
            epoch_d_loss['update_count'] += 1

            for n, p in net_D_global.named_parameters():
                if ("bias" not in n):
                    gradient_hist['avg_d'][n] += p.grad.abs().mean().item()

        ### get gradient and epoch loss for generator
        try:
            epoch_g_loss[
                'total'] = epoch_g_loss['total'] / epoch_g_loss['update_count']
            epoch_g_loss['recon_global'] = epoch_g_loss[
                'recon_global'] / epoch_g_loss['update_count']
            epoch_g_loss['recon_local'] = epoch_g_loss[
                'recon_local'] / epoch_g_loss['update_count']
            epoch_g_loss[
                'tv'] = epoch_g_loss['tv'] / epoch_g_loss['update_count']
            epoch_g_loss['perceptual'] = epoch_g_loss[
                'perceptual'] / epoch_g_loss['update_count']
            epoch_g_loss[
                'style'] = epoch_g_loss['style'] / epoch_g_loss['update_count']
            epoch_g_loss['face_parsing'] = epoch_g_loss[
                'face_parsing'] / epoch_g_loss['update_count']
            epoch_g_loss[
                'adv'] = epoch_g_loss['adv'] / epoch_g_loss['update_count']

            for n, p in net_G.named_parameters():
                if ("bias" not in n):
                    gradient_hist['avg_g'][n] = gradient_hist['avg_g'][
                        n] / epoch_g_loss['update_count']
        except:
            pass

        ### get gradient and epoch loss for discriminator
        epoch_d_loss[
            'total'] = epoch_d_loss['total'] / epoch_d_loss['update_count']
        epoch_d_loss['adv_real'] = epoch_d_loss['adv_real'] / epoch_d_loss[
            'update_count']
        epoch_d_loss['adv_fake'] = epoch_d_loss['adv_fake'] / epoch_d_loss[
            'update_count']

        for n, p in net_D_global.named_parameters():
            if ("bias" not in n):
                gradient_hist['avg_d'][n] = gradient_hist['avg_d'][
                    n] / epoch_d_loss['update_count']

        training_epoc_hist.append({
            "g": epoch_g_loss,
            "d": epoch_d_loss,
            "gradients": gradient_hist
        })

        if epoch % evaluate_every == 0 and epoch > 0:
            train_metric = evaluate.calculate_metric(
                device,
                train_loader,
                net_G,
                state['train_fid'],
                mode='train',
                inception_model=state['inception_model'])
            test_metric = evaluate.calculate_metric(
                device,
                test_loader,
                net_G,
                state['test_fid'],
                mode='test',
                inception_model=state['inception_model'])
            eval_hist.append({'train': train_metric, 'test': test_metric})

            logger.info("----------")
            logger.info("Validation")
            logger.info("----------")
            logger.info(
                "Train recon global: {glo: .4f}, local: {loc: .4f}, FID: {fid: .4f}"
                .format(glo=train_metric['recon_global'],
                        loc=train_metric['recon_local'],
                        fid=train_metric['fid']))
            logger.info(
                "Test recon global: {glo: .4f}, local: {loc: .4f}, FID: {fid: .4f}"
                .format(glo=test_metric['recon_rmse_global'],
                        loc=test_metric['recon_rmse_local'],
                        fid=test_metric['fid']))
            for u in unique_labels:
                logger.info(
                    "Class {}: Precision {prec: .4f}, Recall {rec: .4f}, IoU {iou: .4f}"
                    .format(u,
                            prec=test_metric['face_parsing_metric']
                            ['indv_class'][u]['precision'].avg,
                            rec=test_metric['face_parsing_metric']
                            ['indv_class'][u]['recall'].avg,
                            iou=test_metric['face_parsing_metric']
                            ['indv_class'][u]['iou'].avg))
            logger.info(
                "Across Classes: Precision {prec.avg: .4f}, Recall {rec.avg: .4f}, IoU {iou.avg: .4f}"
                .format(prec=test_metric['face_parsing_metric']
                        ['accross_class']['precision'],
                        rec=test_metric['face_parsing_metric']['accross_class']
                        ['recall'],
                        iou=test_metric['face_parsing_metric']['accross_class']
                        ['iou']))

            ### save training history
            with open(
                    os.path.join(experiment_dir, 'training_epoch_history.obj'),
                    'wb') as handle:
                pickle.dump(training_epoc_hist,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

            ### save evaluation history
            with open(os.path.join(experiment_dir, 'eval_history.obj'),
                      'wb') as handle:
                pickle.dump(eval_hist,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        if epoch % save_every == 0 and epoch > 0:
            try:
                torch.save(
                    net_G.state_dict(),
                    os.path.join(experiment_dir, "epoch{}_G.pt".format(epoch)))
            except:
                logger.error(traceback.format_exc())
                pass

        elapsed = time.time() - start

        logger.info(f'Epoch: {epoch}, time: {elapsed:.3f}s')
        logger.info('')
        logger.info('> Adversarial')
        logger.info(f'EM Distance: {epoch_d_loss["total"]:.10f}')
        logger.info('> Reconstruction')
        logger.info(
            f'Generator recon global: {epoch_g_loss["recon_global"]:.3f}, local: {epoch_g_loss["recon_local"]:.3f}'
        )
        logger.info('> Face Parsing')
        for u in unique_labels:
            logger.info(
                "Class {}: Precision {prec: .4f}, Recall {rec: .4f}, IoU {iou: .4f}"
                .format(u,
                        prec=classes_metric[u]['precision'].avg,
                        rec=classes_metric[u]['recall'].avg,
                        iou=classes_metric[u]['iou'].avg))
        logger.info(
            "Across Classes: Precision {prec.avg: .4f}, Recall {rec.avg: .4f}, IoU {iou.avg: .4f}"
            .format(prec=across_class_metric['precision'],
                    rec=across_class_metric['recall'],
                    iou=across_class_metric['iou']))
Example #4
0
def plot_inpainting_results(m_paths,csv_path):
  ### construct test dataset
  warnings.filterwarnings('ignore')
  test_df = pd.read_csv(csv_path)
  test_dataset = InpaintingDataset('/home/s2125048/thesis/dataset/',dataframe=test_df,
                                    transform=transforms.Compose([
                            transforms.Resize(image_target_size,),
                            transforms.ToTensor()]))

  test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=100,
    num_workers=0,
    shuffle=False
  )

  batch_size = 8
  num_batches = math.ceil(len(m_paths)/batch_size)

  for n in range(num_batches):
    print('batch',n)
    curr_m_paths = m_paths[n*batch_size:(n+1)*batch_size]

    net_G = {}
    for idx,path in enumerate(curr_m_paths):
      if torch.cuda.is_available():
        G_statedict = torch.load(os.path.join(path))
      else:
        G_statedict = torch.load(os.path.join(path),map_location='cpu')

      net_G[idx] = networks.get_network('generator','unet').to(device)
      net_G[idx].load_state_dict(G_statedict)
    
    start_ep = re.search('epoch\d+', curr_m_paths[0]).group(0)
    end_ep = re.search('epoch\d+', curr_m_paths[-1]).group(0)

    bcount = 0
    for test_loader_idx,(ground,mask,_) in enumerate(test_loader):  

      if not test_loader_idx % 3 == 0:
        continue 
        
      if bcount > 10:
        break

      print(test_loader_idx)
      ground = ground.to(device)
      mask = torch.ceil(mask).to(device)
      masked = ground * (1-mask)
      out = {}

      for idx,path in enumerate(curr_m_paths):
        out[idx] = net_G[idx](masked)
        out[idx] = masked + mask*out[idx]

      list_image_index = (np.array(range(0,10))*10).tolist()
      cols = 2 + batch_size
      plt.figure(figsize=(10,len(list_image_index)*3))

      for i, im_idx in enumerate(list_image_index):

        plt.subplot(len(list_image_index),cols,(i*cols+1))
        plt.imshow(ground[im_idx][0].detach().cpu().numpy(),cmap='Greys_r')
        plt.axis('off')
        if i == 0:
          plt.title('input')

        plt.subplot(len(list_image_index),cols,(i*cols+2))
        plt.imshow(masked[im_idx][0].detach().cpu().numpy(),cmap='Greys_r')
        plt.axis('off')
        if i == 0:
          plt.title('masked')

        for k in range(batch_size):
          plt.subplot(len(list_image_index),cols,(i*cols+3+k))
          im = (out[k][im_idx][0].detach().cpu().numpy()*255).astype(np.uint8)
          im = grey2rgb(im)
          plt.imshow(im)
          plt.axis('off')
          if i == 0:
            plt.title('ep {}'.format(constant_ep * (n*batch_size + k + 1)))

      plt.savefig(f'{result_root}/{start_ep}-{end_ep}_batch{test_loader_idx}.png',dpi=300)
      bcount += 1
Example #5
0
extra_df = pd.read_csv('extra.csv')
extra_dataset = dataset.InpaintingDataset(dataset_path,dataframe=extra_df,
                                  transform=transforms.Compose([
                          transforms.Resize(image_target_size,),
                          transforms.ToTensor()]))
extra_loader = torch.utils.data.DataLoader(
  extra_dataset,
  batch_size=30,
  num_workers=0,
  shuffle=True
)

sample_test_images = next(iter(extra_loader))

### networks
net_G = networks.get_network('generator',args.generator).to(device)
net_D_global = networks.get_network('discriminator',args.discriminator).to(device)

rmse_criterion = loss.RMSELoss()
bce_criterion = nn.BCELoss()

G_optimizer = optim.Adam(net_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_optimizer = optim.Adam(net_D_global.parameters(), lr=0.0002, betas=(0.5, 0.999))

training_epoc_hist = []
eval_hist = []

for epoch in range(num_epochs + 1):
  start = time.time()

  epoch_g_loss = {