예제 #1
0
    def get_band_cover(self, variables):
        if utils.check_if_artist_cover_exists(variables):
            print '[*] ' + variables.band_name + ' cover images already exists'
            return
        zune_root = 'http://catalog.zune.net/v3.2/en-US/music/artist'
        response = requests.get(zune_root, params = { 'q': variables.band_name })
        artist_cover_path = os.path.join(variables.dirs.artists_cover,
                                         str(variables.band_id)) + '.jpg'

        if response.status_code == 200:
            xml_tree = ElementTree.fromstring(response.content)
            # Namespace for XML
            ns = { 'a': 'http://www.w3.org/2005/Atom',
                   'zune': 'http://schemas.zune.net/catalog/music/2007/10' }
            try:
                uuid = xml_tree.find('a:entry', ns).find('a:id', ns).text[9:]
            except Exception as e:
                self.get_band_cover_from_lastfm(variables)
                return

            response = requests.get(zune_root + '/' + uuid + '/images')
            xml_tree = ElementTree.fromstring(response.content)
            entries = xml_tree.findall('a:entry', ns)
            width = 0
            url = None
            # Get widest length cover pic
            for e in entries:
                instance = e.find('zune:instances', ns).find('zune:imageInstance', ns)
                url = instance.find('zune:url', ns).text
                break
            if not url:
                self.get_band_cover_from_lastfm(variables)
                return
            utils.save_image(url, artist_cover_path)
            print '[+] Added ' + variables.band_name + ' cover'
예제 #2
0
def stylize(args):
    img_list=os.listdir(args.content_image)
    epoch_name=os.path.basename(args.model).split('.')[0]
    experiment_name=os.path.dirname(args.output_image)
    if not os.path.exists(experiment_name):
        os.system('mkdir {}'.format(experiment_name))
    if not os.path.exists(args.output_image):
        os.system('mkdir {}'.format(args.output_image))
    for img in img_list:
        if is_image_file(img):
            content_image = utils.load_image(os.path.join(args.content_image,img), scale=args.content_scale)
            content_image=content_image.convert('RGB')
            content_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.mul(255))
            ])
            content_image = content_transform(content_image)
            content_image = content_image.unsqueeze(0)
            if args.cuda:
                content_image = content_image.cuda()
            content_image = Variable(content_image, volatile=True)

            style_model = TransformerNet()
            style_model.load_state_dict(torch.load(args.model))
            if args.cuda:
                style_model.cuda()
            output = style_model(content_image)
            if args.cuda:
                output = output.cpu()
                content_image=content_image.cpu()
            output_data = output.data[0]
            content_image_data=content_image.data[0]
            output_data=torch.cat([content_image_data,output_data],2)
            output_name=os.path.join(args.output_image,epoch_name+"_result_"+img)
            utils.save_image(output_name, output_data)
예제 #3
0
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
    """Save images to the disk.

    Parameters:
        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
        image_path (str)         -- the string is used to create image paths
        aspect_ratio (float)     -- the aspect ratio of saved images
        width (int)              -- the images will be resized to width x width

    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
    """
    image_dir = webpage.get_image_dir()
    short_path = ntpath.basename(image_path[0])
    name = os.path.splitext(short_path)[0]

    webpage.add_header(name)
    ims, txts, links = [], [], []

    for label, im_data in visuals.items():
        im = util.tensor2im(im_data)
        image_name = '%s_%s.png' % (name, label)
        save_path = os.path.join(image_dir, image_name)
        util.save_image(im, save_path, aspect_ratio=aspect_ratio)
        ims.append(image_name)
        txts.append(label)
        links.append(image_name)
    webpage.add_images(ims, txts, links, width=width)
예제 #4
0
    def test(self,
             session,
             step,
             summary_writer=None,
             print_rate=1,
             sample_dir=None,
             meta=None):

        feed_dict, original_sequence = self.get_feed_dict_and_orig(session)

        g_loss_pure, g_reg, d_loss_val, d_pen, rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo, summary = session.run(
            [
                self.g_cost_pure, self.gen_reg, self.d_cost, self.d_penalty,
                self.rmse_temp, self.rmse_cc, self.rmse_sh, self.rmse_sp,
                self.rmse_geo, self.summary_op
            ],
            feed_dict=feed_dict)

        summary_writer.add_summary(summary, step)
        original_sequence = original_sequence.reshape([
            1, self.frame_size, self.crop_size, self.crop_size, self.channels
        ])
        # print(original_sequence.shape)
        # images = zero state of weather
        # generate forecast from state zero
        forecast = session.run(self.sample, feed_dict=feed_dict)
        # return all rmse-s

        denorm_original_sequence = denormalize(original_sequence, self.wvars,
                                               self.crop_size, self.frame_size,
                                               self.channels, meta)
        denorm_forecast = denormalize(forecast, self.wvars, self.crop_size,
                                      self.frame_size, self.channels, meta)

        diff = []
        for orig, gen in zip(denorm_original_sequence, denorm_forecast):
            dif = orig - gen
            diff.append(dif[:, 1, :, :, :])

        if step % print_rate == 0:
            print(
                "Step: %d, generator loss: (%g + %g), discriminator_loss: (%g + %g)"
                % (step, g_loss_pure, g_reg, d_loss_val, d_pen))
            print("RMSE - Temp: %g, CC: %g, SH: %g, SP: %g, Geo: %g" %
                  (rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo))

            print('saving original')
            save_image(denorm_original_sequence, sample_dir,
                       'init_%d_image' % step)
            print('saving forecast / fakes')
            save_image(denorm_forecast, sample_dir, 'gen_%d_future' % step)

        rmse_all = [rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo]
        costs = [g_loss_pure, g_reg, d_loss_val, d_pen]

        return rmse_all, costs, diff
예제 #5
0
            def thread_function(input_image, width, height, out_path):
                image = utils.read_image(input_image)

                if resize_images is not None:
                    image = utils.resize_image(image, width, height)

                # remove all backgrounds
                if should_remove_backgrounds:
                    image = utils.remove_background(image)

                utils.save_image(out_path, image)
예제 #6
0
 def get_band_cover_from_lastfm(self, variables):
     print ' [+] Cover pic not found in Zune, trying to fetch from lastfm'
     artist_object = variables.network.get_artist(variables.band_name)
     artist_id = str(variables.band_id)
     # Save the artist thumbnails
     artist_cover_path = os.path.join(variables.dirs.artists_cover, artist_id)+'.jpg'
     # Note the size argument which returns the url for a smaller image
     try:
         utils.save_image(artist_object.get_cover_image(), artist_cover_path)
     except Exception as e:
         print "[-] Exception while fetching %s's cover pic: %s"%(variables.band_name,e.message)
         return
     print "[+] Added %s's cover"%(variables.band_name)
예제 #7
0
    def evaluate(self, epoch, step, counter, val_data_iterator):
        print("Running evaluation after epoch:{:02d} and step:{:04d} ".format(
            epoch, step))
        # evaluate reconstruction loss
        start_eval_batch = 0
        reconstructed_images = []
        num_eval_batches = val_data_iterator.get_num_samples(
            "val") // self.batch_size
        for _idx in range(start_eval_batch, num_eval_batches):
            batch_eval_images, batch_eval_labels, manual_labels = val_data_iterator.get_next_batch(
                "val")
            integer_label = np.asarray([
                np.where(r == 1)[0][0] for r in batch_eval_labels
            ]).reshape([64, 1])
            batch_eval_labels = np.concatenate(
                [batch_eval_labels, integer_label], axis=1)
            columns = [str(i) for i in range(10)]
            columns.append("label")
            pd.DataFrame(batch_eval_labels,
                         columns=columns)\
                .to_csv(self.result_dir + "label_test_{:02d}.csv".format(_idx),
                        index=False)

            batch_z = prior.gaussian(self.batch_size, self.z_dim)
            reconstructed_image, summary = self.sess.run(
                [self.out, self.merged_summary_op],
                feed_dict={
                    self.inputs: batch_eval_images,
                    self.labels: manual_labels[:, :10],
                    self.is_manual_annotated: manual_labels[:, 10],
                    self.standard_normal: batch_z
                })

            self.writer_v.add_summary(summary, counter)

            manifold_w = 4
            tot_num_samples = min(self.sample_num, self.batch_size)
            manifold_h = tot_num_samples // manifold_w
            reconstructed_images.append(
                reconstructed_image[:manifold_h * manifold_w, :, :, :])
        print("epoch:{} step:{}".format(epoch, step))
        reconstructed_dir = get_eval_result_dir(self.result_dir, epoch, step)
        print(reconstructed_dir)

        for _idx in range(start_eval_batch, num_eval_batches):
            file = "im_" + str(_idx) + ".png"
            save_image(reconstructed_images[_idx], [manifold_h, manifold_w],
                       reconstructed_dir + file)
        val_data_iterator.reset_counter("val")

        print("Evaluation completed")
예제 #8
0
 def get_album_thumbnail(self, variables):
     if utils.check_if_album_thumbnail_exists(variables):
         print '[*] ' + variables.album_name + ' thumbnail already exists'
         return
     album_object = variables.network.get_album(variables.band_name, variables.album_name)
     album_id = str(variables.album_id)
     album_image_path = os.path.join(variables.dirs.albums_thumbnail, album_id)+'.jpg'
     try:
         album_cover_image_url = album_object.get_cover_image()
         utils.save_image(album_cover_image_url, album_image_path)
     except Exception as e:
         print "[-] Exception while fetching %s's thumbnail:%s"%(variables.album_name,e.message)
         return
     print '[+] Added ' + variables.album_name + ' thumbnail'
