Exemple #1
0
    def forward(self, ref_im, seed, loss_str, eps, noise_type,
                num_trainable_noise_layers, tile_latent, bad_noise_layers,
                opt_name, learning_rate, steps, lr_schedule, save_intermediate,
                **kwargs):

        if seed:
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = True

        batch_size = ref_im.shape[0]

        # Generate latent tensor
        if (tile_latent):
            latent = torch.randn((batch_size, 1, 512),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')
        else:
            latent = torch.randn((batch_size, 18, 512),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')

        # Generate list of noise tensors
        noise = []  # stores all of the noise tensors
        noise_vars = []  # stores the noise tensors that we want to optimize on

        for i in range(18):
            # dimension of the ith noise tensor
            res = (batch_size, 1, 2**(i // 2 + 2), 2**(i // 2 + 2))

            if (noise_type == 'zero' or i
                    in [int(layer) for layer in bad_noise_layers.split('.')]):
                new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
                new_noise.requires_grad = False
            elif (noise_type == 'fixed'):
                new_noise = torch.randn(res, dtype=torch.float, device='cuda')
                new_noise.requires_grad = False
            elif (noise_type == 'trainable'):
                new_noise = torch.randn(res, dtype=torch.float, device='cuda')
                if (i < num_trainable_noise_layers):
                    new_noise.requires_grad = True
                    noise_vars.append(new_noise)
                else:
                    new_noise.requires_grad = False
            else:
                raise Exception("unknown noise type")

            noise.append(new_noise)

        var_list = [latent] + noise_vars

        opt_dict = {
            'sgd': torch.optim.SGD,
            'adam': torch.optim.Adam,
            'sgdm': partial(torch.optim.SGD, momentum=0.9),
            'adamax': torch.optim.Adamax
        }
        opt_func = opt_dict[opt_name]
        opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)

        schedule_dict = {
            'fixed':
            lambda x: 1,
            'linear1cycle':
            lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10,
            'linear1cycledrop':
            lambda x: (9 * (1 - np.abs(x /
                                       (0.9 * steps) - 1 / 2) * 2) + 1) / 10
            if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps) /
            (0.1 * steps) * (1 / 1000 - 1 / 10),
        }
        schedule_func = schedule_dict[lr_schedule]
        scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)

        loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()

        min_loss = np.inf
        min_l2 = np.inf
        best_summary = ""
        start_t = time.time()
        gen_im = None

        if self.verbose: print("Optimizing")
        for j in range(steps):
            opt.opt.zero_grad()

            # Duplicate latent in case tile_latent = True
            if (tile_latent):
                latent_in = latent.expand(-1, 18, -1)
            else:
                latent_in = latent

            # Apply learned linear mapping to match latent distribution to that of the mapping network
            latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] +
                                   self.gaussian_fit["mean"])

            # Normalize image to [0,1] instead of [-1,1]
            gen_im = (self.synthesis(latent_in, noise) + 1) / 2

            # Calculate Losses
            loss, loss_dict = loss_builder(latent_in, gen_im)
            loss_dict['TOTAL'] = loss

            # Save best summary for log
            if (loss < min_loss):
                min_loss = loss
                best_summary = f'BEST ({j+1}) | ' + ' | '.join(
                    [f'{x}: {y:.4f}' for x, y in loss_dict.items()])
                best_im = gen_im.clone()

            loss_l2 = loss_dict['L2']

            if (loss_l2 < min_l2):
                min_l2 = loss_l2

            # Save intermediate HR and LR images
            if (save_intermediate):
                yield (best_im.cpu().detach().clamp(0, 1),
                       loss_builder.D(best_im).cpu().detach().clamp(0, 1))

            loss.backward()
            opt.step()
            scheduler.step()

        total_t = time.time() - start_t
        current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
        if self.verbose: print(best_summary + current_info)
        if (min_l2 <= eps):
            yield (gen_im.clone().cpu().detach().clamp(0, 1),
                   loss_builder.D(best_im).cpu().detach().clamp(0, 1))
        else:
            print(
                "Could not find a face that downscales correctly within epsilon"
            )
