def as_infinite_iterator( dataset: tf.data.Dataset, steps_per_epoch: Optional[int] = None) -> Tuple[tf.data.Iterator, int]: """ Get an iterator for an infinite dataset and steps_per_epoch. Args: dataset: possibly infinite dataset. steps_per_epoch: number of steps per epoch if `dataset` has infinite cardinality, otherwise `None` or `dataset`'s cardinality. Returns: iterator: tf.data.Iterator of possibly repeated `dataset`. steps_per_epoch: number of elements in iterator considered one epoch. Raises: ValueError is dataset has finite cardinality inconsistent with steps_per_epoch. """ cardinality = tf.keras.backend.get_value(dataset.cardinality()) if steps_per_epoch is None: steps_per_epoch = cardinality if cardinality == tf.data.INFINITE_CARDINALITY: raise ValueError( "steps_per_epoch must be provided if dataset has infinite " "cardinality") dataset = dataset.repeat() elif cardinality != tf.data.INFINITE_CARDINALITY: assert cardinality == steps_per_epoch dataset = dataset.repeat() return iter(dataset), steps_per_epoch
def _train_bert_multitask_keras_model( train_dataset: tf.data.Dataset, eval_dataset: tf.data.Dataset, model: tf.keras.Model, params: BaseParams, mirrored_strategy: tf.distribute.MirroredStrategy = None): # can't save whole model with model subclassing api due to tf bug # see: https://github.com/tensorflow/tensorflow/issues/42741 # https://github.com/tensorflow/tensorflow/issues/40366 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=os.path.join(params.ckpt_dir, 'model'), save_weights_only=True, monitor='val_mean_acc', mode='auto', save_best_only=True) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=params.ckpt_dir) if mirrored_strategy is not None: with mirrored_strategy.scope(): model.fit( x=train_dataset.repeat(), validation_data=eval_dataset, epochs=params.train_epoch, callbacks=[model_checkpoint_callback, tensorboard_callback], steps_per_epoch=params.train_steps_per_epoch) else: model.fit(x=train_dataset.repeat(), validation_data=eval_dataset, epochs=params.train_epoch, callbacks=[model_checkpoint_callback, tensorboard_callback], steps_per_epoch=params.train_steps_per_epoch) model.summary()
def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. Returns: A TensorFlow dataset outputting batched images and labels. """ if self._num_gpus > 1: dataset = dataset.shard(self._num_gpus, hvd.rank()) if self.is_training: # Shuffle the input files. dataset.shuffle(buffer_size=self._file_shuffle_buffer_size) if self.is_training and not self._cache: dataset = dataset.repeat() # Read the data from disk in parallel dataset = dataset.interleave( tf.data.TFRecordDataset, cycle_length=10, block_length=1, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self._cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel preprocess = self.parse_record dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self._num_gpus > 1: # The batch size of the dataset will be multiplied by the number of # replicas automatically when strategy.distribute_datasets_from_function # is called, so we use local batch size here. dataset = dataset.batch(self.local_batch_size, drop_remainder=self.is_training) else: dataset = dataset.batch(self.global_batch_size, drop_remainder=self.is_training) # Apply Mixup mixup_alpha = self.mixup_alpha if self.is_training else 0.0 dataset = dataset.map( functools.partial(self.mixup, self.local_batch_size, mixup_alpha), num_parallel_calls=64) # Prefetch overlaps in-feed with training dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def preprocess(dataset: tf.data.Dataset, feature_layer: tf.keras.layers, target_feature: str, num_epochs: int, shuffle_buffer: int, batch_size: int, batches_to_take=None): """ Preprocess data with a single-element label (label of length one). :param dataset: the dataset to preprocess. :param feature_layer: feature layer to use to preprocess the data. :param target_feature: the name of the target feature (used to extract the correct element from the input observations). :param num_epochs: number of epochs to repeat for; by default, it is set to. :return: """ def element_fn(element): # element_fn extracts feature and label vectors from each element; # 'x' and 'y' names are required by keras. feature_vector = feature_layer(element) return collections.OrderedDict([ ('x', tf.reshape(feature_vector, [feature_vector.shape[1]])), ('y', tf.reshape(element[target_feature], [1])), ]) preprocessed_dataset = dataset.repeat(num_epochs).map(element_fn).shuffle( shuffle_buffer).batch(batch_size) if not batches_to_take: return preprocessed_dataset else: return preprocessed_dataset.take(batches_to_take)
def get_augmented_data( dataset: tf.data.Dataset, batch_size: int, map_func: Callable, shuffle_buffer: Optional[int] = None, shuffle_seed: Optional[int] = None, augment_seed: Optional[int] = None, use_stateless_map: bool = False, ) -> RepeatedData: if shuffle_buffer is not None: dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed) dataset = dataset.batch(batch_size) steps_per_epoch = tf.keras.backend.get_value(dataset.cardinality()) # repeat before map so stateless map is different across epochs dataset = dataset.repeat() AUTOTUNE = tf.data.experimental.AUTOTUNE if use_stateless_map: dataset = dataset.apply( tfrng.data.stateless_map( map_func, seed=augment_seed, num_parallel_calls=AUTOTUNE, )) else: # if map_func has random elements this won't be deterministic dataset = dataset.map(map_func, num_parallel_calls=AUTOTUNE) dataset = dataset.prefetch(AUTOTUNE) return RepeatedData(dataset, steps_per_epoch)
def __init__(self, dataset: tf.data.Dataset, steps_per_epoch: Optional[int] = None): cardinality = tf.keras.backend.get_value(dataset.cardinality()) if steps_per_epoch is None: steps_per_epoch = cardinality if cardinality == tf.data.INFINITE_CARDINALITY: raise ValueError( "steps_per_epoch must be provided if dataset has infinite " "cardinality") dataset = dataset.repeat() elif cardinality != tf.data.INFINITE_CARDINALITY: assert cardinality == steps_per_epoch dataset = dataset.repeat() self._dataset = dataset self._steps_per_epoch = steps_per_epoch
def process(self, dataset: tf.data.Dataset, batch_size: int): dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE) if self.cache: dataset = dataset.cache() if self.shuffle: dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) if self.indefinite: dataset = dataset.repeat() # PADDED BATCH the dataset dataset = dataset.padded_batch( batch_size=batch_size, padded_shapes=( tf.TensorShape([]), tf.TensorShape(self.speech_featurizer.shape), tf.TensorShape([]), tf.TensorShape(self.text_featurizer.shape), tf.TensorShape([]), tf.TensorShape(self.text_featurizer.prepand_shape), tf.TensorShape([]), ), padding_values=(None, 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0), drop_remainder=self.drop_remainder ) # PREFETCH to improve speed of input length dataset = dataset.prefetch(AUTOTUNE) self.total_steps = get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder) return dataset
def iterator_from_dataset( dataset: tf.data.Dataset, batch_size: int, repeat: bool = True, prefetch_size: int = 0, devices: Optional[Sequence[Any]] = None, ): """Create a data iterator that returns JAX arrays from a TF dataset. Args: dataset: the dataset to iterate over. batch_size: the batch sizes the iterator should return. repeat: whether the iterator should repeat the dataset. prefetch_size: the number of batches to prefetch to device. devices: the devices to prefetch to. Returns: An iterator that returns data batches. """ if repeat: dataset = dataset.repeat() if batch_size > 0: dataset = dataset.batch(batch_size) it = map(prepare_tf_data, dataset) else: it = map(prepare_tf_data_unbatched, dataset) if prefetch_size > 0: it = jax_utils.prefetch_to_device(it, prefetch_size, devices) return it
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset: """ Returns a test :class:`~tf.data.Dataset`. Args: test_dataset (:class:`~tf.data.Dataset`): The dataset to use. Subclass and override this method if you want to inject some custom behavior. """ num_examples = tf.data.experimental.cardinality(test_dataset).numpy() if num_examples < 0: raise ValueError( "The training dataset must have an asserted cardinality") approx = math.floor if self.args.dataloader_drop_last else math.ceil steps = approx(num_examples / self.args.eval_batch_size) ds = (test_dataset.repeat().batch( self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last).prefetch( tf.data.experimental.AUTOTUNE)) return self.args.strategy.experimental_distribute_dataset( ds), steps, num_examples
def preprocess(dataset: tf.data.Dataset, num_epoch: int, batch_size: int) -> tf.data.Dataset: def batch_format_fn(element: Dict[str, tf.Tensor]): return (tf.expand_dims(element["pixels"], axis=-1), element["label"]) return dataset.repeat(num_epoch).shuffle(100).batch(batch_size).map( batch_format_fn)
def run_distilibert(strategy: tf.distribute.TPUStrategy, x_train: np.array, x_valid: np.array, _y_train: np.array, y_valid: np.array, train_dataset: tf.data.Dataset, valid_dataset: tf.data.Dataset, test_dataset: tf.data.Dataset, max_len: int, epochs: int, batch_size: int) -> tf.keras.models.Model: """ create and run distilbert on training and testing data """ logger.info('build distilbert') with strategy.scope(): transformer_layer = TFDistilBertModel.from_pretrained(MODEL) model = build_model(transformer_layer, max_len=max_len) model.summary() # train given model n_steps = x_train.shape[0] // batch_size history = model.fit(train_dataset, steps_per_epoch=n_steps, validation_data=valid_dataset, epochs=epochs) plot_train_val_loss(history, 'distilbert') n_steps = x_valid.shape[0] // batch_size _train_history_2 = model.fit(valid_dataset.repeat(), steps_per_epoch=n_steps, epochs=epochs * 2) scores = model.predict(test_dataset, verbose=1) logger.info(f"AUC: {roc_auc(scores, y_valid):.4f}") return model
def mixup( ds: tf.data.Dataset, postmix_fn: typing.Callable[..., typing.Any] = None, num_parallel_calls: int = None, ): """tf.dataでのmixup: <https://arxiv.org/abs/1710.09412> Args: ds: 元のデータセット postmix_fn: mixup後の処理 num_parallel_calls: premix_fnの並列数 """ @tf.function def mixup_fn(*data): r = _tf_random_beta(alpha=0.2, beta=0.2) data = [ tf.cast(d[0], tf.float32) * r + tf.cast(d[1], tf.float32) * (1 - r) for d in data ] return data if postmix_fn is None else postmix_fn(*data) ds = ds.repeat() ds = ds.batch(2) ds = ds.map( mixup_fn, num_parallel_calls=num_parallel_calls, deterministic=None if num_parallel_calls is None else False, ) return ds
def train(self, dataset: tf.data.Dataset, nr_records: int): dataset = dataset.batch(self.batch_size).map(self.transform_example) dataset = dataset.repeat() dataset = dataset.shuffle(1000) self.model.fit(dataset, epochs=self.epochs, steps_per_epoch=nr_records // self.batch_size)
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset: """ Returns a test :class:`~tf.data.Dataset`. Args: test_dataset (:class:`~tf.data.Dataset`): The dataset to use. The dataset should yield tuples of ``(features, labels)`` where ``features`` is a dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss is calculated by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such as when using a QuestionAnswering head model with multiple targets, the loss is instead calculated by calling ``model(features, **labels)``. Subclass and override this method if you want to inject some custom behavior. """ num_examples = tf.data.experimental.cardinality(test_dataset).numpy() if num_examples < 0: raise ValueError( "The training dataset must have an asserted cardinality") approx = math.floor if self.args.dataloader_drop_last else math.ceil steps = approx(num_examples / self.args.eval_batch_size) ds = (test_dataset.repeat().batch( self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last).prefetch( tf.data.experimental.AUTOTUNE)) return self.args.strategy.experimental_distribute_dataset( ds), steps, num_examples
def preprocess_dataset(dataset: tf.data.Dataset, batch_size: int, n_step_returns: int, discount: float): d_len = sum([1 for _ in dataset]) dataset = dataset.map(lambda *x: n_step_transition_from_episode(*x, n_step=n_step_returns, additional_discount=discount)) dataset = dataset.repeat().shuffle(d_len).batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def prepare_for_testing(data_set: tf.data.Dataset, batch_size, cache_path=''): if cache_path != '': cache_filename = 'dataset_test.tfcache' data_set = data_set.cache(''.join([cache_path, '/', cache_filename])) data_set = data_set.repeat() data_set = data_set.batch(batch_size=batch_size) return data_set
def _prepare_test_dataset(dataset: tf.data.Dataset, batch_size, cache_path=''): if cache_path != '': cache_filename = 'dataset_test.tfcache' dataset = dataset.cache( os.path.join(opt.data_path, cache_path, cache_filename)) # dataset = dataset.cache(''.join([cache_path, '/', cache_filename])) dataset = dataset.repeat() dataset = dataset.batch(batch_size=batch_size) return dataset
def preprocess_fn(dataset: tf.data.Dataset) -> tf.data.Dataset: if shuffle_buffer_size > 1: dataset = dataset.shuffle(shuffle_buffer_size, seed=debug_seed) if preprocess_spec.num_epochs > 1: dataset = dataset.repeat(preprocess_spec.num_epochs) if preprocess_spec.max_elements is not None: dataset = dataset.take(preprocess_spec.max_elements) dataset = dataset.batch(preprocess_spec.batch_size, drop_remainder=False) return dataset.map( mapping_fn, num_parallel_calls=num_parallel_calls, deterministic=debug_seed is not None)
def train_fn(ds: tf.data.Dataset, batch_size=1, shuffle=10000, repeat: int = None): '''Create input function for training, prediction, evaluation.''' if shuffle: ds = ds.shuffle(shuffle) ds = ds.batch(batch_size) if repeat != 1: ds = ds.repeat(repeat) return lambda: ds.make_one_shot_iterator().get_next()
def fit( self, train: tf.data.Dataset, valid: Optional[tf.data.Dataset] = None, valid_freq=500, valid_interval=0, optimizer='adam', learning_rate=1e-3, clipnorm=None, epochs=-1, max_iter=1000, sample_shape=(), # for ELBO analytic=False, # for ELBO iw=False, # for ELBO callback=lambda: None, compile_graph=True, autograph=False, logging_interval=2, skip_fitted=False, log_tag='', log_path=None): if self.is_fitted and skip_fitted: return self from odin.exp.trainer import Trainer trainer = Trainer() self.trainer = trainer # create the optimizer if optimizer is not None and self.optimizer is None: self.optimizer = _to_optimizer(optimizer, learning_rate, clipnorm) if self.optimizer is None: raise RuntimeError("No optimizer found!") self._trainstep_kw = dict(sample_shape=sample_shape, iw=iw, elbo_kw=dict(analytic=analytic)) # if already called repeat, then no need to repeat more if hasattr(train, 'repeat'): train = train.repeat(int(epochs)) trainer.fit(train_ds=train, optimize=self.optimize, valid_ds=valid, valid_freq=valid_freq, valid_interval=valid_interval, compile_graph=compile_graph, autograph=autograph, logging_interval=logging_interval, log_tag=log_tag, log_path=log_path, max_iter=max_iter, callback=callback) self._trainstep_kw = dict() return self
def _prepare_dataset( dataset: tf.data.Dataset, global_batch_size: int, shuffle: bool, rng: np.ndarray, preprocess_fn: Optional[Callable[[Any], Any]] = None, num_epochs: Optional[int] = None, filter_fn: Optional[Callable[[Any], Any]] = None) -> tf.data.Dataset: """Batches, shuffles, prefetches and preprocesses a dataset. Args: dataset: The dataset to prepare. global_batch_size: The global batch size to use. shuffle: Whether the shuffle the data on example level. rng: PRNG for seeding the shuffle operations. preprocess_fn: Preprocessing function that will be applied to every example. num_epochs: Number of epochs to repeat the dataset. filter_fn: Funtion that filters samples according to some criteria. Returns: The dataset. """ if shuffle and rng is None: raise ValueError("Shuffling without RNG is not supported.") if global_batch_size % jax.host_count() != 0: raise ValueError( f"Batch size {global_batch_size} not divisible by number " f"of hosts ({jax.host_count()}).") local_batch_size = global_batch_size // jax.host_count() batch_dims = [jax.local_device_count(), local_batch_size] # tf.data uses single integers as seed. if rng is not None: rng = rng[0] ds = dataset.repeat(num_epochs) if shuffle: ds = ds.shuffle(1024, seed=rng) if preprocess_fn is not None: ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) if filter_fn is not None: ds = ds.filter(filter_fn) for batch_size in reversed(batch_dims): ds = ds.batch(batch_size, drop_remainder=True) return ds.prefetch(tf.data.experimental.AUTOTUNE)
def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset: """Create a dataset with repeated loops over the input elements. Args: ds_input: Any dataset. Returns: A `tf.data.Dataset` with elements containing the same keys, but repeated for `epochs` iterations. """ if self.repeat: return ds_input.repeat(count=self.epochs) else: return ds_input
def fit(self, data: tf.data.Dataset, epochs=1, steps_per_epoch=1, validation_data=None, validation_steps=1, **flow_kwargs): data.repeat(epochs) if validation_data is not None: validation_data = validation_data.repeat(epochs) test_hist = dict() for epoch in range(epochs): train_hist = dict() with tqdm(total=steps_per_epoch, desc=f'train, epoch {epoch+1}/{epochs}') as prog: for i, (x, y) in enumerate(data.take(steps_per_epoch)): loss, nll = self.train_batch(x, y, **flow_kwargs) utils.update_metrics(train_hist, loss=loss.numpy(), nll=nll.numpy()) prog.update(1) prog.set_postfix(utils.get_metrics(train_hist)) with tqdm(total=validation_steps, desc=f'test, epoch {epoch+1}/{epochs}') as prog: if validation_data is None: continue for i, (x, y) in enumerate(validation_data.take(validation_steps)): nll = self.eval_batch(x, y, **flow_kwargs) utils.update_metrics(test_hist, nll=nll.numpy()) prog.update(1) prog.set_postfix(utils.get_metrics(test_hist)) return test_hist
def repeat(data: tf.data.Dataset) -> tf.data.Dataset: """ Repeat dataset Parameters ---------- data tensorflow dataset to cache Returns ------- data_repeated repeated data """ data = data.repeat() return data
def prepare_dataset( dataset: tf.data.Dataset, model_image_size: Tuple[int, int], augmentation_fn: Optional[ImageDataMapFn] = None, num_epochs: Optional[int] = None, batch_size: Optional[int] = None, shuffle_buffer_size: Optional[int] = None, num_parallel_calls: Optional[int] = None, prefetch_buffer_size: Optional[int] = None, prefetch_to_device: Optional[str] = None, ) -> tf.data.Dataset: # apply data augmentation: if augmentation_fn is not None: dataset = dataset.map( map_image_data(augmentation_fn), num_parallel_calls=num_parallel_calls, ) if shuffle_buffer_size is not None: dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) dataset = dataset.repeat(num_epochs) dataset = dataset.map( map_image_data(prepare_for_batching(model_image_size)), num_parallel_calls=num_parallel_calls, ) # batching and padding if batch_size is not None: dataset = dataset.padded_batch( batch_size=batch_size, padded_shapes=get_padding_shapes( dataset, spatial_image_shape=model_image_size), drop_remainder=True, ) # try to prefetch dataset on certain device if prefetch_to_device is not None: prefetch_fn = tf.data.experimental.prefetch_to_device( device=prefetch_to_device, buffer_size=prefetch_buffer_size) dataset = dataset.apply(prefetch_fn) else: if prefetch_buffer_size is not None: dataset = dataset.prefetch(buffer_size=prefetch_buffer_size) return dataset
def get_tfds_data_loader(data : tf.data.Dataset, data_subset_mode='train', batch_size=32, num_samples=100, num_classes=19, infinite=True, augment=True, seed=2836): def encode_example(x, y): x = tf.image.convert_image_dtype(x, tf.float32) * 255.0 y = _encode_label(y, num_classes=num_classes) return x, y test_d = next(iter(data)) print(test_d[0].numpy().min()) print(test_d[0].numpy().max()) data = data.shuffle(buffer_size=num_samples) \ .cache() \ .map(encode_example, num_parallel_calls=AUTOTUNE) test_d = next(iter(data)) print(test_d[0].numpy().min()) print(test_d[0].numpy().max()) data = data.map(preprocess_input, num_parallel_calls=AUTOTUNE) test_d = next(iter(data)) print(test_d[0].numpy().min()) print(test_d[0].numpy().max()) if data_subset_mode == 'train': data = data.shuffle(buffer_size=100, seed=seed) augmentor = TRAIN_image_augmentor elif data_subset_mode == 'val': augmentor = VAL_image_augmentor elif data_subset_mode == 'test': augmentor = TEST_image_augmentor if augment: data = augmentor.apply_augmentations(data) test_d = next(iter(data)) print(test_d[0].numpy().min()) print(test_d[0].numpy().max()) data = data.batch(batch_size, drop_remainder=True) if infinite: data = data.repeat() return data.prefetch(AUTOTUNE)
def prepare_for_training(data_set: tf.data.Dataset, batch_size, cache_path=None, shuffle_buffer_size=1000): if cache_path != '': cache_filename = 'dataset_train.tfcache' data_set = data_set.cache(''.join([cache_path, '/', cache_filename])) data_set = data_set.shuffle(buffer_size=shuffle_buffer_size) # repeat forever data_set = data_set.repeat() data_set = data_set.batch(batch_size=batch_size) # `prefetch` lets the dataset fetch batches in the background # while the model is training. data_set = data_set.prefetch(buffer_size=AUTOTUNE) return data_set
def preprocess_tf_dataset(dataset: tf.data.Dataset, hparams: ClientDataHParams) -> tf.data.Dataset: """Preprocesses dataset according to the dataset hyperparmeters. Args: dataset: Dataset with a mapping element structure. hparams: Hyper parameters for dataset preparation. Returns: Preprocessed dataset. """ dataset = dataset.repeat(hparams.num_epochs) if hparams.shuffle_buffer_size: dataset = dataset.shuffle(hparams.shuffle_buffer_size) dataset = (dataset.batch( hparams.batch_size, drop_remainder=hparams.drop_remainder).prefetch(1)) return dataset.take(hparams.num_batches)
def profile_model( model: tf.keras.Model, dataset: tf.data.Dataset, inference_only: bool = False, **kwargs, ): if dataset.cardinality() != tf.data.INFINITE_CARDINALITY: dataset = dataset.repeat() it = iter(dataset) model_func = (model.make_predict_function() if inference_only else model.make_train_function()) def func(): return model_func(it) return profile_func(func, **kwargs, name="predict" if inference_only else "train")
def benchmark_model(model: tf.keras.Model, dataset: tf.data.Dataset, inference_only=False, **kwargs): if dataset.cardinality() != tf.data.INFINITE_CARDINALITY: dataset = dataset.repeat() inputs, labels, sample_weight = tf.keras.utils.unpack_x_y_sample_weight( as_inputs(dataset)) if inference_only: op = model(inputs) else: variables = model.trainable_variables with tf.GradientTape() as tape: predictions = model(inputs) loss = model.loss(labels, predictions, sample_weight=sample_weight) grads = tape.gradient(loss, variables) op = model.optimizer.apply_gradients(zip(grads, variables)) return benchmark_op(op, **kwargs)