Beispiel #1
0
def naive_train(cfg, logger, distributed, local_rank):
    trainer = Trainer(cfg, distributed, local_rank)
    dataset = prepare_multiple_dataset(cfg, logger)
    if cfg.MODEL.PRETRAIN_CHOICE == 'finetune':
        load_checkpoint(trainer.encoder, cfg.MODEL.PRETRAIN_PATH)
    trainer.do_train(dataset)
    torch.cuda.empty_cache()
def inference(cfg, logger):
    testset = cfg.DATASETS.TEST[0]
    dataset = init_dataset(testset, root=cfg.DATASETS.ROOT_DIR)
    dataset.print_dataset_statistics(dataset.train, dataset.query,
                                     dataset.gallery, logger)
    test_loader = get_test_loader(dataset.query + dataset.gallery, cfg)

    model = Encoder(cfg).cuda()

    logger.info("loading model from {}".format(cfg.TEST.WEIGHT))
    load_checkpoint(model, cfg.TEST.WEIGHT)

    feats, pids, camids, img_paths = [], [], [], []
    with torch.no_grad():
        model.eval()
        for batch in test_loader:
            data, pid, camid, img_path = batch
            data = data.cuda()
            feat = model(data)
            if cfg.TEST.FLIP_TEST:
                data_flip = data.flip(dims=[3])  # NCHW
                feat_flip = model(data_flip)
                feat = (feat + feat_flip) / 2
            feats.append(feat)
            pids.extend(pid)
            camids.extend(camid)
            img_paths.extend(img_path)
    del model
    torch.cuda.empty_cache()

    feats = torch.cat(feats, dim=0)
    feats = torch.nn.functional.normalize(feats, dim=1, p=2)

    return [feats, pids, camids, img_paths], len(dataset.query)
def inference(cfg, logger):
    testset = 'visda20'
    #testset = 'personx'
    dataset = init_dataset(testset, root=cfg.DATASETS.ROOT_DIR)
    dataset.print_dataset_statistics(dataset.train, dataset.query,
                                     dataset.gallery, logger)

    dataset = dataset.train
    #dataset = dataset.query + dataset.gallery
    test_loader = get_test_loader(dataset, cfg)

    model = Encoder(cfg.MODEL.BACKBONE, cfg.MODEL.PRETRAIN_PATH,
                    cfg.MODEL.PRETRAIN_CHOICE).cuda()

    logger.info("loading model from {}".format(cfg.TEST.WEIGHT))
    load_checkpoint(model, cfg.TEST.WEIGHT)

    feats, pids, camids, img_paths = [], [], [], []
    with torch.no_grad():
        model.eval()
        for batch in test_loader:
            data, pid, camid, img_path = batch
            data = data.cuda()
            feat = model(data)
            feats.append(feat)
            pids.extend(pid)
            camids.extend(camid)
            img_paths.extend(img_path)
    del model
    torch.cuda.empty_cache()

    feats = torch.cat(feats, dim=0)
    feats = torch.nn.functional.normalize(feats, dim=1, p=2)
    return feats, dataset
def iterative_dbscan(cfg, logger):
    # trainer = Trainer(cfg)
    src_dataset = init_dataset('personx', root=cfg.DATASETS.ROOT_DIR)
    target_dataset = init_dataset('visda20', root=cfg.DATASETS.ROOT_DIR)

    dataset = BaseImageDataset()
    dataset.query = src_dataset.query
    dataset.gallery = src_dataset.gallery

    iteration = 8
    pseudo_label_dataset = []
    best_mAP = 0
    for i in range(iteration):
        dataset.train = merge_datasets(
            [src_dataset.train, pseudo_label_dataset])
        dataset.relabel_train()
        dataset.print_dataset_statistics(dataset.train, dataset.query,
                                         dataset.gallery, logger)

        trainer = Trainer(cfg)
        trainer
        # from scratch
        load_checkpoint(trainer.encoder.base, cfg.MODEL.PRETRAIN_PATH)
        trainer.encoder.reset_bn()
        #if os.path.exists(os.path.join(cfg.OUTPUT_DIR, 'best.pth')):
        #    load_checkpoint(trainer.encoder, os.path.join(cfg.OUTPUT_DIR, 'best.pth'))
        trainer.do_train(dataset)
        torch.cuda.empty_cache()
        test_loader = get_test_loader(target_dataset.train, cfg)
        feats = trainer.extract_feature(test_loader)
        pseudo_label_dataset = DBSCAN_cluster(feats, target_dataset.train,
                                              logger)
    def __init__(self, cfg):
        super(Encoder, self).__init__()
        self.base, self.in_planes = build_backbone(cfg.MODEL.BACKBONE)

        if cfg.MODEL.PRETRAIN_CHOICE == 'imagenet':
            load_checkpoint(self.base, cfg.MODEL.PRETRAIN_PATH)
            print('Loading pretrained ImageNet model......')

        if cfg.MODEL.POOLING == 'GeM':
            print('using GeM')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)
        if cfg.MODEL.REDUCE:
            self.bottleneck = nn.Sequential(nn.Linear(self.in_planes, cfg.MODEL.REDUCE_DIM),
                                           nn.BatchNorm1d(cfg.MODEL.REDUCE_DIM))
            self.in_planes = cfg.MODEL.REDUCE_DIM
        else:
            self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.apply(weights_init_kaiming)