Exemple #2
0
    def forward(
        self,
        ref_im,
        loss_str,
        eps,
        noise_type,
        num_trainable_noise_layers,
        tile_latent,
        bad_noise_layers,
        opt_name,
        learning_rate,
        steps,
        lr_schedule,
        save_intermediate,
        seed=0,
        var_list_initial_values=None,
        step_postprocess=None,
        psi=1.0,
        **kwargs,
    ):

        if seed:
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = True

        batch_size = ref_im.shape[0]

        # Generate latent tensor
        if tile_latent:
            latent = torch.randn(
                (batch_size, 1, 512),
                dtype=torch.float,
                requires_grad=True,
                device="cuda",
            )
        else:
            latent = torch.randn(
                (batch_size, 18, 512),
                dtype=torch.float,
                requires_grad=True,
                device="cuda",
            )
        with torch.no_grad():
            latent *= psi

        # Generate list of noise tensors
        noise = []  # stores all of the noise tensors
        noise_vars = []  # stores the noise tensors that we want to optimize on

        for i in range(18):
            # dimension of the ith noise tensor
            res = (batch_size, 1, 2**(i // 2 + 2), 2**(i // 2 + 2))

            if noise_type == "zero" or i in [
                    int(layer) for layer in bad_noise_layers.split(".")
            ]:
                new_noise = torch.zeros(res, dtype=torch.float, device="cuda")
                new_noise.requires_grad = False
            elif noise_type == "fixed":
                new_noise = torch.randn(res, dtype=torch.float, device="cuda")
                new_noise.requires_grad = False
            elif noise_type == "trainable":
                new_noise = torch.randn(res, dtype=torch.float, device="cuda")
                if i < num_trainable_noise_layers:
                    new_noise.requires_grad = True
                    noise_vars.append(new_noise)
                else:
                    new_noise.requires_grad = False
            else:
                raise Exception("unknown noise type")

            noise.append(new_noise)

        var_list = [latent] + noise_vars

        if var_list_initial_values is not None:
            assert len(var_list) == len(var_list_initial_values)
            with torch.no_grad():
                for var, initial_value in zip(var_list,
                                              var_list_initial_values):
                    var.copy_(initial_value)

        opt_dict = {
            "sgd": torch.optim.SGD,
            "adam": torch.optim.Adam,
            "sgdm": partial(torch.optim.SGD, momentum=0.9),
            "adamax": torch.optim.Adamax,
            "custom": partial(torch.optim.AdamW, betas=(0.9, 0.99)),
        }
        opt_func = opt_dict[opt_name]
        if step_postprocess is None:
            opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)
        else:
            opt = StepPostProcessOptimizer(opt_func,
                                           var_list,
                                           step_postprocess=step_postprocess,
                                           lr=learning_rate)

        schedule_dict = {
            "fixed":
            lambda x: 1,
            "linear1cycle":
            lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10,
            "linear1cycledrop":
            lambda x: (9 * (1 - np.abs(x /
                                       (0.9 * steps) - 1 / 2) * 2) + 1) / 10
            if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps) /
            (0.1 * steps) * (1 / 1000 - 1 / 10),
        }
        schedule_func = schedule_dict[lr_schedule]
        scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)

        loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()

        min_loss = np.inf
        min_l2 = np.inf
        best_summary = ""
        start_t = time.time()
        gen_im = None

        if self.verbose:
            print("Optimizing")
        for j in range(steps):
            opt.opt.zero_grad()

            # Duplicate latent in case tile_latent = True
            if tile_latent:
                latent_in = latent.expand(-1, 18, -1)
            else:
                latent_in = latent

            # Apply learned linear mapping to match latent distribution to that of the mapping network
            latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] +
                                   self.gaussian_fit["mean"])

            # Normalize image to [0,1] instead of [-1,1]
            gen_im = (self.synthesis(latent_in, noise) + 1) / 2

            # Calculate Losses
            loss, loss_dict = loss_builder(latent_in, gen_im)
            loss_dict["TOTAL"] = loss

            # Save best summary for log
            if loss < min_loss:
                min_loss = loss
                best_summary = f"BEST ({j+1}) | " + " | ".join(
                    [f"{x}: {y:.4f}" for x, y in loss_dict.items()])
                best_im = gen_im.clone()

            loss_l2 = loss_dict["L2"]

            if loss_l2 < min_l2:
                min_l2 = loss_l2

            # Save intermediate HR and LR images
            if save_intermediate:
                yield dict(
                    final=False,
                    min_l2=min_l2,
                    HR=best_im.cpu().detach().clamp(0, 1),
                    LR=loss_builder.D(best_im).cpu().detach().clamp(0, 1),
                )

            loss.backward()
            opt.step()
            scheduler.step()

        total_t = time.time() - start_t
        current_info = f" | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}"
        if self.verbose:
            print(best_summary + current_info)
        yield dict(
            final=True,
            min_l2=min_l2,
            success=min_l2 <= eps,
            HR=gen_im.clone().cpu().detach().clamp(0, 1),
            LR=loss_builder.D(best_im).cpu().detach().clamp(0, 1),
            var_list=var_list,
            loss_dict=loss_dict,
        )
