コード例 #1
0
ファイル: builder.py プロジェクト: dupeljan/nncf_for_tf
    def load_tfds(self):
        logger.info('Using TFDS to load data.')

        set_hard_limit_num_open_files()

        self._builder = tfds.builder(self._dataset_name,
                                     data_dir=self._dataset_dir)

        self._builder.download_and_prepare()

        decoders = {}

        if self._skip_decoding:
            decoders['image'] = tfds.decode.SkipDecoding()

        read_config = tfds.ReadConfig(interleave_cycle_length=64,
                                      interleave_block_length=1)

        dataset = self._builder.as_dataset(split=self._split,
                                           as_supervised=True,
                                           shuffle_files=True,
                                           decoders=decoders,
                                           read_config=read_config)

        return dataset
コード例 #2
0
ファイル: main.py プロジェクト: dupeljan/nncf_for_tf
def resume_from_checkpoint(checkpoint_manager, ckpt_path, steps_per_epoch):
    if load_checkpoint(checkpoint_manager.checkpoint, ckpt_path) == 0:
        return 0
    optimizer = checkpoint_manager.checkpoint.optimizer
    initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
    logger.info('Resuming from epoch {}'.format(initial_epoch))
    return int(initial_epoch)
コード例 #3
0
def build_scheduler(config, epoch_size, batch_size, steps):
    optimizer_config = config.get('optimizer', {})
    schedule_type = optimizer_config.get('schedule_type',
                                         'exponential').lower()
    schedule_params = optimizer_config.get("schedule_params", {})

    if schedule_type == 'exponential':
        decay_rate = schedule_params.get('decay_rate', None)
        if decay_rate is None:
            raise ValueError('decay_rate parameter must be specified '
                             'for the exponential scheduler')

        initial_lr = schedule_params.get('initial_lr', None)
        if initial_lr is None:
            raise ValueError('initial_lr parameter must be specified '
                             'for the exponential scheduler')

        decay_epochs = schedule_params.get('decay_epochs', None)
        decay_steps = decay_epochs * steps if decay_epochs is not None else 0

        logger.info(
            'Using exponential learning rate with: '
            'initial_learning_rate: {initial_lr}, decay_steps: {decay_steps}, '
            'decay_rate: {decay_rate}'.format(initial_lr=initial_lr,
                                              decay_steps=decay_steps,
                                              decay_rate=decay_rate))
        lr = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=initial_lr,
            decay_steps=decay_steps,
            decay_rate=decay_rate)
    elif schedule_type == 'piecewise_constant':
        boundaries = schedule_params.get('boundaries', None)
        if boundaries is None:
            raise ValueError('boundaries parameter must be specified '
                             'for the piecewise_constant scheduler')

        values = schedule_params.get('values', None)
        if values is None:
            raise ValueError('values parameter must be specified '
                             'for the piecewise_constant')

        logger.info(
            'Using Piecewise constant decay with warmup. '
            'Parameters: batch_size: {batch_size}, epoch_size: {epoch_size}, '
            'boundaries: {boundaries}, values: {values}'.format(
                batch_size=batch_size,
                epoch_size=epoch_size,
                boundaries=boundaries,
                values=values))
        steps_per_epoch = epoch_size // batch_size
        boundaries = [steps_per_epoch * x for x in boundaries]
        lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries, values)
    elif schedule_type == 'step':
        lr = StepLearningRateWithLinearWarmup(steps, schedule_params)

    return lr
コード例 #4
0
ファイル: builder.py プロジェクト: dupeljan/nncf_for_tf
    def load_tfrecords(self) -> tf.data.Dataset:
        logger.info('Using TFRecords to load data.')

        if self._dataset_name in records_dataset.__dict__:
            self._builder = records_dataset.__dict__[self._dataset_name](
                config=self.config, is_train=self.is_train)
        else:
            raise Exception('Undefined dataset name: {}'.format(
                self._dataset_name))

        dataset = self._builder.as_dataset()

        return dataset