Beispiel #6
0
    def __init__(self, cfg, distributed=False, local_rank=0):
        self.encoder = Encoder(cfg).cuda()
        self.cfg = cfg

        self.distributed = distributed
        self.local_rank = local_rank

        if cfg.MODEL.PRETRAIN_CHOICE == 'self':
            load_checkpoint(self.encoder, cfg.MODEL.PRETRAIN_PATH)

        self.logger = logging.getLogger("reid_baseline.train")
        self.best_mAP = 0
        self.use_memory_bank = False
        if cfg.MODEL.MEMORY_BANK:
            self.encoder_ema = Encoder(cfg).cuda()
            for param_q, param_k in zip(self.encoder.parameters(),
                                        self.encoder_ema.parameters()):
                param_k.requires_grad = False  # not update by gradient
                param_k.data.copy_(param_q.data)

            self.memory_bank = VanillaMemoryBank(self.encoder.in_planes,
                                                 cfg.MODEL.MEMORY_SIZE)
            self.use_memory_bank = True
Beispiel #7
0
def main(cfg, config_name):
    """
    Main training function: after preparing the data loaders, model, optimizer, and trainer,
    start with the training process.

    Args:
        cfg (dict): current configuration parameters
        config_name (str): path to the config file
    """

    # Create the output dir if it does not exist
    if not os.path.exists(cfg['misc']['log_dir']):
        os.makedirs(cfg['misc']['log_dir'])

    # Initialize the model
    model = config.get_model(cfg)
    model = model.cuda()

    # Get data loader
    train_loader = make_data_loader(cfg, phase='train')
    val_loader = make_data_loader(cfg, phase='val')

    # Log directory
    dataset_name = cfg["data"]["dataset"]

    now = datetime.now().strftime("%y_%m_%d-%H_%M_%S_%f")
    now += "__Method_" + str(cfg['method']['backbone'])
    now += "__Pretrained_" if cfg['network']['use_pretrained'] and cfg[
        'network']['pretrained_path'] else ''
    if cfg['method']['flow']: now += "__Flow_"
    if cfg['method']['ego_motion']: now += "__Ego_"
    if cfg['method']['semantic']: now += "__Sem_"
    now += "__Rem_Ground_" if cfg['data']['remove_ground'] else ''
    now += "__VoxSize_" + str(cfg['misc']["voxel_size"])
    now += "__Pts_" + str(cfg['misc']["num_points"])
    path2log = os.path.join(cfg['misc']['log_dir'], "logs_" + dataset_name,
                            now)

    logger, checkpoint_dir = prepare_logger(cfg, path2log)
    tboard_logger = SummaryWriter(path2log)

    # Output number of model parameters
    logger.info("Parameter Count: {:d}".format(n_model_parameters(model)))

    # Output torch and cuda version
    logger.info('Torch version: {}'.format(torch.__version__))
    logger.info('CUDA version: {}'.format(torch.version.cuda))

    # Save config file that was used for this experiment
    with open(os.path.join(path2log,
                           config_name.split(os.sep)[-1]), 'w') as outfile:
        yaml.dump(cfg, outfile, default_flow_style=False, allow_unicode=True)

    # Get optimizer and trainer
    optimizer = config.get_optimizer(cfg, model)
    scheduler = config.get_scheduler(cfg, optimizer)

    # Parameters determining the saving and validation interval (if positive denotes iteration if negative epoch)
    stat_interval = cfg['train']['stat_interval']
    stat_interval = stat_interval if stat_interval > 0 else abs(
        stat_interval * len(train_loader))

    chkpt_interval = cfg['train']['chkpt_interval']
    chkpt_interval = chkpt_interval if chkpt_interval > 0 else abs(
        chkpt_interval * len(train_loader))

    val_interval = cfg['train']['val_interval']
    val_interval = val_interval if val_interval > 0 else abs(val_interval *
                                                             len(train_loader))

    # if not a pretrained model epoch and iterations should be -1
    metric_val_best = np.inf
    running_metrics = {}
    running_losses = {}
    epoch_it = -1
    total_it = -1

    # Load the pretrained weights
    if cfg['network']['use_pretrained'] and cfg['network']['pretrained_path']:
        model, optimizer, scheduler, epoch_it, total_it, metric_val_best = load_checkpoint(
            model,
            optimizer,
            scheduler,
            filename=cfg['network']['pretrained_path'])

        # Find previous tensorboard files and copy them
        tb_files = glob.glob(
            os.sep.join(cfg['network']['pretrained_path'].split(os.sep)[:-1]) +
            '/events.*')
        for tb_file in tb_files:
            shutil.copy(tb_file,
                        os.path.join(path2log,
                                     tb_file.split(os.sep)[-1]))

    # Initialize the trainer
    device = torch.device('cuda' if (
        torch.cuda.is_available() and cfg['misc']['use_gpu']) else 'cpu')
    trainer = config.get_trainer(cfg, model, device)
    acc_iter_size = cfg['train']['acc_iter_size']

    # Training loop
    while epoch_it < cfg['train']['max_epoch']:
        epoch_it += 1
        lr = scheduler.get_last_lr()
        logger.info('Training epoch: {}, LR: {} '.format(epoch_it, lr))
        gc.collect()

        train_loader_iter = train_loader.__iter__()
        start = time.time()
        tbar = tqdm(total=len(train_loader) // acc_iter_size, ncols=100)

        for it in range(len(train_loader) // acc_iter_size):
            optimizer.zero_grad()
            total_it += 1
            batch_metrics = {}
            batch_losses = {}

            for iter_idx in range(acc_iter_size):

                batch = train_loader_iter.next()

                dict_all_to_device(batch, device)
                losses, metrics, total_loss = trainer.train_step(batch)

                total_loss.backward()

                # Save the running metrics and losses
                if not batch_metrics:
                    batch_metrics = copy.deepcopy(metrics)
                else:
                    for key, value in metrics.items():
                        batch_metrics[key] += value

                if not batch_losses:
                    batch_losses = copy.deepcopy(losses)
                else:
                    for key, value in losses.items():
                        batch_losses[key] += value

            # Compute the mean value of the metrics and losses of the batch
            for key, value in batch_metrics.items():
                batch_metrics[key] = value / acc_iter_size

            for key, value in batch_losses.items():
                batch_losses[key] = value / acc_iter_size

            optimizer.step()
            torch.cuda.empty_cache()

            tbar.set_description('Loss: {:.3g}'.format(
                batch_losses['total_loss']))
            tbar.update(1)

            # Save the running metrics and losses
            if not running_metrics:
                running_metrics = copy.deepcopy(batch_metrics)
            else:
                for key, value in batch_metrics.items():
                    running_metrics[key] += value

            if not running_losses:
                running_losses = copy.deepcopy(batch_losses)
            else:
                for key, value in batch_losses.items():
                    running_losses[key] += value

            # Logs
            if total_it % stat_interval == stat_interval - 1:
                # Print / save logs
                logger.info("Epoch {0:d} - It. {1:d}: loss = {2:.3f}".format(
                    epoch_it, total_it,
                    running_losses['total_loss'] / stat_interval))

                for key, value in running_losses.items():
                    tboard_logger.add_scalar("Train/{}".format(key),
                                             value / stat_interval, total_it)
                    # Reinitialize the values
                    running_losses[key] = 0

                for key, value in running_metrics.items():
                    tboard_logger.add_scalar("Train/{}".format(key),
                                             value / stat_interval, total_it)
                    # Reinitialize the values
                    running_metrics[key] = 0

                start = time.time()

            # Run validation
            if total_it % val_interval == val_interval - 1:
                logger.info("Starting the validation")
                val_losses, val_metrics = trainer.validate(val_loader)

                for key, value in val_losses.items():
                    tboard_logger.add_scalar("Val/{}".format(key), value,
                                             total_it)

                for key, value in val_metrics.items():
                    tboard_logger.add_scalar("Val/{}".format(key), value,
                                             total_it)

                logger.info(
                    "VALIDATION -It. {0:d}: total loss: {1:.3f}.".format(
                        total_it, val_losses['total_loss']))

                if val_losses['total_loss'] < metric_val_best:
                    metric_val_best = val_losses['total_loss']
                    logger.info('New best model (loss: {:.4f})'.format(
                        metric_val_best))

                    save_checkpoint(os.path.join(path2log, 'model_best.pt'),
                                    epoch=epoch_it,
                                    it=total_it,
                                    model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    config=cfg,
                                    best_val=metric_val_best)
                else:
                    save_checkpoint(os.path.join(
                        path2log, 'model_{}.pt'.format(total_it)),
                                    epoch=epoch_it,
                                    it=total_it,
                                    model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    config=cfg,
                                    best_val=val_losses['total_loss'])

        # After the epoch if finished update the scheduler
        scheduler.step()

    # Quit after the maximum number of epochs is reached
    logger.info(
        'Training completed after {} Epochs ({} it) with best val metric ({})={}'
        .format(epoch_it, it, model_selection_metric, metric_val_best))
Beispiel #8
0
train_loader = InfiniteDataLoader(train_data,
                                  batch_size=batch_size,
                                  num_workers=1)
validation_loader = DataLoader(val_data,
                               batch_size=batch_size * 2,
                               shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size * 2, shuffle=False)
test_loader_aug = DataLoader(test_data_aug,
                             batch_size=batch_size * 2,
                             shuffle=False)

miner = miners.DistanceWeightedMiner()

if has_checkpoint() and not args.debug:
    state = load_checkpoint()
    model.load_state_dict(state['model_dict'])
    optimizer.load_state_dict(state['optimizer_dict'])
    lr_scheduler.load_state_dict(state['scheduler_dict'])
    train_loader.sampler.load_state_dict(state['sampler_dict'])
    start_step = state['start_step']
    es = state['es']
    torch.random.set_rng_state(state['rng'])
    print("Loaded checkpoint at step %s" % start_step)
else:
    start_step = 0

for step in range(start_step, n_steps):
    if es.early_stop:
        break
    data, target, meta = next(iter(train_loader))
Beispiel #9
0
def main(cfg, logger):
    """
    Main function of this evaluation software. After preparing the data loaders, and the model start with the evaluation process.
    Args:
        cfg (dict): current configuration paramaters
    """

    # Create the output dir if it does not exist 
    if not os.path.exists(cfg['test']['results_dir']):
        os.makedirs(cfg['test']['results_dir'])

    # Get model
    model = config.get_model(cfg)
    device = torch.device('cuda' if (torch.cuda.is_available() and cfg['misc']['use_gpu']) else 'cpu') 

    # Get data loader
    eval_loader = make_data_loader(cfg, phase='test')

    # Log directory
    dataset_name = cfg["data"]["dataset"]

    path2log = os.path.join(cfg['test']['results_dir'], dataset_name, '{}_{}'.format(cfg['method']['backbone'], cfg['misc']['num_points']))

    logger, checkpoint_dir = prepare_logger(cfg, path2log)

    # Output torch and cuda version 
    
    logger.info('Torch version: {}'.format(torch.__version__))
    logger.info('CUDA version: {}'.format(torch.version.cuda))
    logger.info('Starting evaluation of the method {} on {} dataset'.format(cfg['method']['backbone'], dataset_name))

    # Save config file that was used for this experiment
    with open(os.path.join(path2log, "config.yaml"),'w') as outfile:
        yaml.dump(cfg, outfile, default_flow_style=False, allow_unicode=True)


    logger.info("Parameter Count: {:d}".format(n_model_parameters(model)))
    
    # Load the pretrained weights
    if cfg['network']['use_pretrained'] and cfg['network']['pretrained_path']:
        model, optimizer, scheduler, epoch_it, total_it, metric_val_best = load_checkpoint(model, None, None, filename=cfg['network']['pretrained_path'])

    else:
        logger.warning('MODEL RUNS IN EVAL MODE, BUT NO PRETRAINED WEIGHTS WERE LOADED!!!!')


    # Initialize the trainer
    trainer = config.get_trainer(cfg, model,device)

    # if not a pretrained model epoch and iterations should be -1 
    eval_metrics = defaultdict(list)    
    start = time.time()
    
    for it, batch in enumerate(tqdm(eval_loader)):
        # Put all the tensors to the designated device
        dict_all_to_device(batch, device)
        

        metrics = trainer.eval_step(batch)
        
        for key in metrics:
            eval_metrics[key].append(metrics[key])


    stop = time.time()

    # Compute mean values of the evaluation statistics
    result_string = ''

    for key, value in eval_metrics.items():
        if key not in ['true_p', 'true_n', 'false_p', 'false_n']:
            result_string += '{}: {:.3f}; '.format(key, np.mean(value))
    
    if 'true_p' in eval_metrics:
        result_string += '{}: {:.3f}; '.format('dataset_precision_f', (np.sum(eval_metrics['true_p']) / (np.sum(eval_metrics['true_p'])  + np.sum(eval_metrics['false_p'])) ))
        result_string += '{}: {:.3f}; '.format('dataset_recall_f', (np.sum(eval_metrics['true_p']) / (np.sum(eval_metrics['true_p'])  + np.sum(eval_metrics['false_n']))))

        result_string += '{}: {:.3f}; '.format('dataset_precision_b', (np.sum(eval_metrics['true_n']) / (np.sum(eval_metrics['true_n'])  + np.sum(eval_metrics['false_n']))))
        result_string += '{}: {:.3f}; '.format('dataset_recall_b', (np.sum(eval_metrics['true_n']) / (np.sum(eval_metrics['true_n'])  + np.sum(eval_metrics['false_p']))))


    logger.info('Outputing the evaluation metric for: {} {} {} '.format('Flow, ' if cfg['metrics']['flow'] else '', 'Ego-Motion, ' if cfg['metrics']['ego_motion'] else '', 'Bckg. Segmentaion' if cfg['metrics']['semantic'] else ''))
    logger.info(result_string)
    logger.info('Evaluation completed in {}s [{}s per scene]'.format((stop - start), (stop - start)/len(eval_loader)))     
Beispiel #10
0
    def warmup_generator(self, generator):
        """ Training on L1 Loss to warmup the Generator.

    Minimizing the L1 Loss will reduce the Peak Signal to Noise Ratio (PSNR)
    of the generated image from the generator.
    This trained generator is then used to bootstrap the training of the
    GAN, creating better image inputs instead of random noises.
    Args:
      generator: Model Object for the Generator
    """
        # Loading up phase parameters
        warmup_num_iter = self.settings.get("warmup_num_iter", None)
        phase_args = self.settings["train_psnr"]
        decay_params = phase_args["adam"]["decay"]
        decay_step = decay_params["step"]
        decay_factor = decay_params["factor"]
        total_steps = phase_args["num_steps"]
        metric = tf.keras.metrics.Mean()
        psnr_metric = tf.keras.metrics.Mean()
        # Generator Optimizer
        G_optimizer = tf.optimizers.Adam(
            learning_rate=phase_args["adam"]["initial_lr"],
            beta_1=phase_args["adam"]["beta_1"],
            beta_2=phase_args["adam"]["beta_2"])
        checkpoint = tf.train.Checkpoint(G=generator, G_optimizer=G_optimizer)

        status = utils.load_checkpoint(checkpoint, "phase_1", self.model_dir)
        logging.debug("phase_1 status object: {}".format(status))
        previous_loss = 0
        start_time = time.time()

        # Training starts

        def _step_fn(image_lr, image_hr):
            logging.debug("Starting Distributed Step")
            with tf.GradientTape() as tape:
                fake = generator.unsigned_call(image_lr)
                loss = utils.pixel_loss(image_hr,
                                        fake) * (1.0 / self.batch_size)
            psnr_metric(
                tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0)))
            gen_vars = list(set(generator.trainable_variables))
            gradient = tape.gradient(loss, gen_vars)
            G_optimizer.apply_gradients(zip(gradient, gen_vars))
            mean_loss = metric(loss)
            logging.debug("Ending Distributed Step")
            return tf.cast(G_optimizer.iterations, tf.float32)

        @tf.function
        def train_step(image_lr, image_hr):
            distributed_metric = self.strategy.experimental_run_v2(
                _step_fn, args=[image_lr, image_hr])
            mean_metric = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                               distributed_metric,
                                               axis=None)
            return mean_metric

        while True:
            image_lr, image_hr = next(self.dataset)
            num_steps = train_step(image_lr, image_hr)

            if num_steps >= total_steps:
                return
            if status:
                status.assert_consumed()
                logging.info("consumed checkpoint for phase_1 successfully")
                status = None

            if not num_steps % decay_step:  # Decay Learning Rate
                logging.debug("Learning Rate: %s" %
                              G_optimizer.learning_rate.numpy)
                G_optimizer.learning_rate.assign(G_optimizer.learning_rate *
                                                 decay_factor)
                logging.debug("Decayed Learning Rate by %f."
                              "Current Learning Rate %s" %
                              (decay_factor, G_optimizer.learning_rate))
            with self.summary_writer.as_default():
                tf.summary.scalar("warmup_loss",
                                  metric.result(),
                                  step=G_optimizer.iterations)
                tf.summary.scalar("mean_psnr", psnr_metric.result(),
                                  G_optimizer.iterations)

            if not num_steps % self.settings["print_step"]:
                logging.info("[WARMUP] Step: {}\tGenerator Loss: {}"
                             "\tPSNR: {}\tTime Taken: {} sec".format(
                                 num_steps, metric.result(),
                                 psnr_metric.result(),
                                 time.time() - start_time))
                if psnr_metric.result() > previous_loss:
                    utils.save_checkpoint(checkpoint, "phase_1",
                                          self.model_dir)
                previous_loss = psnr_metric.result()
                start_time = time.time()