Exemple #3
0
    def forward(self, ref_im, loss_str, eps, tile_latent, opt_name, steps,
                learning_rate, lr_schedule, save_intermediate):

        gaussian_fit = self.Gaussian_fit()

        batch_size = ref_im.shape[0]

        if (tile_latent):
            latent = torch.randn((batch_size, 1, 512),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')
        else:
            latent = torch.randn((batch_size, 18, 512),
                                 dtype=torch.float,
                                 requires_grad=True,
                                 device='cuda')

        var_list = [latent]

        opt_dict = {
            'sgd': torch.optim.SGD,
            'adam': torch.optim.Adam,
            'sgdm': partial(torch.optim.SGD, momentum=0.9),
            'adamax': torch.optim.Adamax
        }
        opt_func = opt_dict[opt_name]
        opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)

        schedule_dict = {
            'fixed':
            lambda x: 1,
            'linear1cycle':
            lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10,
            'linear1cycledrop':
            lambda x: (9 * (1 - np.abs(x /
                                       (0.9 * steps) - 1 / 2) * 2) + 1) / 10
            if x < 0.9 * steps else 1 / 10 + (x - 0.9 * steps) /
            (0.1 * steps) * (1 / 1000 - 1 / 10),
        }
        schedule_func = schedule_dict[lr_schedule]
        scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)

        ref_im = ref_im.cuda()

        loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()

        min_loss = np.inf
        min_l2 = np.inf
        best_summary = ""
        start_t = time.time()
        gen_im = None

        print("Optimizing")

        for j in range(steps):
            opt.opt.zero_grad()

            # Duplicate latent in case tile_latent = True
            if (tile_latent):
                latent_in = latent.expand(-1, 18, -1)
            else:
                latent_in = latent

            # Apply learned linear mapping to match latent distribution to that of the mapping network
            latent_in = self.lrelu(latent_in * gaussian_fit["std"] +
                                   gaussian_fit["mean"])

            # Normalize image to [0,1] instead of [-1,1]
            gen_im = (self.synthesizer(latent_in) + 1) / 2

            # Calculate Losses
            loss, loss_dict = loss_builder(latent_in, gen_im)
            loss_dict['TOTAL'] = loss

            # Save best summary for log
            if (loss < min_loss):
                min_loss = loss
                best_summary = f'BEST ({j+1}) | ' + ' | '.join(
                    [f'{x}: {y:.4f}' for x, y in loss_dict.items()])
                best_im = gen_im.clone()

            loss_l2 = loss_dict['L2']

            if (loss_l2 < min_l2):
                min_l2 = loss_l2

            # Save intermediate HR and LR images
            if (save_intermediate):
                yield (best_im.cpu().detach().clamp(0, 1),
                       loss_builder.D(best_im).cpu().detach().clamp(0, 1))

            loss.backward()
            opt.step()
            scheduler.step()

        total_t = time.time() - start_t
        current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
        print(best_summary + current_info)
        if (min_l2 <= eps):
            yield (gen_im.clone().cpu().detach().clamp(0, 1),
                   loss_builder.D(best_im).cpu().detach().clamp(0, 1))
        else:
            print(
                "Could not find a face that downscales correctly within epsilon"
            )