コード例 #5
0
ファイル: main.py プロジェクト: dupeljan/nncf_for_tf
def export(config):
    model_builder = retinanet_model.RetinanetModel(config)
    model = model_builder.build_model(pretrained=config.get(
        'pretrained', True),
                                      weights=config.get('weights', None),
                                      mode=ModeKeys.PREDICT_WITH_GT)

    compression_ctrl, compress_model = create_compressed_model(model, config)

    if config.ckpt_path:
        checkpoint = tf.train.Checkpoint(model=compress_model)
        load_checkpoint(checkpoint, config.ckpt_path)

    save_path, save_format = get_saving_parameters(config)
    compression_ctrl.export_model(save_path, save_format)
    logger.info("Saved to {}".format(save_path))
コード例 #6
0
    def build_model(self, pretrained=True, weights=None, mode=None):
        if self._keras_model is None:
            with keras_utils.maybe_enter_backend_graph():
                outputs = self.model_outputs(self._input_layer, mode)

                model = tf.keras.models.Model(inputs=self._input_layer,
                                              outputs=outputs,
                                              name='retinanet')
                assert model is not None, 'Fail to build tf.keras.Model.'
                self._keras_model = model

            if pretrained:
                logger.info('Init backbone')
                init_checkpoint_fn = self.make_restore_checkpoint_fn()
                init_checkpoint_fn(self._keras_model)

            if weights:
                logger.info(
                    'Loaded pretrained weights from {}'.format(weights))
                self._keras_model.load_weights(weights)

        return self._keras_model
コード例 #7
0
    def _restore_checkpoint_fn(keras_model):
        """Loads pretrained model through scaffold function."""
        if not checkpoint_path:
            logger.info('checkpoint_path is empty')
            return

        var_prefix = prefix

        if prefix and not prefix.endswith('/'):
            var_prefix += '/'

        var_to_shape_map = _get_checkpoint_map(checkpoint_path)
        assert var_to_shape_map, 'var_to_shape_map should not be empty'

        vars_to_load = _build_assignment_map(keras_model,
                                             prefix=var_prefix,
                                             skip_variables_regex=skip_regex,
                                             var_to_shape_map=var_to_shape_map)

        if not vars_to_load:
            raise ValueError('Variables to load is empty.')

        tf.compat.v1.train.init_from_checkpoint(checkpoint_path, vars_to_load)
コード例 #8
0
def export(config):
    raise NotImplementedError('Experemental code, please use train + export mode, '
                              'don\'t use only export mode')
    model = tf.keras.Sequential(
        hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/4",
                       trainable=True))
    model.build([None, 224, 224, 3])

    compression_ctrl, compress_model = create_compressed_model(model, config)

    metrics = get_metrics()
    loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)

    compress_model.compile(loss=loss_obj,
                           metrics=metrics)
    compress_model.summary()

    if config.ckpt_path is not None:
        load_checkpoint(model=compress_model,
                        ckpt_path=config.ckpt_path)

    save_path, save_format = get_saving_parameters(config)
    compression_ctrl.export_model(save_path, save_format)
    logger.info("Saved to {}".format(save_path))
コード例 #9
0
    def evaluate(self):
        """Evaluates with detections from all images with COCO API.

        Returns:
            coco_metric: float numpy array with shape [24] representing the
              coco-style evaluation metrics (box and mask).
        """
        if not self._annotation_file:
            logger.info('Thre is no annotation_file in COCOEvaluator.')
            gt_dataset = coco_utils.convert_groundtruths_to_coco_dataset(self._groundtruths)
            coco_gt = coco_utils.COCOWrapper(eval_type='box', gt_dataset=gt_dataset)
        else:
            logger.info('Using annotation file: %s', self._annotation_file)
            coco_gt = self._coco_gt

        coco_predictions = coco_utils.convert_predictions_to_coco_annotations(self._predictions)
        coco_dt = coco_gt.load_res(predictions=coco_predictions)
        image_ids = [ann['image_id'] for ann in coco_predictions]

        coco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='bbox')
        coco_eval.params.imgIds = image_ids
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()
        coco_metrics = coco_eval.stats

        metrics = coco_metrics

        # Cleans up the internal variables in order for a fresh eval next time.
        self.reset()

        metrics_dict = {}
        for i, name in enumerate(self._metric_names):
            metrics_dict[name] = metrics[i].astype(np.float32)

        return metrics_dict