Beispiel #11
0
    def train_gan(self, generator, discriminator):
        """ Implements Training routine for ESRGAN
        Args:
          generator: Model object for the Generator
          discriminator: Model object for the Discriminator
    """
        phase_args = self.settings["train_combined"]
        decay_args = phase_args["adam"]["decay"]
        decay_factor = decay_args["factor"]
        decay_steps = decay_args["step"]
        lambda_ = phase_args["lambda"]
        hr_dimension = self.settings["dataset"]["hr_dimension"]
        eta = phase_args["eta"]
        total_steps = phase_args["num_steps"]
        optimizer = partial(tf.optimizers.Adam,
                            learning_rate=phase_args["adam"]["initial_lr"],
                            beta_1=phase_args["adam"]["beta_1"],
                            beta_2=phase_args["adam"]["beta_2"])

        G_optimizer = optimizer()
        D_optimizer = optimizer()

        ra_gen = utils.RelativisticAverageLoss(discriminator, type_="G")
        ra_disc = utils.RelativisticAverageLoss(discriminator, type_="D")

        # The weights of generator trained during Phase #1
        # is used to initialize or "hot start" the generator
        # for phase #2 of training
        status = None
        checkpoint = tf.train.Checkpoint(G=generator,
                                         G_optimizer=G_optimizer,
                                         D=discriminator,
                                         D_optimizer=D_optimizer)
        if not tf.io.gfile.exists(
                os.path.join(self.model_dir,
                             self.settings["checkpoint_path"]["phase_2"],
                             "checkpoint")):
            hot_start = tf.train.Checkpoint(G=generator,
                                            G_optimizer=G_optimizer)
            status = utils.load_checkpoint(hot_start, "phase_1",
                                           self.model_dir)
            # consuming variable from checkpoint
            G_optimizer.learning_rate.assign(phase_args["adam"]["initial_lr"])
        else:
            status = utils.load_checkpoint(checkpoint, "phase_2",
                                           self.model_dir)

        logging.debug("phase status object: {}".format(status))

        gen_metric = tf.keras.metrics.Mean()
        disc_metric = tf.keras.metrics.Mean()
        psnr_metric = tf.keras.metrics.Mean()
        logging.debug("Loading Perceptual Model")
        perceptual_loss = utils.PerceptualLoss(
            weights="imagenet",
            input_shape=[hr_dimension, hr_dimension, 3],
            loss_type=phase_args["perceptual_loss_type"])
        logging.debug("Loaded Model")

        def _step_fn(image_lr, image_hr):
            logging.debug("Starting Distributed Step")
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                fake = generator.unsigned_call(image_lr)
                logging.debug("Fetched Generator Fake")
                fake = utils.preprocess_input(fake)
                image_lr = utils.preprocess_input(image_lr)
                image_hr = utils.preprocess_input(image_hr)
                percep_loss = tf.reduce_mean(perceptual_loss(image_hr, fake))
                logging.debug("Calculated Perceptual Loss")
                l1_loss = utils.pixel_loss(image_hr, fake)
                logging.debug("Calculated Pixel Loss")
                loss_RaG = ra_gen(image_hr, fake)
                logging.debug("Calculated Relativistic"
                              "Averate (RA) Loss for Generator")
                disc_loss = ra_disc(image_hr, fake)
                logging.debug("Calculated RA Loss Discriminator")
                gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss
                logging.debug("Calculated Generator Loss")
                disc_metric(disc_loss)
                gen_metric(gen_loss)
                gen_loss = gen_loss * (1.0 / self.batch_size)
                disc_loss = disc_loss * (1.0 / self.batch_size)
                psnr_metric(
                    tf.reduce_mean(tf.image.psnr(fake, image_hr,
                                                 max_val=256.0)))
            disc_grad = disc_tape.gradient(disc_loss,
                                           discriminator.trainable_variables)
            logging.debug("Calculated gradient for Discriminator")
            D_optimizer.apply_gradients(
                zip(disc_grad, discriminator.trainable_variables))
            logging.debug("Applied gradients to Discriminator")
            gen_grad = gen_tape.gradient(gen_loss,
                                         generator.trainable_variables)
            logging.debug("Calculated gradient for Generator")
            G_optimizer.apply_gradients(
                zip(gen_grad, generator.trainable_variables))
            logging.debug("Applied gradients to Generator")

            return tf.cast(D_optimizer.iterations, tf.float32)

        @tf.function
        def train_step(image_lr, image_hr):
            distributed_iterations = self.strategy.experimental_run_v2(
                _step_fn, args=(image_lr, image_hr))
            num_steps = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                             distributed_iterations,
                                             axis=None)
            return num_steps

        start = time.time()
        last_psnr = 0
        while True:
            image_lr, image_hr = next(self.dataset)
            num_step = train_step(image_lr, image_hr)
            if num_step >= total_steps:
                return
            if status:
                status.assert_consumed()
                logging.info("consumed checkpoint successfully!")
                status = None
            # Decaying Learning Rate
            for _step in decay_steps.copy():
                if num_step >= _step:
                    decay_steps.pop(0)
                    g_current_lr = self.strategy.reduce(
                        tf.distribute.ReduceOp.MEAN,
                        G_optimizer.learning_rate,
                        axis=None)

                    d_current_lr = self.strategy.reduce(
                        tf.distribute.ReduceOp.MEAN,
                        D_optimizer.learning_rate,
                        axis=None)

                    logging.debug("Current LR: G = %s, D = %s" %
                                  (g_current_lr, d_current_lr))
                    logging.debug("[Phase 2] Decayed Learing Rate by %f." %
                                  decay_factor)
                    G_optimizer.learning_rate.assign(
                        G_optimizer.learning_rate * decay_factor)
                    D_optimizer.learning_rate.assign(
                        D_optimizer.learning_rate * decay_factor)

            # Writing Summary
            with self.summary_writer_2.as_default():
                tf.summary.scalar("gen_loss",
                                  gen_metric.result(),
                                  step=D_optimizer.iterations)
                tf.summary.scalar("disc_loss",
                                  disc_metric.result(),
                                  step=D_optimizer.iterations)
                tf.summary.scalar("mean_psnr",
                                  psnr_metric.result(),
                                  step=D_optimizer.iterations)

            # Logging and Checkpointing
            if not num_step % self.settings["print_step"]:
                logging.info("Step: {}\tGen Loss: {}\tDisc Loss: {}"
                             "\tPSNR: {}\tTime Taken: {} sec".format(
                                 num_step, gen_metric.result(),
                                 disc_metric.result(), psnr_metric.result(),
                                 time.time() - start))
                # if psnr_metric.result() > last_psnr:
                last_psnr = psnr_metric.result()
                utils.save_checkpoint(checkpoint, "phase_2", self.model_dir)
                start = time.time()