예제 #9
0
 def get_band_thumbnail(self, variables):
     if utils.check_if_artist_thumbnail_exists(variables):
         print '[*] ' + variables.band_name + ' thumbnail already exists'
         return
     artist_object = variables.network.get_artist(variables.band_name)
     artist_id = str(variables.band_id)
     # Save the artist thumbnails
     artist_thumbnail_path = os.path.join(variables.dirs.artist_thumbnail, artist_id)+'.jpg'
     # Note the size argument which returns the url for a smaller image
     try:
         utils.save_image(artist_object.get_cover_image(size=2), artist_thumbnail_path)
     except Exception as e:
         print "[-] Exception while fetching %s's thumbnail: %s"%(variables.band_name,e.message)
         return
     print '[+] Added ' + variables.band_name + ' thumbnail'
예제 #10
0
def generate(
    cfg, 
    model: torch.nn.Module, 
    data_loader: torch.utils.data.DataLoader, 
    device: torch.device, 
    logger=None, 
    *args, 
    **kwargs, 
):
    model.eval()
    total_loss = []
    with utils.log_info(msg="Generate results", level="INFO", state=True, logger=logger):

        pbar = tqdm(total=len(data_loader), dynamic_ncols=True)
        for idx, data in enumerate(data_loader):
            start_time = time.time()
            output, *_ = utils.inference(model=model, data=data, device=device)

            for i in range(output.shape[0]):
                save_dir = os.path.join(cfg.SAVE.DIR, "results")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                path2file = os.path.join(save_dir, data["img_idx"][i]+".png")
                succeed = utils.save_image(output[i].detach().cpu().numpy(), cfg.DATA.MEAN, cfg.DATA.NORM, path2file)
                if not succeed:
                    utils.notify("Cannot save image to {}".format(path2file))

            pbar.update()
        pbar.close()
예제 #11
0
def generate(
    cfg, 
    model: torch.nn.Module, 
    data_loader: torch.utils.data.DataLoader, 
    device: torch.device, 
    phase, 
    logger=None, 
    *args, 
    **kwargs, 
):
    model.eval()
    # Prepare to log info.
    log_info = print if logger is None else logger.log_info
    total_loss = []
    inference_time = []
    # Read data and evaluate and record info.
    with utils.log_info(msg="Generate results", level="INFO", state=True, logger=logger):
        pbar = tqdm(total=len(data_loader), dynamic_ncols=True)
        for idx, data in enumerate(data_loader):
            start_time = time.time()
            output = utils.inference(model=model, data=data, device=device)
            inference_time.append(time.time()-start_time)

            for i in range(output.shape[0]):
                save_dir = os.path.join(cfg.SAVE.DIR, phase)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                path2file = os.path.join(save_dir, data["img_idx"][i]+"_g.png")
                succeed = utils.save_image(output[i].detach().cpu().numpy(), cfg.DATA.MEAN, cfg.DATA.NORM, path2file)
                if not succeed:
                    log_info("Cannot save image to {}".format(path2file))
            pbar.update()
        pbar.close()
    log_info("Runtime per image: {:<5} seconds.".format(round(sum(inference_time)/len(inference_time), 4)))
예제 #12
0
    def train(self,
              session,
              step,
              summary_writer=None,
              log_summary=False,
              sample_dir=None,
              generate_sample=False,
              meta=None):
        if log_summary:
            start_time = time.time()

        critic_itrs = self.critic_iterations

        for critic_itr in range(critic_itrs):
            session.run(self.d_adam, feed_dict=self.get_feed_dict(session))

        feed_dict = self.get_feed_dict(session)
        session.run(self.g_adam_gan, feed_dict=feed_dict)
        session.run(self.g_adam_first, feed_dict=feed_dict)

        if log_summary:
            g_loss_pure, g_reg, d_loss_val, d_pen, rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo, fake_min, fake_max, summary = session.run(
                [self.g_cost_pure, self.gen_reg, self.d_cost, self.d_penalty, self.rmse_temp, self.rmse_cc, self.rmse_sh, self.rmse_sp, self.rmse_geo, self.fake_min, self.fake_max, self.summary_op],
                feed_dict=feed_dict)
            summary_writer.add_summary(summary, step)
            print("Time: %g/itr, Step: %d, generator loss: (%g + %g), discriminator_loss: (%g + %g)" % (
                time.time() - start_time, step, g_loss_pure, g_reg, d_loss_val, d_pen))
            print("RMSE - Temp: %g, CC: %g, SH: %g, SP: %g, Geo: %g" % (rmse_temp, rmse_cc, rmse_sh, rmse_sp, rmse_geo))
            print("Fake_vid min: %g, max: %g" % (fake_min, fake_max))

        if generate_sample:
            original_sequence = session.run(self.videos)
            original_sequence = original_sequence.reshape([self.batch_size, self.frame_size, self.crop_size, self.crop_size, self.channels])
            print(original_sequence.shape)
            # images = zero state of weather
            images = original_sequence[:,0,:,:,:]
            # generate forecast from state zero
            forecast = session.run(self.sample, feed_dict={self.input_images: images})

            original_sequence = denormalize(original_sequence, self.wvars, self.crop_size, self.frame_size, self.channels, meta)
            print('saving original')
            save_image(original_sequence, sample_dir, 'init_%d_image' % step)

            forecast = denormalize(forecast, self.wvars, self.crop_size, self.frame_size, self.channels, meta)
            print('saving forecast / fakes')
            save_image(forecast, sample_dir, 'gen_%d_future' % step)
예제 #13
0
    def save_images(self, webpage, visuals, image_path):
        image_dir = webpage.get_image_dir()
        short_path = ntpath.basename(image_path[0])
        name = os.path.splitext(short_path)[0]

        webpage.add_header(name)
        ims = []
        txts = []
        links = []

        for label, image_numpy in visuals.items():
            image_name = '%s_%s.png' % (name, label)
            save_path = os.path.join(image_dir, image_name)
            utils.save_image(image_numpy, save_path)

            ims.append(image_name)
            txts.append(label)
            links.append(image_name)
        webpage.add_images(ims, txts, links, width=self.win_size)
예제 #14
0
    def save_images(self, webpage, visuals, image_path):
        image_dir = webpage.get_image_dir()
        short_path = ntpath.basename(image_path[0])
        name = os.path.splitext(short_path)[0]

        webpage.add_header(name)
        ims = []
        txts = []
        links = []

        for label, image_numpy in visuals.items():
            image_name = '%s_%s.png' % (name, label)
            save_path = os.path.join(image_dir, image_name)
            utils.save_image(image_numpy, save_path)

            ims.append(image_name)
            txts.append(label)
            links.append(image_name)
        webpage.add_images(ims, txts, links, width=self.win_size)
예제 #15
0
def evaluate(
    epoch: int, 
    cfg, 
    model: torch.nn.Module, 
    data_loader: torch.utils.data.DataLoader, 
    device: torch.device, 
    loss_fn, 
    metrics_logger, 
    phase="valid", 
    logger=None, 
    save=False, 
    *args, 
    **kwargs, 
):
    model.eval()
    # Prepare to log info.
    log_info = print if logger is None else logger.log_info
    total_loss = []
    inference_time = []
    # Read data and evaluate and record info.
    with utils.log_info(msg="{} at epoch: {}".format(phase.upper(), str(epoch).zfill(3)), level="INFO", state=True, logger=logger):
        # log_info("Will{}save results to {}".format(" " if save else " not ", cfg.SAVE.DIR))
        pbar = tqdm(total=len(data_loader), dynamic_ncols=True)
        for idx, data in enumerate(data_loader):
            start_time = time.time()
            out, loss = utils.inference_and_cal_loss(model=model, data=data, loss_fn=loss_fn, device=device)
            inference_time.append(time.time()-start_time)
            total_loss.append(loss.detach().cpu().item())

            if save:
                # Save results to directory.
                for i in range(out.shape[0]):
                    save_dir = os.path.join(cfg.SAVE.DIR, phase)
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    path2file = os.path.join(save_dir, data["img_idx"][i]+"_g.png")
                    succeed = utils.save_image(out[i].detach().cpu().numpy(), cfg.DATA.MEAN, cfg.DATA.NORM, path2file)
                    if not succeed:
                        log_info("Cannot save image to {}".format(path2file))
            
            metrics_logger.record(phase, epoch, "loss", loss.detach().cpu().item())
            output = out.detach().cpu()
            target = data["target"]
            utils.cal_and_record_metrics(phase, epoch, output, target, metrics_logger, logger=logger)

            pbar.set_description("Epoch: {:<3}, avg loss: {:<5}, cur loss: {:<5}".format(epoch, round(sum(total_loss)/len(total_loss), 5), round(total_loss[-1], 5)))
            pbar.update()
        pbar.close()
    log_info("Runtime per image: {:<5} seconds.".format(round(sum(inference_time)/len(inference_time), 4)))
    mean_metrics = metrics_logger.mean(phase, epoch)
    log_info("SSIM: {:<5}, PSNR: {:<5}, MAE: {:<5}, Loss: {:<5}".format(
        mean_metrics["SSIM"], mean_metrics["PSNR"], mean_metrics["MAE"], mean_metrics["loss"], 
    ))