コード例 #10
0
ファイル: builder.py プロジェクト: dupeljan/nncf_for_tf
    def __init__(self, config, num_channels, image_size, num_devices, one_hot,
                 is_train):
        self.config = config

        self._dataset_name = config.get('dataset', 'imagenet2012')
        self._dataset_type = config.get('dataset_type', 'tfrecords')
        self._dataset_dir = config.dataset_dir
        self._num_devices = num_devices
        self._batch_size = config.batch_size
        self._dtype = config.get('dtype', 'float32')
        self._num_preprocess_workers = config.get(
            'workers', tf.data.experimental.AUTOTUNE)

        self._split = 'train' if is_train else 'test'
        self._image_size = image_size
        self._num_channels = num_channels
        self._is_train = is_train
        self._one_hot = one_hot

        self._cache = False
        self._builder = None
        self._skip_decoding = True
        self._shuffle_buffer_size = 10000
        self._deterministic_train = False
        self._use_slack = True

        self._mean_subtract = False
        self._standardize = False

        augmenter_config = self.config.get('augmenter', None)
        if augmenter_config is not None:
            logger.info('Using augmentation: %s', augmenter_config.name)
            self._augmenter = create_augmenter(
                augmenter_config.name, augmenter_config.get('params', {}))
        else:
            self._augmenter = None
コード例 #11
0
ファイル: main.py プロジェクト: dupeljan/nncf_for_tf
def load_checkpoint(checkpoint, ckpt_path):
    logger.info('Load from checkpoint is enabled')
    if tf.io.gfile.isdir(ckpt_path):
        path_to_checkpoint = tf.train.latest_checkpoint(ckpt_path)
        logger.info('Latest checkpoint: {}'.format(path_to_checkpoint))
    else:
        path_to_checkpoint = ckpt_path if tf.io.gfile.exists(
            ckpt_path + '.index') else None
        logger.info('Provided checkpoint: {}'.format(path_to_checkpoint))

    if not path_to_checkpoint:
        logger.info('No checkpoint detected')
        return 0

    logger.info(
        'Checkpoint file {} found and restoring from checkpoint'.format(
            path_to_checkpoint))
    status = checkpoint.restore(path_to_checkpoint)
    status.expect_partial()
    logger.info('Completed loading from checkpoint')

    return None
コード例 #12
0
ファイル: main.py プロジェクト: dupeljan/nncf_for_tf
def train_test_export(config):
    strategy = get_distribution_strategy(config)
    strategy_scope = get_strategy_scope(strategy)

    # Training parameters
    NUM_EXAMPLES_TRAIN = 118287
    NUM_EXAMPLES_EVAL = 5000
    epochs = config.epochs
    batch_size = config.batch_size  # per replica batch size
    num_devices = strategy.num_replicas_in_sync if strategy else 1
    global_batch_size = batch_size * num_devices
    steps_per_epoch = NUM_EXAMPLES_TRAIN // global_batch_size

    # Create Dataset
    train_input_fn = input_reader.InputFn(
        file_pattern=config.train_file_pattern,
        params=config,
        mode=input_reader.ModeKeys.TRAIN,
        batch_size=global_batch_size)

    eval_input_fn = input_reader.InputFn(
        file_pattern=config.eval_file_pattern,
        params=config,
        mode=input_reader.ModeKeys.PREDICT_WITH_GT,
        batch_size=global_batch_size,
        num_examples=NUM_EXAMPLES_EVAL)

    train_dist_dataset = strategy.experimental_distribute_dataset(
        train_input_fn())
    test_dist_dataset = strategy.experimental_distribute_dataset(
        eval_input_fn())

    # Create model builder
    mode = ModeKeys.TRAIN if 'train' in config.mode else ModeKeys.PREDICT_WITH_GT
    model_builder = retinanet_model.RetinanetModel(config)
    eval_metric = model_builder.eval_metrics

    with strategy_scope:
        model = model_builder.build_model(pretrained=config.get(
            'pretrained', True),
                                          weights=config.get('weights', None),
                                          mode=mode)

        compression_ctrl, compress_model = create_compressed_model(
            model, config)
        # compression_callbacks = create_compression_callbacks(compression_ctrl, config.log_dir)

        scheduler = build_scheduler(config=config,
                                    epoch_size=NUM_EXAMPLES_TRAIN,
                                    batch_size=global_batch_size,
                                    steps=steps_per_epoch)

        optimizer = build_optimizer(config=config, scheduler=scheduler)

        eval_metric = model_builder.eval_metrics()
        loss_fn = model_builder.build_loss_fn()
        predict_post_process_fn = model_builder.post_processing

        checkpoint = tf.train.Checkpoint(model=compress_model,
                                         optimizer=optimizer)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, config.checkpoint_save_dir, max_to_keep=None)

        logger.info('initialization...')
        compression_ctrl.initialize(dataset=train_input_fn())

        initial_epoch = 0
        if config.ckpt_path:
            initial_epoch = resume_from_checkpoint(checkpoint_manager,
                                                   config.ckpt_path,
                                                   steps_per_epoch)

    train_step = create_train_step_fn(strategy, compress_model, loss_fn,
                                      optimizer)
    test_step = create_test_step_fn(strategy, compress_model,
                                    predict_post_process_fn)

    if 'train' in config.mode:
        logger.info('Training...')
        train(train_step, test_step, eval_metric, train_dist_dataset,
              test_dist_dataset, initial_epoch, epochs, steps_per_epoch,
              checkpoint_manager, compression_ctrl, config.log_dir, optimizer)

    logger.info('Evaluation...')
    metric_result = evaluate(test_step, eval_metric, test_dist_dataset)
    logger.info('Validation metric = {}'.format(metric_result))

    if 'export' in config.mode:
        save_path, save_format = get_saving_parameters(config)
        compression_ctrl.export_model(save_path, save_format)
        logger.info("Saved to {}".format(save_path))