Beispiel #12
0
    def warmup_generator(self, generator):
        """ Training on L1 Loss to warmup the Generator.

    Minimizing the L1 Loss will reduce the Peak Signal to Noise Ratio (PSNR)
    of the generated image from the generator.
    This trained generator is then used to bootstrap the training of the
    GAN, creating better image inputs instead of random noises.
    Args:
      generator: Model Object for the Generator
    """
        # Loading up phase parameters
        warmup_num_iter = self.settings.get("warmup_num_iter", None)
        phase_args = self.settings["train_psnr"]
        decay_params = phase_args["adam"]["decay"]
        decay_step = decay_params["step"]
        decay_factor = decay_params["factor"]

        metric = tf.keras.metrics.Mean()
        psnr_metric = tf.keras.metrics.Mean()
        tf.summary.experimental.set_step(tf.Variable(0, dtype=tf.int64))
        # Generator Optimizer
        G_optimizer = tf.optimizers.Adam(
            learning_rate=phase_args["adam"]["initial_lr"],
            beta_1=phase_args["adam"]["beta_1"],
            beta_2=phase_args["adam"]["beta_2"])
        checkpoint = tf.train.Checkpoint(
            G=generator,
            G_optimizer=G_optimizer,
            summary_step=tf.summary.experimental.get_step())

        status = utils.load_checkpoint(checkpoint, "phase_1")
        logging.debug("phase_1 status object: {}".format(status))
        previous_loss = float("inf")
        start_time = time.time()
        # Training starts
        for epoch in range(self.iterations):
            metric.reset_states()
            psnr_metric.reset_states()
            for image_lr, image_hr in self.dataset:
                step = tf.summary.experimental.get_step()
                if warmup_num_iter and step % warmup_num_iter:
                    return

                with tf.GradientTape() as tape:
                    fake = generator(image_lr)
                    loss = utils.pixel_loss(image_hr, fake)
                psnr = psnr_metric(
                    tf.reduce_mean(tf.image.psnr(fake, image_hr,
                                                 max_val=256.0)))
                gradient = tape.gradient(loss, generator.trainable_variables)
                G_optimizer.apply_gradients(
                    zip(gradient, generator.trainable_variables))
                mean_loss = metric(loss)

                if status:
                    status.assert_consumed()
                    logging.info(
                        "consumed checkpoint for phase_1 successfully")
                    status = None

                if not step % decay_step and step:  # Decay Learning Rate
                    logging.debug("Learning Rate: %f" %
                                  G_optimizer.learning_rate.numpy())
                    G_optimizer.learning_rate.assign(
                        G_optimizer.learning_rate * decay_factor)
                    logging.debug(
                        "Decayed Learning Rate by %f. Current Learning Rate %f"
                        % (decay_factor, G_optimizer.learning_rate.numpy()))
                with self.summary_writer.as_default():
                    tf.summary.scalar("warmup_loss", mean_loss, step=step)
                    tf.summary.scalar("mean_psnr", psnr, step=step)
                    step.assign_add(1)

                if not step % self.settings["print_step"]:
                    with self.summary_writer.as_default():
                        tf.summary.image(
                            "fake_image",
                            tf.cast(tf.clip_by_value(fake[:1], 0, 255),
                                    tf.uint8),
                            step=step)
                        tf.summary.image("hr_image",
                                         tf.cast(image_hr[:1], tf.uint8),
                                         step=step)

                    logging.info(
                        "[WARMUP] Epoch: {}\tBatch: {}\tGenerator Loss: {}\tPSNR: {}\tTime Taken: {} sec"
                        .format(epoch, step // epoch, mean_loss.numpy(),
                                psnr.numpy(),
                                time.time() - start_time))
                    if mean_loss < previous_loss:
                        utils.save_checkpoint(checkpoint, "phase_1")
                    previous_loss = mean_loss
                    start_time = time.time()