예제 #16
0
def stylize(content_image, style):
    device = torch.device("cpu")
    content_transform = transforms.Compose([
        transforms.ToTensor()
        # transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load("./saved_models/" + style)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
        utils.save_image("./style/static/style-sample-images/temp.jpg",
                         output[0])
        return output[0]
예제 #17
0
def test(test_loader, model, device, args):
    test_psnr = AverageMeter('PSNR')
    test_ssim = AverageMeter('SSIM')

    model.eval()
    with torch.no_grad():
        for i, (input, label) in enumerate(test_loader, 1):
            blur1, blur2, blur3 = input
            sharp1, sharp2, sharp3 = label
            blur1, blur2, blur3 = blur1.to(device), blur2.to(device), blur3.to(device)
            sharp1, sharp2, sharp3 = sharp1.to(device), sharp2.to(device), sharp3.to(device)

            pred1, _, _ = model(blur1, blur2, blur3)

            sharp1 = sharp1.detach().clone().cpu().numpy().squeeze().transpose(1, 2, 0)
            pred1 = pred1.detach().clone().cpu().numpy().squeeze().transpose(1, 2, 0)
            blur1 = blur1.detach().clone().cpu().numpy().squeeze().transpose(1, 2, 0)
            
            sharp1 += 0.5
            pred1 += 0.5
            blur1 += 0.5

            psnr = peak_signal_noise_ratio(
                sharp1, pred1, 
                data_range=1.
            )
            ssim = structural_similarity(
                sharp1, pred1, 
                multichannel=True, 
                gaussian_weights=True, 
                use_sample_covariance=False,
                data_range=1.
            )

            test_psnr.update(psnr)
            test_ssim.update(ssim)

            print('{:d}/{:d} | PSNR (dB) {:.2f} | SSIM {:.4f}'.format(i, len(test_loader), psnr, ssim))
            
            save_image(sharp1, os.path.join(args.save_dir, '{:d}_sharp.png'.format(i)))
            save_image(blur1, os.path.join(args.save_dir, '{:d}_blur.png'.format(i)))
            save_image(pred1, os.path.join(args.save_dir, '{:d}_pred.png'.format(i)))

    print('>> Avg PSNR, SSIM: {:.2f}, {:.2f}'.format(test_psnr.avg, test_ssim.avg))
예제 #18
0
            def thread_function(input_image, width, height, out_path):
                image = utils.read_image(input_image)
                image = utils.resize_image(image, width, height)

                utils.save_image(out_path, image)
예제 #19
0
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    # 1) Create Dataset of 300_WLP.
    train_data_dir = [
        '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_AFW_HELEN_LFPW',
        '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_AFW_HELEN_LFPW_Flip',
        '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_IBUG_Src_Flip'
    ]
    wlp300 = PRNetDataset(root_dir=train_data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    # 2) Create DataLoader.
    wlp300_dataloader = DataLoaderX(dataset=wlp300,
                                    batch_size=FLAGS['batch_size'],
                                    shuffle=True,
                                    num_workers=4)

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256()

    #GPU
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to("cuda")

    #Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    # scheduler_MultiStepLR = torch.optim.lr_scheduler.MultiStepLR(optimizer,[11],gamma=0.5, last_epoch=-1)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.1,
                                                           patience=5,
                                                           min_lr=1e-6,
                                                           verbose=False)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
    # scheduler_StepLR = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5, last_epoch=-1)

    #apex混合精度训练
    # from apex import amp
    # model , optimizer = amp.initialize(model,optimizer,opt_level="O1",verbosity=0)

    #Loss
    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])

    # Load the pre-trained weight
    if FLAGS['resume'] and os.path.exists(
            os.path.join(FLAGS['model_path'], "latest.pth")):
        state = torch.load(os.path.join(
            FLAGS['model_path'],
            "latest.pth"))  #这是个字典,keys: ['prnet', 'Loss', 'start_epoch']
        model.load_state_dict(state['prnet'])
        optimizer.load_state_dict(state['optimizer'])
        # amp.load_state_dict(state['amp'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    #Tensorboard
    model_input = torch.rand(FLAGS['batch_size'], 3, 256, 256)
    writer.add_graph = (model, model_input)

    nme_mean = 999

    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        Loss_list, Stat_list = deque(maxlen=len(bar)), deque(maxlen=len(bar))

        model.train()
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            uv_map_result = model(origin)

            # Loss & ssim stat.
            logit_loss = loss(uv_map_result, uv_map)
            stat_logit = stat_loss(uv_map_result, uv_map)

            # Record Loss.
            Loss_list.append(logit_loss.item())
            Stat_list.append(stat_logit.item())

            # Update.
            optimizer.zero_grad()
            logit_loss.backward()
            # with amp.scale_loss(logit_loss,optimizer) as scaled_loss:
            #     scaled_loss.backward()
            optimizer.step()
            lr = optimizer.param_groups[0]['lr']
            bar.set_description(
                " {} lr {} [Loss(Paper)] {:.5f} [SSIM({})] {:.5f}".format(
                    ep, lr, np.mean(Loss_list), FLAGS["gauss_kernel"],
                    np.mean(Stat_list)))

            # Record Training information in Tensorboard.
            # if origin_img is None and uv_map_gt is None:
            #     origin_img, uv_map_gt = origin, uv_map
            # uv_map_predicted = uv_map_result

            #写入Tensorboard
            # FLAGS["summary_step"] += 1
            # if  FLAGS["summary_step"] % 500 ==0:
            #     writer.add_scalar("Original Loss", Loss_list[-1], FLAGS["summary_step"])
            #     writer.add_scalar("SSIM Loss", Stat_list[-1], FLAGS["summary_step"])

            #     grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted)

            #     writer.add_image('original', grid_1, FLAGS["summary_step"])
            #     writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            #     writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            #     writer.add_graph(model, uv_map)

        #每个epoch过后将Loss写入Tensorboard
        loss_mean = np.mean(Loss_list)
        writer.add_scalar("Original Loss", loss_mean, ep)

        lr = optimizer.param_groups[0]['lr']
        writer.add_scalar("lr", lr, ep)
        # scheduler_StepLR.step()
        scheduler.step(loss_mean)

        del Loss_list
        del Stat_list

        #Test && Cal AFLW2000's NME
        model.eval()
        if ep % FLAGS["save_interval"] == 0:
            with torch.no_grad():
                nme_mean = cal_aflw2000_nme(
                    model, '/home/beitadoge/Data/PRNet_PyTorch_Data/AFLW2000')
                print("NME IS {}".format(nme_mean))

                writer.add_scalar("Aflw2000_nme", nme_mean, ep)

                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = cv2.imread("./test_data/obama_uv_posmap.jpg")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)
                origin_in = F.normalize(origin, FLAGS["normalize_mean"],
                                        FLAGS["normalize_std"],
                                        False).unsqueeze_(0)
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['model_path'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # # Save model
            # state = {
            #     'prnet': model.state_dict(),
            #     'Loss': Loss_list,
            #     'start_epoch': ep,
            # }
            # torch.save(checkpoint, os.path.join(FLAGS['model_path'], 'epoch{}.pth'.format(ep)))
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                # 'amp': amp.state_dict(),
                'start_epoch': ep
            }
            torch.save(checkpoint,
                       os.path.join(FLAGS['model_path'], 'lastest.pth'))

        # adjust_learning_rate(lr , ep, optimizer)
        # scheduler.step(nme_mean)

    writer.close()
예제 #20
0
from utils.foreground_background import estimate_foreground_background
from utils.utils import save_image, show_image, stack_alpha
import matplotlib.pyplot as plt

if __name__ == '__main__':
    file_name = 'elephant.png'
    alpha_sub_dir = 'test/Trimap1/'

    alpha_dir = './out/'
    image_dir = './data/input_lowres/'
    out_dir = './out/cut_out/'

    image = plt.imread(image_dir + file_name)
    alpha = plt.imread(alpha_dir + alpha_sub_dir + file_name)

    foreground, background = estimate_foreground_background(image,
                                                            alpha,
                                                            print_info=True)

    # Make new image from foreground and alpha
    cutout = stack_alpha(foreground, alpha)

    # save
    save_image(cutout, out_dir + alpha_sub_dir, file_name)

    # show
    # show_image(cutout)