コード例 #13
0
ファイル: main.py プロジェクト: dupeljan/nncf_for_tf
def train(train_step,
          test_step,
          eval_metric,
          train_dist_dataset,
          test_dist_dataset,
          initial_epoch,
          epochs,
          steps_per_epoch,
          checkpoint_manager,
          compression_ctrl,
          log_dir,
          optimizer,
          save_checkpoint_freq=200):

    train_summary_writer = SummaryWriter(log_dir, 'eval_train')
    test_summary_writer = SummaryWriter(log_dir, 'eval_test')

    logger.info('Training started')
    for epoch in range(initial_epoch, epochs):
        logger.info('Epoch {}/{}'.format(epoch, epochs))

        statistics = compression_ctrl.statistics()
        train_summary_writer(metrics=statistics,
                             step=optimizer.iterations.numpy())
        logger.info('Compression statistics = {}'.format(statistics))

        for step, x in enumerate(train_dist_dataset):
            if step == steps_per_epoch:
                save_path = checkpoint_manager.save()
                logger.info(
                    'Saved checkpoint for step epoch={} step={}: {}'.format(
                        epoch, step, save_path))
                break

            train_loss = train_step(x)
            train_metric_result = tf.nest.map_structure(
                lambda s: s.numpy().astype(float), train_loss)

            if np.isnan(train_metric_result['total_loss']):
                raise ValueError('total loss is NaN')

            train_metric_result.update(
                {'learning_rate': optimizer.lr(optimizer.iterations).numpy()})

            train_summary_writer(metrics=train_metric_result,
                                 step=optimizer.iterations.numpy())

            if step % 100 == 0:
                logger.info('Step {}/{}'.format(step, steps_per_epoch))
                logger.info('Training metric = {}'.format(train_metric_result))

            if step % save_checkpoint_freq == 0:
                save_path = checkpoint_manager.save()
                logger.info(
                    "Saved checkpoint for step epoch={} step={}: {}".format(
                        epoch, step, save_path))

        compression_ctrl.scheduler.epoch_step(epoch)

        logger.info('Evaluation...')
        test_metric_result = evaluate(test_step, eval_metric,
                                      test_dist_dataset)
        test_summary_writer(metrics=test_metric_result,
                            step=optimizer.iterations.numpy())
        eval_metric.reset_states()
        logger.info('Validation metric = {}'.format(test_metric_result))

    train_summary_writer.close()
    test_summary_writer.close()
コード例 #14
0
def resume_from_checkpoint(model, ckpt_path, train_steps):
    if load_checkpoint(model, ckpt_path) == 0:
        return 0
    initial_epoch = model.optimizer.iterations // train_steps
    logger.info('Resuming from epoch %d', initial_epoch)
    return int(initial_epoch)