class Projector(nn.Module):
    """
    Projects data to latent space and noise tensors.
    Arguments:
        G (Generator)
        dlatent_avg_samples (int): Number of dlatent samples
            to collect to find the mean and std.
            Default value is 10 000.
        dlatent_avg_label (int, torch.Tensor, optional): The label to
            use when gathering dlatent statistics.
        dlatent_device (int, str, torch.device, optional): Device to use
            for gathering statistics of dlatents. By default uses
            the same device as parameters of `G` reside on.
        dlatent_batch_size (int): The batch size to sample
            dlatents with. Default value is 1024.
        lpips_model (nn.Module): A model that returns feature the distance
            between two inputs. Default value is the LPIPS VGG16 model.
        lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling
            the data so that its smallest side is the same size as this
            argument. Only has a default value of 256 if `lpips_model` is unspecified.
        verbose (bool): Write progress of dlatent statistics gathering to stdout.
            Default value is True.
    """
    def __init__(self,
                 G,
                 dlatent_avg_samples=10000,
                 dlatent_avg_label=None,
                 dlatent_device=None,
                 dlatent_batch_size=1024,
                 lpips_model=None,
                 lpips_size=None,
                 verbose=True):
        super(Projector, self).__init__()
        assert isinstance(G, models.Generator)
        G.eval().requires_grad_(False)

        self.G_synthesis = G.G_synthesis

        G_mapping = G.G_mapping

        dlatent_batch_size = min(dlatent_batch_size, dlatent_avg_samples)

        if dlatent_device is None:
            dlatent_device = next(G_mapping.parameters()).device()
        else:
            dlatent_device = torch.device(dlatent_device)

        G_mapping.to(dlatent_device)

        latents = torch.empty(
            dlatent_avg_samples, G_mapping.latent_size).normal_()
        dlatents = []

        labels = None
        if dlatent_avg_label is not None:
            labels = torch.tensor(dlatent_avg_label).to(dlatent_device).long().view(-1).repeat(dlatent_batch_size)

        if verbose:
            progress = utils.ProgressWriter(np.ceil(dlatent_avg_samples / dlatent_batch_size))
            progress.write('Gathering dlatents...', step=False)

        for i in range(0, dlatent_avg_samples, dlatent_batch_size):
            batch_latents = latents[i: i + dlatent_batch_size].to(dlatent_device)
            batch_labels = None
            if labels is not None:
                batch_labels = labels[:len(batch_latents)]
            with torch.no_grad():
                dlatents.append(G_mapping(batch_latents, labels=batch_labels).cpu())
            if verbose:
                progress.step()

        if verbose:
            progress.write('Done!', step=False)
            progress.close()

        dlatents = torch.cat(dlatents, dim=0)

        self.register_buffer(
            '_dlatent_avg',
            dlatents.mean(dim=0).view(1, 1, -1)
        )
        self.register_buffer(
            '_dlatent_std',
            torch.sqrt(
                torch.sum((dlatents - self._dlatent_avg) ** 2) / dlatent_avg_samples + 1e-8
            ).view(1, 1, 1)
        )

        if lpips_model is None:
            warnings.warn(
                'Using default LPIPS distance metric based on VGG 16. ' + \
                'This metric will only work on image data where values are in ' + \
                'the range [-1, 1], please specify an lpips module if you want ' + \
                'to use other kinds of data formats.'
            )
            lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
            lpips_size = 256
        self.lpips_model = lpips_model.eval().requires_grad_(False)
        self.lpips_size = lpips_size

        self.to(dlatent_device)

    def _scale_for_lpips(self, data):
        if not self.lpips_size:
            return data
        scale_factor = self.lpips_size / min(data.size()[2:])
        if scale_factor == 1:
            return data
        mode = 'nearest'
        if scale_factor < 1:
            mode = 'area'
        return F.interpolate(data, scale_factor=scale_factor, mode=mode)

    def _check_job(self):
        assert self._job is not None, 'Call `start()` first to set up target.'
        # device of dlatent param will not change with the rest of the models
        # and buffers of this class as it was never registered as a buffer or
        # parameter. Same goes for optimizer. Make sure it is on the correct device.
        if self._job.dlatent_param.device != self._dlatent_avg.device:
            self._job.dlatent_param = self._job.dlatent_param.to(self._dlatent_avg)
            self._job.opt.load_state_dict(
                utils.move_to_device(self._job.opt.state_dict(), self._dlatent_avg.device)[0])

    def generate(self):
        """
        Generate an output with the current dlatent and noise values.
        Returns:
            output (torch.Tensor)
        """
        self._check_job()
        with torch.no_grad():
            return self.G_synthesis(self._job.dlatent_param)

    def get_dlatent(self):
        """
        Get a copy of the current dlatent values.
        Returns:
            dlatents (torch.Tensor)
        """
        self._check_job()
        return self._job.dlatent_param.data.clone()

    def get_noise(self):
        """
        Get a copy of the current noise values.
        Returns:
            noise_tensors (list)
        """
        self._check_job()
        return [noise.data.clone() for noise in self._job.noise_params]

    def start(self,
              target,
              num_steps=1000,
              initial_learning_rate=0.1,
              initial_noise_factor=0.05,
              lr_rampdown_length=0.25,
              lr_rampup_length=0.05,
              noise_ramp_length=0.75,
              regularize_noise_weight=1e5,
              verbose=True,
              verbose_prefix='',
              noise_layers=5):
        """
        Set up a target and its projection parameters.
        Arguments:
            target (torch.Tensor): The data target. This should
                already be preprocessed (scaled to correct value range).
            num_steps (int): Number of optimization steps. Default
                value is 1000.
            initial_learning_rate (float): Default value is 0.1.
            initial_noise_factor (float): Default value is 0.05.
            lr_rampdown_length (float): Default value is 0.25.
            lr_rampup_length (float): Default value is 0.05.
            noise_ramp_length (float): Default value is 0.75.
            regularize_noise_weight (float): Default value is 1e5.
            verbose (bool): Write progress to stdout every time
                `step()` is called.
            verbose_prefix (str, optional): This is written before
                any other output to stdout.
        """
        if target.dim() == self.G_synthesis.dim + 1:
            target = target.unsqueeze(0)
        assert target.dim() == self.G_synthesis.dim + 2, \
            'Number of dimensions of target data is incorrect.'

        target = target.to(self._dlatent_avg)
        target_scaled = self._scale_for_lpips(target)

        dlatent_param = nn.Parameter(self._dlatent_avg.clone().repeat(target.size(0), len(self.G_synthesis), 1))
        noise_params = self.G_synthesis.static_noise(trainable=True)
        for i_n, n in enumerate(noise_params):
            if i_n > noise_layers:
                # n.grad_fn = None
                n.requires_grad = True
        params = [dlatent_param] + noise_params
        # opt = torch.optim.Adam(params)
        opt = SphericalOptimizer(torch..optim.Adam, params, lr=initial_learning_rate)
        schedule_func = lambda x: (9 * (1 - np.abs(x / (0.9 * num_steps) - 1 / 2) * 2) + 1) / 10 
                if x < 0.9 * num_steps else 1/10 + (x - 0.9 * num_steps) / (0.1 * num_steps) * (1 / 1000 - 1 / 10)

        loss_builder = LossBuilder(target, "100*L2+0.1*GEOCROSS", 1e-3).cuda()
        scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)

        noise_tensor = torch.empty_like(dlatent_param)

        if verbose:
            progress = utils.ProgressWriter(num_steps)
            value_tracker = utils.ValueTracker()

        self._job = utils.AttributeDict(**locals())
        self._job.current_step = 0
        self._job.min_loss = np.inf
        self._job.min_l2 = np.inf
        self._job.best_output = None