Beispiel #13
0
    def train_gan(self, generator, discriminator):
        """ Implements Training routine for ESRGAN
        Args:
          generator: Model object for the Generator
          discriminator: Model object for the Discriminator
    """
        phase_args = self.settings["train_combined"]
        decay_args = phase_args["adam"]["decay"]
        decay_factor = decay_args["factor"]
        decay_steps = decay_args["step"]
        lambda_ = phase_args["lambda"]
        hr_dimension = self.settings["dataset"]["hr_dimension"]
        eta = phase_args["eta"]
        tf.summary.experimental.set_step(tf.Variable(0, dtype=tf.int64))
        optimizer = partial(tf.optimizers.Adam,
                            learning_rate=phase_args["adam"]["initial_lr"],
                            beta_1=phase_args["adam"]["beta_1"],
                            beta_2=phase_args["adam"]["beta_2"])

        G_optimizer = optimizer()
        D_optimizer = optimizer()

        ra_gen = utils.RelativisticAverageLoss(discriminator, type_="G")
        ra_disc = utils.RelativisticAverageLoss(discriminator, type_="D")

        # The weights of generator trained during Phase #1
        # is used to initialize or "hot start" the generator
        # for phase #2 of training
        status = None
        checkpoint = tf.train.Checkpoint(
            G=generator,
            G_optimizer=G_optimizer,
            D=discriminator,
            D_optimizer=D_optimizer,
            summary_step=tf.summary.experimental.get_step())

        if not tf.io.gfile.exists(
                os.path.join(self.settings["checkpoint_path"]["phase_2"],
                             "checkpoint")):
            hot_start = tf.train.Checkpoint(
                G=generator,
                G_optimizer=G_optimizer,
                summary_step=tf.summary.experimental.get_step())
            status = utils.load_checkpoint(hot_start, "phase_1")
            # consuming variable from checkpoint
            tf.summary.experimental.get_step()

            tf.summary.experimental.set_step(tf.Variable(0, dtype=tf.int64))
        else:
            status = utils.load_checkpoint(checkpoint, "phase_2")

        logging.debug("phase status object: {}".format(status))

        gen_metric = tf.keras.metrics.Mean()
        disc_metric = tf.keras.metrics.Mean()
        psnr_metric = tf.keras.metrics.Mean()
        perceptual_loss = utils.PerceptualLoss(
            weights="imagenet",
            input_shape=[hr_dimension, hr_dimension, 3],
            loss_type=phase_args["perceptual_loss_type"])
        for epoch in range(self.iterations):
            # Resetting Metrics
            gen_metric.reset_states()
            disc_metric.reset_states()
            psnr_metric.reset_states()
            start = time.time()
            for (image_lr, image_hr) in self.dataset:
                step = tf.summary.experimental.get_step()

                # Calculating Loss applying gradients
                with tf.GradientTape() as gen_tape, tf.GradientTape(
                ) as disc_tape:
                    fake = generator(image_lr)
                    percep_loss = perceptual_loss(image_hr, fake)
                    l1_loss = utils.pixel_loss(image_hr, fake)
                    loss_RaG = ra_gen(image_hr, fake)
                    disc_loss = ra_disc(image_hr, fake)
                    gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss
                    disc_metric(disc_loss)
                    gen_metric(gen_loss)
                psnr = psnr_metric(
                    tf.reduce_mean(tf.image.psnr(fake, image_hr,
                                                 max_val=256.0)))
                disc_grad = disc_tape.gradient(
                    disc_loss, discriminator.trainable_variables)
                gen_grad = gen_tape.gradient(gen_loss,
                                             generator.trainable_variables)
                D_optimizer.apply_gradients(
                    zip(disc_grad, discriminator.trainable_variables))
                G_optimizer.apply_gradients(
                    zip(gen_grad, generator.trainable_variables))

                if status:
                    status.assert_consumed()
                    logging.info("consumed checkpoint successfully!")
                    status = None

                # Decaying Learning Rate
                for _step in decay_steps.copy():
                    if (step - 1) >= _step:
                        decay_steps.pop()
                        logging.debug("[Phase 2] Decayed Learing Rate by %f." %
                                      decay_factor)
                        G_optimizer.learning_rate.assign(
                            G_optimizer.learning_rate * decay_factor)
                        D_optimizer.learning_rate.assign(
                            D_optimizer.learning_rate * decay_factor)

                # Writing Summary
                with self.summary_writer.as_default():
                    tf.summary.scalar("gen_loss",
                                      gen_metric.result(),
                                      step=step)
                    tf.summary.scalar("disc_loss",
                                      disc_metric.result(),
                                      step=step)
                    tf.summary.scalar("mean_psnr", psnr, step=step)
                    step.assign_add(1)

                # Logging and Checkpointing
                if not step % self.settings["print_step"]:
                    with self.summary_writer.as_default():
                        resized_lr = tf.cast(
                            tf.clip_by_value(
                                tf.image.resize(image_lr[:1],
                                                [hr_dimension, hr_dimension],
                                                method=self.settings["dataset"]
                                                ["scale_method"]), 0, 255),
                            tf.uint8)
                        tf.summary.image("lr_image", resized_lr, step=step)
                        tf.summary.image(
                            "fake_image",
                            tf.cast(tf.clip_by_value(fake[:1], 0, 255),
                                    tf.uint8),
                            step=step)
                        tf.summary.image("hr_image",
                                         tf.cast(image_hr[:1], tf.uint8),
                                         step=step)
                    logging.info(
                        "Epoch: {}\tBatch: {}\tGen Loss: {}\tDisc Loss: {}\tPSNR: {}\tTime Taken: {} sec"
                        .format((epoch + 1),
                                step.numpy() // (epoch + 1),
                                gen_metric.result().numpy(),
                                disc_metric.result().numpy(), psnr.numpy(),
                                time.time() - start))
                    utils.save_checkpoint(checkpoint, "phase_2")
                    start = time.time()