예제 #21
0
def main_generate_datapoints():
    r"""
    Run the main script to generate the dataset. The output of generate_bins.py should be existing in ./raw/bins.pickle
    :return:
    """
    start_time = time.time()
    datasets = [defaultdict(dict), defaultdict(dict), defaultdict(dict), defaultdict(dict)]
    bins_dict, (N_X_BINS, N_Y_BINS) = pickle.load(open(f"raw/bins.pickle", "rb"))  # load pre-generated bins of road segment
    tot_count = 0  # number of datapoints proposal explored
    count = 0  # number of accepted datapoints
    count_split = [0, 0, 0]  # number of accepted datapoints per split
    count_augment = 0  # number of accepted datapoints when using rotate and flip augmentation
    this_square = Square(0., 0., 0., 0.)  # square object representing the current datapoint's coordinates
    
    # to generate the default dataset we use translation augmentation of factor x4 for both x and y dimensions
    for cx in np.arange(GLOBAL_MIN.x + STEP, GLOBAL_MAX.x - STEP, STEP / 4):
        
        if tot_count % 100 == 0:
            # log progress every now and then
            print(f"{int(100 * tot_count / N_POSSIBLE_DATAPOINTS)}%,"
                  f" {int(time.time() - start_time)}s, x={cx}/{GLOBAL_X_MAX} \t accepted:{count} / proposed:{tot_count}")
        
        # approximate to 5 decimals to avoid numerical instabilities,
        # convert coordinate to bin index and update Square representation
        cx = round(cx, 5)
        cx_idx = to_x_idx(cx)
        this_square.left = cx - (STEP / 2)
        this_square.right = cx + (STEP / 2)
        
        for cy in np.arange(GLOBAL_Y_MIN + STEP, GLOBAL_Y_MAX - STEP, STEP / 4):
            # get split for this map tile
            this_split = check_split(cx, cy)
            
            # check if square is in the conflict area between two splits, then skip datapoint
            if this_split == -1:
                continue
            
            # approximate to 5 decimals to avoid numerical instabilities,
            # convert coordinate to bin index and update Square representation
            tot_count += 1
            cy = round(cy, 5)
            cy_idx = to_y_idx(cy)
            this_square.up = cy + (STEP / 2)
            this_square.down = cy - (STEP / 2)
            
            this_point = Point(cx, cy)
            this_bin = Bin(cx_idx, cy_idx)
            
            # get all lines that could possibly intersect the current square from the bin where this data point lies
            lines = get_possible_lines(this_bin)
            
            # get only lines in this square, and handle lines intersecting the borders
            valid_lines, valid_id_roads = get_valid_lines(lines, this_square)
            
            # handle intersections by substituting intersecting lines with new lines terminating in the intersection
            # point
            valid_lines, valid_id_roads = handle_crossings(valid_lines, valid_id_roads)
            
            # change data format to nodes(x, y), and edges(node_a, node_b)
            nodes, edges = to_nodes_edges(valid_lines, valid_id_roads)
            
            # merge duplicate nodes
            nodes, edges, n_del = merge_duplicate_nodes(nodes, edges)
            
            # merge consecutive (almost) straight lines
            nodes, edges, n_del = merge_straight_lines(nodes, edges)
            
            # normalize coordinates to [-1, +1]
            nodes = normalize_coordinates(nodes, this_point)
            
            # filter for graph size
            if MIN_NUM_EDGES <= len(edges) <= MAX_NUM_EDGES and MIN_NUM_NODES <= len(nodes) <= MAX_NUM_NODES:
                longest_road = max(Counter(valid_id_roads).values())
                # optionally, filter out very long roads
                if longest_road < 10:
                    # get graph representation with adjacency lists per every node_id
                    adj_lists = edges_to_adjacency_lists(edges)
                    
                    # compute BFS, DFS in different formats
                    bfs_nodes, bfs_edges = generate_bfs(copy.deepcopy(nodes), adj_lists)
                    dfs_nodes, dfs_edges = generate_dfs(copy.deepcopy(nodes), adj_lists)
                    
                    # plot DFS/BFS
                    # for k in range(1, len(dfs_edges) + 1):
                    #     path_image = PATH_IMAGES[this_split] + "{:0>7d}_{}.png".format(count, k)
                    #     save_image_bfs(square_origin, dfs_edges[:k], path_image, plot_nodes=True)
                    
                    # plot network for current datapoint and save it
                    path_image = PATH_IMAGES[this_split] + "{:0>7d}.png".format(count)
                    save_image(nodes, edges, path_image)
                    
                    """
                    Optionally, plot the graphs before pre-processing or by coloring differently each road, and
                    higlighting nodes:
                    """
                    # # a) save before preprocessing
                    # path_image2 = PATH_IMAGES[this_split] + "extra_plots/" + "{:0>7d}b.png".format(count)
                    # save_image_by_lines(this_square, valid_lines, path_image2, id_roads=valid_id_roads, plot_nodes=True)
                    # # b) save with colored edges and nodes
                    # path_image3 = PATH_IMAGES[this_split] + "extra_plots/" + "{:0>7d}c.png".format(count)
                    # save_image_by_nodes_edges(UNIT_SQUARE, nodes, edges, path_image3, plot_nodes=True)
                    
                    # generate the representation of this datapoint, and store it in its dataset split dictionary
                    current_dict = generate_datapoint(count, count, nodes, edges, adj_lists, bfs_edges, bfs_nodes,
                                                      dfs_edges, dfs_nodes, this_split, this_point)
                    datasets[this_split][count] = current_dict
                    
                    # If this datapoint belongs to the training set, possibly augment with flip and rotation.
                    # Then, generate the datapoint and save the image with semantic segmentation.
                    # The type of augmentation used is stored as an attribute in data points in the augment split.
                    # if this_split == 0:
                    #     nodes_list = augment(nodes)
                    #     for id_augment, nodes in enumerate(nodes_list):
                    #         # plot augmented datapoint
                    #         path_image = PATH_IMAGES[3] + "{:0>7d}.png".format(count_augment)
                    #         save_image(nodes, edges, path_image)
                    #
                    #         # generate the representation of this datapoint,
                    #         # and store it in its dataset split dictionary
                    #         current_dict = generate_datapoint(count_augment, count, nodes, edges, adj_lists, bfs_edges,
                    #                                           bfs_nodes, dfs_edges, dfs_nodes, 3, this_point,
                    #                                           id_augment=id_augment)
                    #         datasets[3][count_augment] = current_dict
                    #         count_augment += 1
                    
                    count += 1
                    count_split[this_split] += 1
    
    # finally save all the splits in the generated dataset and plot the size of each
    save_dataset(datasets, PATH_FILES)
    print(
        f"Final count: {count} | augmented: {count_augment} | count per split: {count_split} | squares explored: {tot_count}")
예제 #22
0
 def save_test_images(self, idx):
     save_image(tensor2im(self.real_A),
                self.opt.img_save_path + f"/img_{idx:04d}_real_A.png")
     save_image(tensor2im(self.rec_A),
                self.opt.img_save_path + f"/img_{idx:04d}_rec_A.png")
     save_image(tensor2im(self.fake_B),
                self.opt.img_save_path + f"/img_{idx:04d}_trans_A2B.png")
     save_image(tensor2im(self.real_B),
                self.opt.img_save_path + f"/img_{idx:04d}_real_B.png")
     save_image(tensor2im(self.rec_B),
                self.opt.img_save_path + f"/img_{idx:04d}_rec_B.png")
     save_image(tensor2im(self.fake_A),
                self.opt.img_save_path + f"/img_{idx:04d}_trans_B2A.png")