Exemple #5
0
def main(config, mode, weights):
    # Generate configuration
    cfg = kaptan.Kaptan(handler='yaml')
    config = cfg.import_config(config)

    # Generate logger
    MODEL_SAVE_NAME, MODEL_SAVE_FOLDER, LOGGER_SAVE_NAME, CHECKPOINT_DIRECTORY = utils.generate_save_names(
        config)
    os.makedirs(MODEL_SAVE_FOLDER, exist_ok=True)
    logger = utils.generate_logger(MODEL_SAVE_FOLDER, LOGGER_SAVE_NAME)

    logger.info("*" * 40)
    logger.info("")
    logger.info("")
    logger.info("Using the following configuration:")
    logger.info(config.export("yaml", indent=4))
    logger.info("")
    logger.info("")
    logger.info("*" * 40)

    NORMALIZATION_MEAN, NORMALIZATION_STD, RANDOM_ERASE_VALUE = utils.fix_generator_arguments(
        config)
    TRAINDATA_KWARGS = {
        "rea_value": config.get("TRANSFORMATION.RANDOM_ERASE_VALUE")
    }
    """ MODEL PARAMS """
    from utils import model_weights

    MODEL_WEIGHTS = None
    if config.get("MODEL.MODEL_BASE") in model_weights:
        if mode == "train":
            if os.path.exists(
                    model_weights[config.get("MODEL.MODEL_BASE")][1]):
                pass
            else:
                logger.info(
                    "Model weights file {} does not exist. Downloading.".
                    format(model_weights[config.get("MODEL.MODEL_BASE")][1]))
                utils.web.download(
                    model_weights[config.get("MODEL.MODEL_BASE")][1],
                    model_weights[config.get("MODEL.MODEL_BASE")][0])
            MODEL_WEIGHTS = model_weights[config.get("MODEL.MODEL_BASE")][1]
    else:
        raise NotImplementedError(
            "Model %s is not available. Please choose one of the following: %s"
            % (config.get("MODEL.MODEL_BASE"), str(model_weights.keys())))

    # ------------------ LOAD SAVED LOGGER IF EXISTS ----------------------------
    DRIVE_BACKUP = config.get("SAVE.DRIVE_BACKUP")
    if DRIVE_BACKUP:  # Find backed-up log file,  if it exists, and copy to local
        backup_logger = os.path.join(CHECKPOINT_DIRECTORY, LOGGER_SAVE_NAME)
        logger_save_path = os.path.join(MODEL_SAVE_FOLDER, LOGGER_SAVE_NAME)
        if os.path.exists(backup_logger):
            shutil.copy2(backup_logger, logger_save_path)
    else:
        backup_logger = None

    NUM_GPUS = torch.cuda.device_count()
    if NUM_GPUS > 1:
        raise RuntimeError(
            "Not built for multi-GPU. Please start with single-GPU.")
    logger.info("Found %i GPUs" % NUM_GPUS)

    # --------------------- BUILD GENERATORS ------------------------
    data_crawler = utils.dynamic_import(cfg=config,
                                        module_name="crawlers",
                                        import_name="EXECUTION.CRAWLER")
    data_generator = utils.dynamic_import(cfg=config,
                                          module_name="generators",
                                          import_name="EXECUTION.GENERATOR")

    crawler = data_crawler(data_folder=config.get("DATASET.ROOT_DATA_FOLDER"),
                           train_folder=config.get("DATASET.TRAIN_FOLDER"),
                           test_folder=config.get("DATASET.TEST_FOLDER"),
                           query_folder=config.get("DATASET.QUERY_FOLDER"),
                           **{"logger": logger})

    train_mode = config.get("EXECUTION.TRAIN_MODE", "train")  # or "train-gzsl"
    train_generator = data_generator(
        gpus=NUM_GPUS,
        i_shape=config.get("DATASET.SHAPE"),
        normalization_mean=NORMALIZATION_MEAN,
        normalization_std=NORMALIZATION_STD,
        normalization_scale=1. /
        config.get("TRANSFORMATION.NORMALIZATION_SCALE"),
        h_flip=config.get("TRANSFORMATION.H_FLIP"),
        t_crop=config.get("TRANSFORMATION.T_CROP"),
        rea=config.get("TRANSFORMATION.RANDOM_ERASE"),
        **TRAINDATA_KWARGS)
    train_generator.setup(crawler,
                          mode=train_mode,
                          batch_size=config.get("TRANSFORMATION.BATCH_SIZE"),
                          instance=config.get("TRANSFORMATION.INSTANCES"),
                          workers=config.get("TRANSFORMATION.WORKERS"))

    logger.info("Generated training data generator")
    TRAIN_CLASSES = train_generator.num_entities
    test_mode = config.get("EXECUTION.TEST_MODE", "zsl")  # or "gzsl"
    test_generator = data_generator(
        gpus=NUM_GPUS,
        i_shape=config.get("DATASET.SHAPE"),
        normalization_mean=NORMALIZATION_MEAN,
        normalization_std=NORMALIZATION_STD,
        normalization_scale=1. /
        config.get("TRANSFORMATION.NORMALIZATION_SCALE"),
        h_flip=0,
        t_crop=False,
        rea=False)
    test_generator.setup(crawler,
                         mode=test_mode,
                         batch_size=config.get("TRANSFORMATION.BATCH_SIZE"),
                         instance=config.get("TRANSFORMATION.INSTANCES"),
                         workers=config.get("TRANSFORMATION.WORKERS"))
    TEST_CLASSES = test_generator.num_entities
    logger.info("Generated validation data/query generator")

    # --------------------- INSTANTIATE MODEL ------------------------
    model_builder = __import__("models", fromlist=["*"])
    model_builder = getattr(model_builder,
                            config.get("EXECUTION.MODEL_BUILDER"))
    logger.info("Loaded {} from {} to build CarZam model".format(
        config.get("EXECUTION.MODEL_BUILDER"), "models"))

    carzam_model = model_builder(   arch=config.get("MODEL.MODEL_ARCH"), \
                                    base=config.get("MODEL.MODEL_BASE"), \
                                    weights=MODEL_WEIGHTS, \
                                    embedding_dimensions = config.get("MODEL.EMBEDDING_DIMENSIONS"), \
                                    normalization = config.get("MODEL.MODEL_NORMALIZATION"), \
                                    **json.loads(config.get("MODEL.MODEL_KWARGS")))
    logger.info("Finished instantiating model with {} architecture".format(
        config.get("MODEL.MODEL_ARCH")))

    if mode == "test":
        carzam_model.load_state_dict(torch.load(weights))
        carzam_model.cuda()
        carzam_model.eval()
    else:
        if weights != "":  # Load weights if train and starting from a another model base...
            logger.info(
                "Commencing partial model load from {}".format(weights))
            carzam_model.partial_load(weights)
            logger.info("Completed partial model load from {}".format(weights))
        carzam_model.cuda()
        logger.info(
            torchsummary.summary(carzam_model,
                                 input_size=(3, *config.get("DATASET.SHAPE"))))

    # --------------------- INSTANTIATE LOSS ------------------------
    from loss import CarZamLossBuilder as LossBuilder
    loss_function = LossBuilder(loss_functions=config.get("LOSS.LOSSES"),
                                loss_lambda=config.get("LOSS.LOSS_LAMBDAS"),
                                loss_kwargs=config.get("LOSS.LOSS_KWARGS"),
                                **{"logger": logger})
    logger.info("Built loss function")

    # --------------------- INSTANTIATE LOSS OPTIMIZER --------------
    from optimizer.StandardLossOptimizer import StandardLossOptimizer as loss_optimizer

    LOSS_OPT = loss_optimizer(
        base_lr=config.get("LOSS_OPTIMIZER.BASE_LR",
                           config.get("OPTIMIZER.BASE_LR")),
        lr_bias=config.get("LOSS_OPTIMIZER.LR_BIAS_FACTOR",
                           config.get("OPTIMIZER.LR_BIAS_FACTOR")),
        weight_decay=config.get("LOSS_OPTIMIZER.WEIGHT_DECAY",
                                config.get("OPTIMIZER.WEIGHT_DECAY")),
        weight_bias=config.get("LOSS_OPTIMIZER.WEIGHT_BIAS_FACTOR",
                               config.get("OPTIMIZER.WEIGHT_BIAS_FACTOR")),
        gpus=NUM_GPUS)
    loss_optimizer = LOSS_OPT.build(
        loss_builder=loss_function,
        name=config.get("LOSS_OPTIMIZER.OPTIMIZER_NAME",
                        config.get("OPTIMIZER.OPTIMIZER_NAME")),
        **json.loads(
            config.get("LOSS_OPTIMIZER.OPTIMIZER_KWARGS",
                       config.get("OPTIMIZER.OPTIMIZER_KWARGS"))))
    logger.info("Built loss optimizer")

    # --------------------- INSTANTIATE OPTIMIZER ------------------------
    optimizer_builder = __import__("optimizer", fromlist=["*"])
    optimizer_builder = getattr(optimizer_builder,
                                config.get("EXECUTION.OPTIMIZER_BUILDER"))
    logger.info("Loaded {} from {} to build Optimizer model".format(
        config.get("EXECUTION.OPTIMIZER_BUILDER"), "optimizer"))

    OPT = optimizer_builder(
        base_lr=config.get("OPTIMIZER.BASE_LR"),
        lr_bias=config.get("OPTIMIZER.LR_BIAS_FACTOR"),
        weight_decay=config.get("OPTIMIZER.WEIGHT_DECAY"),
        weight_bias=config.get("OPTIMIZER.WEIGHT_BIAS_FACTOR"),
        gpus=NUM_GPUS)
    optimizer = OPT.build(
        carzam_model, config.get("OPTIMIZER.OPTIMIZER_NAME"),
        **json.loads(config.get("OPTIMIZER.OPTIMIZER_KWARGS")))
    logger.info("Built optimizer")

    # --------------------- INSTANTIATE SCHEDULER ------------------------
    try:  # We first check if scheduler is part of torch's provided schedulers.
        scheduler = __import__('torch.optim.lr_scheduler',
                               fromlist=['lr_scheduler'])
        scheduler = getattr(scheduler, config.get("SCHEDULER.LR_SCHEDULER"))
    except (
            ModuleNotFoundError, AttributeError
    ):  # If it fails, then we try to import from schedulers implemented in scheduler/ folder
        scheduler_ = config.get("SCHEDULER.LR_SCHEDULER")
        scheduler = __import__("scheduler." + scheduler_,
                               fromlist=[scheduler_])
        scheduler = getattr(scheduler, scheduler_)
    scheduler = scheduler(optimizer,
                          last_epoch=-1,
                          **json.loads(config.get("SCHEDULER.LR_KWARGS")))
    logger.info("Built scheduler")

    # ------------------- INSTANTIATE LOSS SCHEEDULER ---------------------
    loss_scheduler = None
    if loss_optimizer is not None:  # In case loss has no differentiable paramters
        try:
            loss_scheduler = __import__('torch.optim.lr_scheduler',
                                        fromlist=['lr_scheduler'])
            loss_scheduler = getattr(
                loss_scheduler,
                config.get("LOSS_SCHEDULER.LR_SCHEDULER",
                           config.get("SCHEDULER.LR_SCHEDULER")))
        except (ModuleNotFoundError, AttributeError):
            loss_scheduler_ = config.get("LOSS_SCHEDULER.LR_SCHEDULER",
                                         config.get("SCHEDULER.LR_SCHEDULER"))
            loss_scheduler = __import__("scheduler." + loss_scheduler_,
                                        fromlist=[loss_scheduler_])
            loss_scheduler = getattr(loss_scheduler, loss_scheduler_)
        loss_scheduler = loss_scheduler(
            loss_optimizer,
            last_epoch=-1,
            **json.loads(
                config.get("LOSS_SCHEDULER.LR_KWARGS",
                           config.get("SCHEDULER.LR_KWARGS"))))
        logger.info("Built loss scheduler")
    else:
        loss_scheduler = None

    # --------------------- DRIVE BACKUP ------------------------
    if DRIVE_BACKUP:  #
        fl_list = glob.glob(os.path.join(CHECKPOINT_DIRECTORY, "*.pth"))
    else:
        fl_list = glob.glob(os.path.join(MODEL_SAVE_FOLDER, "*.pth"))
    _re = re.compile(r'.*epoch([0-9]+)\.pth')
    previous_stop = [
        int(item[1]) for item in [_re.search(item) for item in fl_list]
        if item is not None
    ]
    if len(previous_stop) == 0:
        previous_stop = 0
        logger.info("No previous stop detected. Will start from epoch 0")
    else:
        previous_stop = max(previous_stop) + 1
        logger.info(
            "Previous stop detected. Will attempt to resume from epoch %i" %
            previous_stop)

    # --------------------- PERFORM TRAINING ------------------------
    trainer = __import__("trainer", fromlist=["*"])
    trainer = getattr(trainer, config.get("EXECUTION.TRAINER"))
    logger.info("Loaded {} from {} to build Trainer".format(
        config.get("EXECUTION.TRAINER"), "trainer"))

    loss_stepper = trainer(model=carzam_model,
                           loss_fn=loss_function,
                           optimizer=optimizer,
                           loss_optimizer=loss_optimizer,
                           scheduler=scheduler,
                           loss_scheduler=loss_scheduler,
                           train_loader=train_generator.dataloader,
                           test_loader=test_generator.dataloader,
                           queries=TEST_CLASSES,
                           epochs=config.get("EXECUTION.EPOCHS"),
                           logger=logger,
                           test_mode=config.get("EXECUTION.TEST_MODE",
                                                "zsl"))  # or "gzsl"
    loss_stepper.setup(step_verbose=config.get("LOGGING.STEP_VERBOSE"),
                       save_frequency=config.get("SAVE.SAVE_FREQUENCY"),
                       test_frequency=config.get("EXECUTION.TEST_FREQUENCY"),
                       save_directory=MODEL_SAVE_FOLDER,
                       save_backup=DRIVE_BACKUP,
                       backup_directory=CHECKPOINT_DIRECTORY,
                       gpus=NUM_GPUS,
                       fp16=config.get("OPTIMIZER.FP16"),
                       model_save_name=MODEL_SAVE_NAME,
                       logger_file=LOGGER_SAVE_NAME)
    if mode == 'train':
        loss_stepper.train(continue_epoch=previous_stop)
    elif mode == 'test':
        loss_stepper.evaluate()
    else:
        raise NotImplementedError()