コード例 #15
0
def load_checkpoint(model, ckpt_path):
    logger.info('Load from checkpoint is enabled.')
    if tf.io.gfile.isdir(ckpt_path):
        checkpoint = tf.train.latest_checkpoint(ckpt_path)
        logger.info('Latest checkpoint: {}'.format(checkpoint))
    else:
        checkpoint = ckpt_path if tf.io.gfile.exists(ckpt_path + '.index') else None
        logger.info('Provided checkpoint: {}'.format(checkpoint))

    if not checkpoint:
        logger.info('No checkpoint detected.')
        return 0

    logger.info('Checkpoint file {} found and restoring from checkpoint'
                .format(checkpoint))
    model.load_weights(checkpoint).expect_partial()
    logger.info('Completed loading from checkpoint.')
    return None
コード例 #16
0
def _build_assignment_map(keras_model,
                          prefix='',
                          skip_variables_regex=None,
                          var_to_shape_map=None):
    """Compute an assignment mapping for loading older checkpoints into a Keras

    model. Variable names are remapped from the original TPUEstimator model to
    the new Keras name.

    Args:
      keras_model: tf.keras.Model object to provide variables to assign.
      prefix: prefix in the variable name to be remove for alignment with names in
        the checkpoint.
      skip_variables_regex: regular expression to math the names of variables that
        do not need to be assign.
      var_to_shape_map: variable name to shape mapping from the checkpoint.

    Returns:
      The variable assignment map.
    """
    assignment_map = {}

    checkpoint_names = None
    if var_to_shape_map:
        predicate = lambda x: not x.endswith('Momentum') and not x.endswith(
            'global_step')
        checkpoint_names = list(filter(predicate, var_to_shape_map.keys()))

    for var in keras_model.variables:
        var_name = var.name

        if skip_variables_regex and re.match(skip_variables_regex, var_name):
            continue

        # Trim the index of the variable.
        if ':' in var_name:
            var_name = var_name[:var_name.rindex(':')]
        if var_name.startswith(prefix):
            var_name = var_name[len(prefix):]

        if not var_to_shape_map:
            assignment_map[var_name] = var
            continue

        # Match name with variables in the checkpoint.
        match_names = []
        for x in checkpoint_names:
            if x.endswith(var_name):
                match_names.append(x)

        try:
            if match_names:
                assert len(match_names
                           ) == 1, 'more then on matches for {}: {}'.format(
                               var_name, match_names)
                checkpoint_names.remove(match_names[0])
                assignment_map[match_names[0]] = var
            else:
                logger.info('Error not found var name: %s', var_name)
        except Exception as ex:
            logger.info('Error removing the match_name: %s', match_names)
            logger.info('Exception: %s', ex)
            raise

    logger.info('Found variable in checkpoint: %d', len(assignment_map))

    return assignment_map
コード例 #17
0
def train_test_export(config):
    strategy = get_distribution_strategy(config)
    strategy_scope = get_strategy_scope(strategy)


    builders = get_dataset_builders(config, strategy)
    datasets = [builder.build() for builder in builders]

    train_builder, validation_builder = builders
    train_dataset, validation_dataset = datasets

    train_epochs = config.epochs
    train_steps = train_builder.num_steps
    validation_steps = validation_builder.num_steps

    if config.model_type == ModelType.KerasLayer:
        args = get_KerasLayer_model()
    else:
        args = None

    with strategy_scope:
        from op_insertion import NNCFWrapperCustom
        if not args:
            args = get_model(config.model_type)

        model = tf.keras.Sequential([
            tf.keras.layers.Input(shape=(224, 224, 3)),
            NNCFWrapperCustom(*args)
        ])
        if SAVE_MODEL_WORKAROUND:
            path = '/tmp/model.pb'
            model.save(path, save_format='tf')
            model = tf.keras.models.load_model(path)


        compression_ctrl, compress_model = create_compressed_model(model, config)
        compression_callbacks = create_compression_callbacks(compression_ctrl, config.log_dir)

        scheduler = build_scheduler(
            config=config,
            epoch_size=train_builder.num_examples,
            batch_size=train_builder.global_batch_size,
            steps=train_steps)
        config['optimizer'] = {'type': 'sgd'}
        optimizer = build_optimizer(
            config=config,
            scheduler=scheduler)

        metrics = get_metrics()
        loss_obj = get_loss()

        compress_model.compile(optimizer=optimizer,
                               loss=loss_obj,
                               metrics=metrics,
                               run_eagerly=config.get('eager_mode', False))

        compress_model.summary()

        logger.info('initialization...')
        compression_ctrl.initialize(dataset=train_dataset)

        initial_epoch = 0
        if config.ckpt_path is not None:
            initial_epoch = resume_from_checkpoint(model=compress_model,
                                                   ckpt_path=config.ckpt_path,
                                                   train_steps=train_steps)

    callbacks = get_callbacks(
        model_checkpoint=True,
        include_tensorboard=True,
        time_history=True,
        track_lr=True,
        write_model_weights=False,
        initial_step=initial_epoch * train_steps,
        batch_size=train_builder.global_batch_size,
        log_steps=100,
        model_dir=config.log_dir)

    callbacks.extend(compression_callbacks)

    validation_kwargs = {
        'validation_data': validation_dataset,
        'validation_steps': validation_steps,
        'validation_freq': 1,
    }

    if 'train' in config.mode:
        logger.info('training...')
        compress_model.fit(
            train_dataset,
            epochs=train_epochs,
            steps_per_epoch=train_steps,
            initial_epoch=initial_epoch,
            callbacks=callbacks,
            **validation_kwargs)

    logger.info('evaluation...')
    compress_model.evaluate(
        validation_dataset,
        steps=validation_steps,
        verbose=1)

    if 'export' in config.mode:
        save_path, save_format = get_saving_parameters(config)
        compression_ctrl.export_model(save_path, save_format)
        logger.info("Saved to {}".format(save_path))