예제 #23
0
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=4)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256()
    discriminator = Discriminator1()

    # Load the pre-trained weight
    if FLAGS['resume'] and os.path.exists(
            os.path.join(FLAGS['images'], "latest.pth")):
        state = torch.load(os.path.join(FLAGS['images'], "latest.pth"))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to("cuda")
    discriminator.to("cuda")

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to("cuda")

    #Loss function for adversarial
    Tensor = torch.cuda.FloatTensor
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        Loss_list, Stat_list = [], []
        Loss_list_D = []
        for i, sample in enumerate(bar):
            #drop the last batch
            if i == 111:
                break
            # the dimension of the uv_map is 16 3 256 256 but the last sample is 10 3 256 256
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])
            # generate fake label
            fake_label = Variable(torch.zeros([16, 512])).cuda()
            # generate real lable
            real_label = Variable(torch.ones([16, 512])).cuda()
            # Sample noise as generator input
            #z = Variable(Tensor(np.random.normal(0, 1,([16, 3, 256, 256]))))

            # Inference.
            uv_map_result = model(origin)

            # Loss & ssim stat.
            # Loss measures generator's ability to fool the discriminator
            logit = bce_loss(discriminator(uv_map_result), real_label)

            stat_logit = stat_loss(uv_map_result, uv_map)

            # Measure discriminator's ability to classify real from generated samples
            real_loss = bce_loss(discriminator(uv_map), real_label)
            fake_loss = bce_loss(
                discriminator(uv_map_result).detach(), fake_label)
            d_loss = (real_loss + fake_loss) / 2

            # Record Loss.
            Loss_list.append(logit.item())
            Loss_list_D.append(d_loss.item())
            Stat_list.append(stat_logit.item())

            # Update.

            optimizer.zero_grad()
            logit.backward()
            optimizer.step()
            bar.set_description(
                " {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(
                    ep, Loss_list[-1], Loss_list_D[-1], FLAGS["gauss_kernel"],
                    Stat_list[-1]))

            optimizer_D.zero_grad()
            d_loss.backward(retain_graph=True)
            optimizer_D.step()
            # Record Training information in Tensorboard.
            if origin_img is None and uv_map_gt is None:
                origin_img, uv_map_gt = origin, uv_map
            uv_map_predicted = uv_map_result

            writer.add_scalar("Original Loss", Loss_list[-1],
                              FLAGS["summary_step"])
            writer.add_scalar("D Loss", Loss_list_D[-1], FLAGS["summary_step"])
            writer.add_scalar("SSIM Loss", Stat_list[-1],
                              FLAGS["summary_step"])

            grid_1, grid_2, grid_3 = make_grid(
                origin_img, normalize=True), make_grid(uv_map_gt), make_grid(
                    uv_map_predicted)

            writer.add_image('original', grid_1, FLAGS["summary_step"])
            writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            writer.add_graph(model, uv_map)

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'prnet': model.state_dict(),
                'Loss': Loss_list,
                'start_epoch': ep,
                'Loss_D': Loss_list_D,
            }
            torch.save(state, os.path.join(FLAGS['images'], 'latest.pth'))

            scheduler.step()

    writer.close()
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((416, 416)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256(resolution_input=416,
                      resolution_output=416,
                      channel=3,
                      size=16)
    discriminator = Discriminator()

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(FLAGS["device"])
    discriminator.to(FLAGS["device"])

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])

    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_G, stat_list = [], []
        loss_list_D = []
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            optimizer.zero_grad()
            uv_map_result = model(origin)

            # Update D
            optimizer_D.zero_grad()
            fake_detach = uv_map_result.detach()
            d_fake = discriminator(fake_detach)
            d_real = discriminator(uv_map)
            retain_graph = False
            if FLAGS['gan_type'] == 'GAN':
                loss_d = bce_loss(d_real, d_fake)
            elif FLAGS['gan_type'].find('WGAN') >= 0:
                loss_d = (d_fake - d_real).mean()
                if FLAGS['gan_type'].find('GP') >= 0:
                    epsilon = torch.rand(fake_detach.shape[0]).view(
                        -1, 1, 1, 1)
                    epsilon = epsilon.to(fake_detach.device)
                    hat = fake_detach.mul(1 - epsilon) + uv_map.mul(epsilon)
                    hat.requires_grad = True
                    d_hat = discriminator(hat)
                    gradients = torch.autograd.grad(outputs=d_hat.sum(),
                                                    inputs=hat,
                                                    retain_graph=True,
                                                    create_graph=True,
                                                    only_inputs=True)[0]
                    gradients = gradients.view(gradients.size(0), -1)
                    gradient_norm = gradients.norm(2, dim=1)
                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
                    loss_d += gradient_penalty
            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
            elif FLAGS['gan_type'] == 'RGAN':
                better_real = d_real - d_fake.mean(dim=0, keepdim=True)
                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
                loss_d = bce_loss(better_real, better_fake)
                retain_graph = True

            if discriminator.training:
                loss_list_D.append(loss_d.item())
                loss_d.backward(retain_graph=retain_graph)
                optimizer_D.step()

                if 'WGAN' in FLAGS['gan_type']:
                    for p in discriminator.parameters():
                        p.data.clamp_(-1, 1)

            # Update G
            d_fake_bp = discriminator(
                uv_map_result)  # for backpropagation, use fake as it is
            if FLAGS['gan_type'] == 'GAN':
                label_real = torch.ones_like(d_fake_bp)
                loss_g = bce_loss(d_fake_bp, label_real)
            elif FLAGS['gan_type'].find('WGAN') >= 0:
                loss_g = -d_fake_bp.mean()
            elif FLAGS['gan_type'] == 'RGAN':
                better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
                better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
                loss_g = bce_loss(better_fake, better_real)

            loss_g.backward()
            loss_list_G.append(loss_g.item())
            optimizer.step()

            stat_logit = stat_loss(uv_map_result, uv_map)
            stat_list.append(stat_logit.item())
            #bar.set_description(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(ep, loss_list_G[-1], loss_list_D[-1],FLAGS["gauss_kernel"], stat_list[-1]))
            # Record Training information in Tensorboard.
            """
            if origin_img is None and uv_map_gt is None:
                origin_img, uv_map_gt = origin, uv_map
            uv_map_predicted = uv_map_result

            writer.add_scalar("Original Loss", loss_list_G[-1], FLAGS["summary_step"])
            writer.add_scalar("D Loss", loss_list_D[-1], FLAGS["summary_step"])
            writer.add_scalar("SSIM Loss", stat_list[-1], FLAGS["summary_step"])

            grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted)

            writer.add_image('original', grid_1, FLAGS["summary_step"])
            writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            writer.add_graph(model, uv_map)
            """

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(
                    ep, loss_list_G[-1], loss_list_D[-1],
                    FLAGS["gauss_kernel"], stat_list[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'prnet': model.state_dict(),
                'Loss': loss_list_G,
                'start_epoch': ep,
                'Loss_D': loss_list_D,
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()

    writer.close()
def main(data_dir):
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((416, 416)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    g_x = ResFCN256(resolution_input=416,
                    resolution_output=416,
                    channel=3,
                    size=16)
    g_y = ResFCN256(resolution_input=416,
                    resolution_output=416,
                    channel=3,
                    size=16)
    d_x = Discriminator()
    d_y = Discriminator()

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        try:
            g_x.load_state_dict(state['g_x'])
            g_y.load_state_dict(state['g_y'])
            d_x.load_state_dict(state['d_x'])
            d_y.load_state_dict(state['d_y'])
        except Exception:
            g_x.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    g_x.to(FLAGS["device"])
    g_y.to(FLAGS["device"])
    d_x.to(FLAGS["device"])
    d_y.to(FLAGS["device"])

    optimizer_g = torch.optim.Adam(itertools.chain(g_x.parameters(),
                                                   g_y.parameters()),
                                   lr=FLAGS["lr"],
                                   betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(itertools.chain(d_x.parameters(),
                                                   d_y.parameters()),
                                   lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])
    l1_loss = nn.L1Loss().to(FLAGS["device"])
    lambda_X = 10
    lambda_Y = 10
    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_cycle_x = []
        loss_list_cycle_y = []
        loss_list_d_x = []
        loss_list_d_y = []
        real_label = torch.ones(FLAGS['batch_size'])
        fake_label = torch.zeros(FLAGS['batch_size'])
        for i, sample in enumerate(bar):
            real_y, real_x = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])
            # x -> y' -> x^
            optimizer_g.zero_grad()
            fake_y = g_x(real_x)
            prediction = d_x(fake_y)
            loss_g_x = bce_loss(prediction, real_label)
            x_hat = g_y(fake_y)
            loss_cycle_x = l1_loss(x_hat, real_x) * lambda_X
            loss_x = loss_g_x + loss_cycle_x
            loss_x.backward(retain_graph=True)
            optimizer_g.step()
            loss_list_cycle_x.append(loss_x.item())
            # y -> x' -> y^
            optimizer_g.zero_grad()
            fake_x = g_y(real_y)
            prediction = d_y(fake_x)
            loss_g_y = bce_loss(prediction, real_label)
            y_hat = g_x(fake_x)
            loss_cycle_y = l1_loss(y_hat, real_y) * lambda_Y
            loss_y = loss_g_y + loss_cycle_y
            loss_y.backward(retain_graph=True)
            optimizer_g.step()
            loss_list_cycle_y.append(loss_y.item())
            # d_x
            optimizer_d.zero_grad()
            pred_real = d_x(real_y)
            loss_d_x_real = bce_loss(pred_real, real_label)
            pred_fake = d_x(fake_y)
            loss_d_x_fake = bce_loss(pred_fake, fake_label)
            loss_d_x = (loss_d_x_real + loss_d_x_fake) * 0.5
            loss_d_x.backward()
            loss_list_d_x.append(loss_d_x.item())
            optimizer_d.step()
            if 'WGAN' in FLAGS['gan_type']:
                for p in d_x.parameters():
                    p.data.clamp_(-1, 1)
            # d_y
            optimizer_d.zero_grad()
            pred_real = d_y(real_x)
            loss_d_y_real = bce_loss(pred_real, real_label)
            pred_fake = d_y(fake_x)
            loss_d_y_fake = bce_loss(pred_fake, fake_label)
            loss_d_y = (loss_d_y_real + loss_d_y_fake) * 0.5
            loss_d_y.backward()
            loss_list_d_y.append(loss_d_y.item())
            optimizer_d.step()
            if 'WGAN' in FLAGS['gan_type']:
                for p in d_y.parameters():
                    p.data.clamp_(-1, 1)

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(
                    " {} [Loss_G_X] {} [Loss_G_Y] {} [Loss_D_X] {} [Loss_D_Y] {}"
                    .format(ep, loss_list_g_x[-1], loss_list_g_y[-1],
                            loss_list_d_x[-1], loss_list_d_y[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = g_x(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'g_x': g_x.state_dict(),
                'g_y': g_y.state_dict(),
                'd_x': d_x.state_dict(),
                'd_y': d_y.state_dict(),
                'start_epoch': ep,
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()
예제 #26
0
def main(argv=None):

    # 1.数据占位符
    # 训练图片输入
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    # 训练集标签
    label = tf.placeholder(tf.int32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                           name="label")
    # dropout比例
    dropout_keep_probability = tf.placeholder(tf.float32,
                                              name="dropout_keep_probability")

    # 2.模型
    # pred_label为预测出来的分割图,logits是用来计算损失来迭代优化
    pred_label, logits = model(image, dropout_keep_probability)

    # print(pred_label.get_shape())   # (?, ?, ?, 1)
    # print('*'*20)
    # print(logits.get_shape())       # (?, ?, ?, 151)
    # print('*' * 20)
    # print(label.get_shape())        # (?, 224, 224, 1)
    # print(tf.squeeze(label, squeeze_dims=[3]).get_shape())      # (?, 224, 224)
    # print('*' * 20)

    # 3.loss
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(label, squeeze_dims=[3]),
        name="entropy")))

    # 4.优化
    # trainable_var = tf.trainable_variables()    # 返回需要训练的变量列表
    # optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    # # apply_gradients和compute_gradients是所有的优化器都有的方法。为梯度修剪主要避免训练梯度爆炸和消失问题
    # # ## minimize()的第一步,返回(gradient, variable)对的list。
    # grads = optimizer.compute_gradients(loss, var_list=trainable_var)
    # # ## minimize()的第二部分,返回一个执行梯度更新的ops。
    # train_op = optimizer.apply_gradients(grads)

    train_op = tf.train.AdagradOptimizer(FLAGS.learning_rate).minimize(loss)

    # 5.准确率,使用测试训练集loss

    # 6.在tensorboard中画图
    loss_summary = tf.summary.scalar("loss", loss)

    summary_op = tf.summary.merge_all()

    # 7.创建一个保存模型的saver
    saver = tf.train.Saver()

    # 8.定义一个初始化变量op
    variable_op = tf.global_variables_initializer()

    # 9.训练或预测
    with tf.Session() as sess:

        # 9.1初始化所有变量
        sess.run(variable_op)

        # 9.2定义存储tensorboard的文件位置
        tensorboard_writer = tf.summary.FileWriter(
            FLAGS.tensorboard_dir + 'train/', sess.graph)

        # 9.3获取真实数据
        print("Setting up image reader...")
        train_records, valid_records = scene_parsing.read_dataset(
            FLAGS.data_dir)
        print(len(train_records))
        print(len(valid_records))

        print("Setting up dataset reader")
        image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
        if FLAGS.is_train == 1:
            train_dataset_reader = dataset.BatchDatset(train_records,
                                                       image_options)
            print("训练集数据准备完毕")
        validation_dataset_reader = dataset.BatchDatset(
            valid_records, image_options)
        print("测试集数据准备完毕")

        # 9.4训练
        if FLAGS.is_train == 1:
            print("开始迭代训练")
            for itr in range(MAX_ITERATION):
                train_images, train_annotations = train_dataset_reader.next_batch(
                    FLAGS.batch_size)
                feed_dict = {
                    image: train_images,
                    label: train_annotations,
                    dropout_keep_probability: 0.85
                }

                sess.run(train_op, feed_dict=feed_dict)

                # 9.4.1训练集误差
                if itr % 10 == 0:
                    train_loss, summary_str = sess.run([loss, loss_summary],
                                                       feed_dict=feed_dict)
                    print("Step: %d, Train_loss:%g" % (itr, train_loss))
                    tensorboard_writer.add_summary(summary_str, itr)

                # 9.4.2测试集误差
                if itr % 500 == 0:
                    valid_images, valid_annotations = validation_dataset_reader.next_batch(
                        FLAGS.batch_size)
                    valid_loss, summary_sva = sess.run(
                        [loss, loss_summary],
                        feed_dict={
                            image: valid_images,
                            label: valid_annotations,
                            dropout_keep_probability: 1.0
                        })
                    print("%s ---> Validation_loss: %g" %
                          (datetime.datetime.now(), valid_loss))

                    # add validation loss to TensorBoard
                    tensorboard_writer.add_summary(summary_sva, itr)

            # 9.4.3保存模型
            saver.save(sess, FLAGS.save_model_dir + "model.ckpt")

        # 做预测验证
        else:
            # 使用测试集来做预测,查看模型的正确性
            valid_images, valid_annotations = validation_dataset_reader.get_random_batch(
                FLAGS.batch_size)
            pred = sess.run(pred_label,
                            feed_dict={
                                image: valid_images,
                                label: valid_annotations,
                                dropout_keep_probability: 1.0
                            })
            # 目标分割图
            valid_annotations = np.squeeze(valid_annotations, axis=3)
            # 预测分割图
            pred = np.squeeze(pred, axis=3)

            for itr in range(FLAGS.batch_size):
                # 原始图
                utils.save_image(valid_images[itr].astype(np.uint8),
                                 FLAGS.verify_dir,
                                 name="inp_" + str(5 + itr))
                # 目标分割图
                utils.save_image(valid_annotations[itr].astype(np.uint8),
                                 FLAGS.verify_dir,
                                 name="gt_" + str(5 + itr))
                # 预测分割图
                utils.save_image(pred[itr].astype(np.uint8),
                                 FLAGS.verify_dir,
                                 name="pred_" + str(5 + itr))
                print("Saved image: %d" % itr)

    return None
예제 #27
0
    def display_current_results(self, visuals, epoch, save_result):
        if self.display_id > 0:  # show images in the browser
            ncols = self.opt.display_single_pane_ncols
            if ncols > 0:
                h, w = next(iter(visuals.values())).shape[:2]
                table_css = """<style>
                        table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
                        table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
                        </style>""" % (w, h)
                title = self.name
                label_html = ''
                label_html_row = ''
                nrows = int(np.ceil(len(visuals.items()) / ncols))
                images = []
                idx = 0
                for label, image_numpy in visuals.items():
                    label_html_row += '<td>%s</td>' % label
                    images.append(image_numpy.transpose([2, 0, 1]))
                    idx += 1
                    if idx % ncols == 0:
                        label_html += '<tr>%s</tr>' % label_html_row
                        label_html_row = ''
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
                while idx % ncols != 0:
                    images.append(white_image)
                    label_html_row += '<td></td>'
                    idx += 1
                if label_html_row != '':
                    label_html += '<tr>%s</tr>' % label_html_row
                # pane col = image row
                self.vis.images(images, nrow=ncols, win=self.display_id + 1,
                                padding=2, opts=dict(title=title + ' images'))
                label_html = '<table>%s</table>' % label_html
                self.vis.text(table_css + label_html, win=self.display_id + 2,
                              opts=dict(title=title + ' labels'))
            else:
                idx = 1
                for label, image_numpy in visuals.items():
                    self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
                                   win=self.display_id + idx)
                    idx += 1

        if self.use_html and (save_result or not self.saved):  # save images to a html file
            self.saved = True
            for label, image_numpy in visuals.items():
                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
                utils.save_image(image_numpy, img_path)
            # update website
            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
            for n in range(epoch, 0, -1):
                webpage.add_header('epoch [%d]' % n)
                ims = []
                txts = []
                links = []

                for label, image_numpy in visuals.items():
                    img_path = 'epoch%.3d_%s.png' % (n, label)
                    ims.append(img_path)
                    txts.append(label)
                    links.append(img_path)
                webpage.add_images(ims, txts, links, width=self.win_size)
            webpage.save()
예제 #28
0
i = 1
while True:
    try:
        rmse_all, costs, diff = model.test(sess, i, summary_writer=summary_writer, print_rate=200, sample_dir=sample_dir, meta=meta)
        for er, dif, k in zip(rmse_all, diff, range(len(global_rmse))):
            global_rmse[k] += er
            global_diff[k] += dif
        for cost, k in zip(global_costs, range(len(global_costs))):
            global_costs[k] += cost
        i += 1
    except tf.errors.OutOfRangeError:
        print('Number of steps: %d', i)
        for rmse, p, dif in zip(global_rmse, weather_params, global_diff):
            print("----------- Global RMSE of %s: %g" % (p, rmse/i))

            print('Saving global and mean diffs')
            save_image(dif, sample_dir, 'diff_global_%s' % p)
            save_image(dif/i, sample_dir, 'diff_mean_%s' % p)

        print('----------- Mean Generator cost (%g + %g) % ' % (global_costs[0]/i, global_costs[1]/i))
        print('----------- Mean Diskriminator/Critic cost (%g + %g) % ' % (global_costs[2]/i, global_costs[3]/i))
        break


print('Testo donezo')
#
# Shut everything down
#
# Wait for threads to finish.
sess.close()
예제 #29
0
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([ToTensor(),
                                                        ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])]))

    wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=4)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256()

    # Load the pre-trained weight
    if FLAGS['resume'] and os.path.exists(os.path.join(FLAGS['images'], "latest.pth")):
        state = torch.load(os.path.join(FLAGS['images'], "latest.pth"))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO("Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to("cuda")

    optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"], betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])

    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        Loss_list, Stat_list = [], []
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            uv_map_result = model(origin)

            # Loss & ssim stat.
            logit = loss(uv_map_result, uv_map)
            stat_logit = stat_loss(uv_map_result, uv_map)

            # Record Loss.
            Loss_list.append(logit.item())
            Stat_list.append(stat_logit.item())

            # Update.
            optimizer.zero_grad()
            logit.backward()
            optimizer.step()
            bar.set_description(" {} [Loss(Paper)] {} [SSIM({})] {}".format(ep, Loss_list[-1], FLAGS["gauss_kernel"], Stat_list[-1]))

            # Record Training information in Tensorboard.
            if origin_img is None and uv_map_gt is None:
                origin_img, uv_map_gt = origin, uv_map
            uv_map_predicted = uv_map_result

            writer.add_scalar("Original Loss", Loss_list[-1], FLAGS["summary_step"])
            writer.add_scalar("SSIM Loss", Stat_list[-1], FLAGS["summary_step"])

            grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted)

            writer.add_image('original', grid_1, FLAGS["summary_step"])
            writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            writer.add_graph(model, uv_map)

        if ep % FLAGS["save_interval"] == 0:
            with torch.no_grad():
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(origin), test_data_preprocess(gt_uv_map)

                # origin, gt_uv_map = transform_img(origin), transform_img(gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image([origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                           os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True)

            # Save model
            state = {
                'prnet': model.state_dict(),
                'Loss': Loss_list,
                'start_epoch': ep,
            }
            torch.save(state, os.path.join(FLAGS['images'], 'latest.pth'))

            scheduler.step()

    writer.close()
예제 #30
0
 def save_train_images(self, epoch):
     save_image(tensor2im(self.real_A),
                self.opt.img_save_path + f"/real_A_epoch_{epoch}.png")
     save_image(tensor2im(self.real_B),
                self.opt.img_save_path + f"/real_B_epoch_{epoch}.png")
     save_image(tensor2im(self.rec_A),
                self.opt.img_save_path + f"/rec_A_epoch_{epoch}.png")
     save_image(tensor2im(self.rec_B),
                self.opt.img_save_path + f"/rec_B_epoch_{epoch}.png")
     save_image(tensor2im(self.idt_A),
                self.opt.img_save_path + f"/idt_A_epoch_{epoch}.png")
     save_image(tensor2im(self.idt_B),
                self.opt.img_save_path + f"/idt_B_epoch_{epoch}.png")
예제 #31
0
import time

from IFM.ifm import information_flow_matting
from utils.utils import save_image, show_image

if __name__ == '__main__':
    file_name = 'pineapple.png'
    trimap_sub_dir = 'Trimap1'

    input_dir = './data/input_lowres/'
    trimap_dir = './data/trimap_lowres/'
    out_dir = './out/test/'

    time_start = time.time()
    # matting
    alpha_matte = information_flow_matting(
        input_dir + file_name, trimap_dir + trimap_sub_dir + '/' + file_name,
        (file_name != 'net.png' and file_name != 'plasticbag.png'))  # 人工区分高透明度
    time_end = time.time()
    print('cost: {:.2f}s'.format(time_end - time_start))
    # save
    save_image(alpha_matte, out_dir + trimap_sub_dir + '/', file_name, True)

    # show
    # show_image(alpha_matte)