コード例 #18
0
def build_optimizer(config, scheduler):
    optimizer_config = config.get('optimizer', {})

    optimizer_type = optimizer_config.get('type', 'adam').lower()
    optimizer_params = optimizer_config.get("optimizer_params", {})

    logger.info('Building %s optimizer with params %s', optimizer_type,
                optimizer_params)

    if optimizer_type == 'sgd':
        logger.info('Using SGD optimizer')
        nesterov = optimizer_params.get('nesterov', False)
        optimizer = tf.keras.optimizers.SGD(learning_rate=scheduler,
                                            nesterov=nesterov)
    elif optimizer_type == 'momentum':
        logger.info('Using momentum optimizer')
        nesterov = optimizer_params.get('nesterov', False)
        momentum = optimizer_params.get('momentum', 0.9)
        optimizer = tf.keras.optimizers.SGD(learning_rate=scheduler,
                                            momentum=momentum,
                                            nesterov=nesterov)
    elif optimizer_type == 'rmsprop':
        logger.info('Using RMSProp')
        rho = optimizer_params.get('rho', 0.9)
        momentum = optimizer_params.get('momentum', 0.9)
        epsilon = optimizer_params.get('epsilon', 1e-07)
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=scheduler,
                                                rho=rho,
                                                momentum=momentum,
                                                epsilon=epsilon)
    elif optimizer_type == 'adam':
        logger.info('Using Adam')
        beta_1 = optimizer_params.get('beta_1', 0.9)
        beta_2 = optimizer_params.get('beta_2', 0.999)
        epsilon = optimizer_params.get('epsilon', 1e-07)
        optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler,
                                             beta_1=beta_1,
                                             beta_2=beta_2,
                                             epsilon=epsilon)
    elif optimizer_type == 'adamw':
        raise RuntimeError()
        logger.info('Using AdamW')
        weight_decay = optimizer_params.get('weight_decay', 0.01)
        beta_1 = optimizer_params.get('beta_1', 0.9)
        beta_2 = optimizer_params.get('beta_2', 0.999)
        epsilon = optimizer_params.get('epsilon', 1e-07)
        optimizer = tfa.optimizers.AdamW(weight_decay=weight_decay,
                                         learning_rate=scheduler,
                                         beta_1=beta_1,
                                         beta_2=beta_2,
                                         epsilon=epsilon)
    else:
        raise ValueError('Unknown optimizer %s' % optimizer_type)

    moving_average_decay = optimizer_params.get('moving_average_decay', 0.)
    if moving_average_decay > 0.:
        logger.info('Including moving average decay.')
        optimizer = tfa.optimizers.MovingAverage(
            optimizer, average_decay=moving_average_decay, num_updates=None)
    if optimizer_params.get('lookahead', None):
        raise RuntimeError()
        logger.info('Using lookahead optimizer.')
        optimizer = tfa.optimizers.Lookahead(optimizer)

    return optimizer