예제 #32
0
def main():
    global args, best_result, output_directory

    print(torch.__version__)

    # set random seed
    torch.manual_seed(args.manual_seed)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        print(torch.version.cuda)
        print(torch.cuda.device_count())
        print(torch.cuda.is_available())
        print()

        args.batch_size = args.batch_size * torch.cuda.device_count()
    else:
        print("Let's use GPU ", torch.cuda.current_device())

    train_loader, val_loader = utils.create_loader(args)

    # load model
    if args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)

        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        optimizer = checkpoint['optimizer']

        # solve 'out of memory'
        model = checkpoint['model']

        print("=> loaded checkpoint (epoch {})".format(
            checkpoint['epoch']), flush=True)

        # clear memory
        del checkpoint
        # del model_dict
        torch.cuda.empty_cache()
    else:
        print("=> creating Model from scratch")
        model = FCRN.ResNet(layers=args.resnet_layers,
                            output_size=train_loader.dataset.output_size)
        start_epoch = 0

        # different modules have different learning rate
        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        if args.optim == 'sgd':
            optimizer = torch.optim.SGD(
                train_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        elif args.optim == 'adam':
            optimizer = torch.optim.Adam(
                train_params, lr=args.lr, weight_decay=args.weight_decay)
        else:
            assert(False, "{} optim not supported".format(args.optim))

        # You can use DataParallel() whether you use Multi-GPUs or not
        model = nn.DataParallel(model).cuda()

    attacker = None
    if args.adv_training:
        start_epoch = 0
        best_result.set_to_worst()
        if not args.attack:
            assert(False, "You must supply an attack for adversarial training")

        if args.attack == 'mifgsm':
            attacker = MIFGSM(model, "cuda:0", args.loss,
                              eps=mifgsm_params['eps'],
                              steps=mifgsm_params['steps'],
                              decay=mifgsm_params['decay'],
                              alpha=mifgsm_params['alpha'],
                              TI=mifgsm_params['TI'],
                              k_=mifgsm_params['k'],
                              targeted=args.targeted,
                              test=args.model)
        elif args.attack == 'pgd':
            attacker = PGD(model, "cuda:0", args.loss,
                           norm=pgd_params['norm'],
                           eps=pgd_params['eps'],
                           alpha=pgd_params['alpha'],
                           iters=pgd_params['iterations'],
                           TI=pgd_params['TI'],
                           k_=mifgsm_params['k'],
                           test=args.model)
        else:
            assert(False, "{} attack not supported".format(args.attack))

        print('performing adversarial training with {} attack and {} loss'.format(
            args.attack, args.loss), flush=True)
    else:
        print('performing standard training with {} loss'.format(args.loss))

    # when training, use reduceLROnPlateau to reduce learning rate
    if args.scheduler == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=args.lr_patience)
    elif args.scheduler == 'cyclic':
        scheduler = lr_scheduler.CyclicLR(
            optimizer, base_lr=args.lr, max_lr=args.lr * 100)
    elif args.scheduler == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=2, eta_min=0.000001, T_mult=2)
    else:
        scheduler = None

    # loss function
    if args.loss == 'l1':
        criterion = criteria.MaskedL1Loss()
    elif args.loss == 'l2':
        criterion = criteria.MaskedMSELoss()
    elif args.loss == 'berhu':
        criterion = criteria.berHuLoss()
    else:
        assert(False, '{} loss not supported'.format(args.loss))

    # create directory path
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    best_txt = os.path.join(output_directory, 'best.txt')
    config_txt = os.path.join(output_directory, 'config.txt')

    # write training parameters to config file
    if not os.path.exists(config_txt):
        with open(config_txt, 'w') as txtfile:
            args_ = vars(args)
            args_str = ''
            for k, v in args_.items():
                args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
            txtfile.write(args_str)

    # create log
    log_path = os.path.join(output_directory, 'logs',
                            datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    if os.path.isdir(log_path):
        shutil.rmtree(log_path)
    os.makedirs(log_path)

    # save every epoch if doing adversarial training
    save_every_epoch = args.adv_training

    for epoch in range(start_epoch, args.epochs):

        # remember change of the learning rate
        for i, param_group in enumerate(optimizer.param_groups):
            old_lr = float(param_group['lr'])

        train(train_loader, model, criterion, optimizer,
              epoch, attacker)  # train for one epoch
        # evaluate on validation set
        result, img_merge = validate(val_loader, model, epoch)

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}, rmse={:.3f}, rml={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, "
                    "t_gpu={:.4f}".
                    format(epoch, result.rmse, result.absrel, result.lg10, result.delta1, result.delta2,
                           result.delta3,
                           result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        # save checkpoint for each epoch
        utils.save_checkpoint({
            'args': args,
            'epoch': epoch,
            'model': model,
            'best_result': best_result,
            'optimizer': optimizer,
        }, is_best, epoch, output_directory, save_every_epoch)

        # when rml doesn't fall, reduce learning rate
        if scheduler is not None:
            if args.scheduler == 'plateau':
                scheduler.step(result.rmse)
            elif args.scheduler == 'cyclic':
                scheduler.step()
            elif args.scheduler == 'cosine':
                scheduler.step()
예제 #33
0
def main(data_dir):
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((256, 256)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256(resolution_input=256,
                      resolution_output=256,
                      channel=3,
                      size=16)

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(FLAGS["device"])

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])

    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_G, stat_list = [], []
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            optimizer.zero_grad()
            uv_map_result = model(origin)
            loss_g = bce_loss(uv_map_result, uv_map)
            loss_g.backward()
            loss_list_G.append(loss_g.item())
            optimizer.step()

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(" {} [BCE ({})]".format(ep, loss_list_G[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'prnet': model.state_dict(),
                'Loss': loss_list_G,
                'start_epoch': ep
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()
def reconstruct_image_from_representation(config):
    should_reconstruct_content = config['should_reconstruct_content']
    should_visualize_representation = config['should_visualize_representation']
    dump_path = os.path.join(config['output_img_dir'],
                             ('c' if should_reconstruct_content else 's') +
                             '_reconstruction_' + config['optimizer'])
    dump_path = os.path.join(
        dump_path, config['content_img_name'].split('.')[0] if
        should_reconstruct_content else config['style_img_name'].split('.')[0])
    os.makedirs(dump_path, exist_ok=True)

    content_img_path = os.path.join(config['content_images_dir'],
                                    config['content_img_name'])
    style_img_path = os.path.join(config['style_images_dir'],
                                  config['style_img_name'])
    img_path = content_img_path if should_reconstruct_content else style_img_path

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    img = utils.prepare_img(img_path, config['height'], device)

    gaussian_noise_img = np.random.normal(loc=0, scale=90.,
                                          size=img.shape).astype(np.float32)
    white_noise_img = np.random.uniform(-90., 90.,
                                        img.shape).astype(np.float32)
    init_img = torch.from_numpy(white_noise_img).float().to(device)
    optimizing_img = Variable(init_img, requires_grad=True)

    # indices pick relevant feature maps (say conv4_1, relu1_1, etc.)
    neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model(
        config['model'], device)

    # don't want to expose everything that's not crucial so some things are hardcoded
    num_of_iterations = {'adam': 3000, 'lbfgs': 350}

    set_of_feature_maps = neural_net(img)

    #
    # Visualize feature maps and Gram matrices (depending whether you're reconstructing content or style img)
    #
    if should_reconstruct_content:
        target_content_representation = set_of_feature_maps[
            content_feature_maps_index_name[0]].squeeze(axis=0)
        if should_visualize_representation:
            num_of_feature_maps = target_content_representation.size()[0]
            print(f'Number of feature maps: {num_of_feature_maps}')
            for i in range(num_of_feature_maps):
                feature_map = target_content_representation[i].to(
                    'cpu').numpy()
                feature_map = np.uint8(utils.get_uint8_range(feature_map))
                plt.imshow(feature_map)
                plt.title(
                    f'Feature map {i+1}/{num_of_feature_maps} from layer {content_feature_maps_index_name[1]} (model={config["model"]}) for {config["content_img_name"]} image.'
                )
                plt.show()
                filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
                utils.save_image(feature_map,
                                 os.path.join(dump_path, filename))
    else:
        target_style_representation = [
            utils.gram_matrix(fmaps)
            for i, fmaps in enumerate(set_of_feature_maps)
            if i in style_feature_maps_indices_names[0]
        ]
        if should_visualize_representation:
            num_of_gram_matrices = len(target_style_representation)
            print(f'Number of Gram matrices: {num_of_gram_matrices}')
            for i in range(num_of_gram_matrices):
                Gram_matrix = target_style_representation[i].squeeze(
                    axis=0).to('cpu').numpy()
                Gram_matrix = np.uint8(utils.get_uint8_range(Gram_matrix))
                plt.imshow(Gram_matrix)
                plt.title(
                    f'Gram matrix from layer {style_feature_maps_indices_names[1][i]} (model={config["model"]}) for {config["style_img_name"]} image.'
                )
                plt.show()
                filename = f'gram_{config["model"]}_{style_feature_maps_indices_names[1][i]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
                utils.save_image(Gram_matrix,
                                 os.path.join(dump_path, filename))

    #
    # Start of optimization procedure
    #
    if config['optimizer'] == 'adam':
        optimizer = Adam((optimizing_img, ))
        target_representation = target_content_representation if should_reconstruct_content else target_style_representation
        tuning_step = make_tuning_step(neural_net, optimizer,
                                       target_representation,
                                       should_reconstruct_content,
                                       content_feature_maps_index_name[0],
                                       style_feature_maps_indices_names[0])
        for it in range(num_of_iterations[config['optimizer']]):
            loss, _ = tuning_step(optimizing_img)
            with torch.no_grad():
                print(
                    f'Iteration: {it}, current {"content" if should_reconstruct_content else "style"} loss={loss:10.8f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    it,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
    elif config['optimizer'] == 'lbfgs':
        cnt = 0

        # closure is a function required by L-BFGS optimizer
        def closure():
            nonlocal cnt
            optimizer.zero_grad()
            loss = 0.0
            if should_reconstruct_content:
                loss = torch.nn.MSELoss(reduction='mean')(
                    target_content_representation, neural_net(optimizing_img)[
                        content_feature_maps_index_name[0]].squeeze(axis=0))
            else:
                current_set_of_feature_maps = neural_net(optimizing_img)
                current_style_representation = [
                    utils.gram_matrix(fmaps)
                    for i, fmaps in enumerate(current_set_of_feature_maps)
                    if i in style_feature_maps_indices_names[0]
                ]
                for gram_gt, gram_hat in zip(target_style_representation,
                                             current_style_representation):
                    loss += (1 / len(target_style_representation)
                             ) * torch.nn.MSELoss(reduction='sum')(gram_gt[0],
                                                                   gram_hat[0])
            loss.backward()
            with torch.no_grad():
                print(
                    f'Iteration: {cnt}, current {"content" if should_reconstruct_content else "style"} loss={loss.item()}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
                cnt += 1
            return loss

        optimizer = torch.optim.LBFGS(
            (optimizing_img, ),
            max_iter=num_of_iterations[config['optimizer']],
            line_search_fn='strong_wolfe')
        optimizer.step(closure)

    return dump_path
예제 #35
0
    def display_current_results(self, visuals, epoch, save_result):
        if self.display_id > 0:  # show images in the browser
            ncols = self.opt.display_single_pane_ncols
            if ncols > 0:
                h, w = next(iter(visuals.values())).shape[:2]
                table_css = """<style>
                        table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
                        table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
                        </style>""" % (w, h)
                title = self.name
                label_html = ''
                label_html_row = ''
                nrows = int(np.ceil(len(visuals.items()) / ncols))
                images = []
                idx = 0
                for label, image_numpy in visuals.items():
                    label_html_row += '<td>%s</td>' % label
                    images.append(image_numpy.transpose([2, 0, 1]))
                    idx += 1
                    if idx % ncols == 0:
                        label_html += '<tr>%s</tr>' % label_html_row
                        label_html_row = ''
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1
                                                                  ])) * 255
                while idx % ncols != 0:
                    images.append(white_image)
                    label_html_row += '<td></td>'
                    idx += 1
                if label_html_row != '':
                    label_html += '<tr>%s</tr>' % label_html_row
                # pane col = image row
                self.vis.images(images,
                                nrow=ncols,
                                win=self.display_id + 1,
                                padding=2,
                                opts=dict(title=title + ' images'))
                label_html = '<table>%s</table>' % label_html
                self.vis.text(table_css + label_html,
                              win=self.display_id + 2,
                              opts=dict(title=title + ' labels'))
            else:
                idx = 1
                for label, image_numpy in visuals.items():
                    self.vis.image(image_numpy.transpose([2, 0, 1]),
                                   opts=dict(title=label),
                                   win=self.display_id + idx)
                    idx += 1

        if self.use_html and (save_result
                              or not self.saved):  # save images to a html file
            self.saved = True
            for label, image_numpy in visuals.items():
                img_path = os.path.join(self.img_dir,
                                        'epoch%.3d_%s.png' % (epoch, label))
                utils.save_image(image_numpy, img_path)
            # update website
            webpage = html.HTML(self.web_dir,
                                'Experiment name = %s' % self.name,
                                reflesh=1)
            for n in range(epoch, 0, -1):
                webpage.add_header('epoch [%d]' % n)
                ims = []
                txts = []
                links = []

                for label, image_numpy in visuals.items():
                    img_path = 'epoch%.3d_%s.png' % (n, label)
                    ims.append(img_path)
                    txts.append(label)
                    links.append(img_path)
                webpage.add_images(ims, txts, links, width=self.win_size)
            webpage.save()
예제 #36
0
def main(opts):
    # Create the data loader
    # loader = sunnerData.DataLoader(sunnerData.ImageDataset(
    #     root=[[opts.path]],
    #     transforms=transforms.Compose([
    #         sunnertransforms.Resize((1024, 1024)),
    #         sunnertransforms.ToTensor(),
    #         sunnertransforms.ToFloat(),
    #         #sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
    #         sunnertransforms.Normalize(),
    #     ])),
    #     batch_size=opts.batch_size,
    #     shuffle=True,
    # )
    loader = data_loader(opts.path)

    device = fluid.CUDAPlace(0) if opts.device == 'GPU' else fluid.CPUPlace(0)
    with fluid.dygraph.guard(device):
        # Create the model
        start_epoch = 0
        G = StyleGenerator()
        D = StyleDiscriminator()

        # Load the pre-trained weight
        if os.path.exists(opts.resume):
            INFO("Load the pre-trained weight!")
            #state = fluid.dygraph.load_dygraph(opts.resume)
            state = load_checkpoint(opts.resume)
            G.load_dict(state['G'])
            D.load_dict(state['D'])
            start_epoch = state['start_epoch']
        else:
            INFO(
                "Pre-trained weight cannot load successfully, train from scratch!"
            )

        # # Multi-GPU support
        # if torch.cuda.device_count() > 1:
        #     INFO("Multiple GPU:" + str(torch.cuda.device_count()) + "\t GPUs")
        #     G = nn.DataParallel(G)
        #     D = nn.DataParallel(D)

        scheduler_D = exponential_decay(learning_rate=0.00001,
                                        decay_steps=1000,
                                        decay_rate=0.99)
        scheduler_G = exponential_decay(learning_rate=0.00001,
                                        decay_steps=1000,
                                        decay_rate=0.99)
        optim_D = optim.Adam(parameter_list=D.parameters(),
                             learning_rate=scheduler_D)
        optim_G = optim.Adam(parameter_list=G.parameters(),
                             learning_rate=scheduler_G)

        # Train
        fix_z = np.random.randn(opts.batch_size, 512)
        fix_z = dygraph.to_variable(fix_z)
        softplus = SoftPlus()
        Loss_D_list = [0.0]
        Loss_G_list = [0.0]
        D.train()
        G.train()
        for ep in range(start_epoch, opts.epoch):
            bar = tqdm(loader())
            loss_D_list = []
            loss_G_list = []
            for i, data in enumerate(bar):
                # =======================================================================================================
                #   (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                # =======================================================================================================
                # Compute adversarial loss toward discriminator
                real_img = np.array([item for item in data],
                                    dtype='float32').reshape(
                                        (-1, 3, 1024, 1024))

                D.clear_gradients()
                real_img = dygraph.to_variable(real_img)
                real_logit = D(real_img)

                z = np.float32(np.random.randn(real_img.shape[0], 512))
                fake_img = G(dygraph.to_variable(z))
                fake_logit = D(fake_img)

                d_loss = layers.mean(softplus(fake_logit))
                d_loss = d_loss + layers.mean(softplus(-real_logit))

                if opts.r1_gamma != 0.0:
                    r1_penalty = R1Penalty(real_img, D)
                    d_loss = d_loss + r1_penalty * (opts.r1_gamma * 0.5)

                if opts.r2_gamma != 0.0:
                    r2_penalty = R2Penalty(fake_img, D)
                    d_loss = d_loss + r2_penalty * (opts.r2_gamma * 0.5)

                loss_D_list.append(d_loss.numpy())

                # Update discriminator
                d_loss.backward()
                optim_D.minimize(d_loss)

                # =======================================================================================================
                #   (2) Update G network: maximize log(D(G(z)))
                # =======================================================================================================
                if i % CRITIC_ITER == 0:
                    G.clear_gradients()
                    fake_logit = D(fake_img.detach())
                    g_loss = layers.mean(softplus(-fake_logit))
                    #print("g_loss",g_loss)
                    loss_G_list.append(g_loss.numpy())

                    # Update generator
                    g_loss.backward()
                    optim_G.minimize(g_loss)

                # Output training stats
                bar.set_description("Epoch {} [{}, {}] [G]: {} [D]: {}".format(
                    ep, i + 1, 52000, loss_G_list[-1], loss_D_list[-1]))

            # Save the result
            Loss_G_list.append(np.mean(loss_G_list))
            Loss_D_list.append(np.mean(loss_D_list))

            # Check how the generator is doing by saving G's output on fixed_noise
            G.eval()
            #fake_img = G(fix_z).detach().cpu()
            fake_img = G(fix_z).numpy().squeeze()
            log(f"fake_img.shape: {fake_img.shape}")
            save_image(fake_img,
                       os.path.join(opts.det, 'images',
                                    str(ep) + '.png'))
            G.train()

            # Save model
            # print("type:",type(G.state_dict()).__name__)
            # print("type:",type(D.state_dict()).__name__)
            states = {
                'G': G.state_dict(),
                'D': D.state_dict(),
                'Loss_G': Loss_G_list,
                'Loss_D': Loss_D_list,
                'start_epoch': ep,
            }
            #dygraph.save_dygraph(state, os.path.join(opts.det, 'models', 'latest'))
            save_checkpoint(states,
                            os.path.join(opts.det, 'models', 'latest.pp'))
            # scheduler_D.step()
            # scheduler_G.step()

        # Plot the total loss curve
        Loss_D_list = Loss_D_list[1:]
        Loss_G_list = Loss_G_list[1:]
        plotLossCurve(opts, Loss_D_list, Loss_G_list)