コード例 #1
0
def predict_on_batch(model, x):
    """Returns predictions for a single batch of samples.

  Arguments:
      model: The model to predict with.
      x: Input data. It could be:
        - A Numpy array (or array-like), or a list of arrays
          (in case the model has multiple inputs).
        - A TensorFlow tensor, or a list of tensors
          (in case the model has multiple inputs).
        - A `tf.data` dataset.

  Returns:
      Numpy array(s) of predictions.

  Raises:
      ValueError: In case of mismatch between given number of inputs and
        expectations of the model.
  """
    # TODO(scottzhu): Standardization should happen in the data handlers,
    ## not on a per batch basis in the *_on_batch methods
    # Validate and standardize user data.
    inputs, _, _ = model._standardize_user_data(
        x, extract_tensors_from_dataset=True)

    # If `model._distribution_strategy` is True, then we are in a replica context
    # at this point.
    inputs = training_utils.cast_if_floating_dtype(inputs)
    if isinstance(inputs, collections.Sequence):
        # Unwrap lists with only one input, as we do when training on batch
        if len(inputs) == 1:
            inputs = inputs[0]

    with backend.eager_learning_phase_scope(0):
        return model(inputs)  # pylint: disable=not-callable
コード例 #2
0
ファイル: training_eager.py プロジェクト: stegon98/tensorflow
def test_on_batch(model,
                  inputs,
                  targets,
                  sample_weights=None,
                  output_loss_metrics=None):
    """Calculates the loss for one input batch.

  Arguments:
      model: Model whose loss has to be calculated.
      inputs: Input batch data.
      targets: Target batch data.
      sample_weights: Sample weight batch data.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.

  Returns:
      total loss, loss and metrics associated with each output.
  """
    if isinstance(inputs, collections.Sequence):
        if len(inputs) and tensor_util.is_tensor(inputs[0]):
            inputs = training_utils.cast_if_floating_to_model_input_dtypes(
                inputs, model)
            if targets:
                targets = training_utils.cast_if_floating_dtype(targets)
        else:
            inputs = training_utils.cast_if_floating_to_model_input_dtypes(
                inputs, model)
            if targets:
                targets = training_utils.cast_if_floating_dtype(targets)
    if sample_weights:
        sample_weights = [
            training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
            if val is not None else None for val in sample_weights
        ]
    with backend.eager_learning_phase_scope(0):
        outs, total_loss, output_losses, masks = (_model_loss(
            model,
            inputs,
            targets,
            sample_weights=sample_weights,
            training=False,
            output_loss_metrics=output_loss_metrics))
    if not isinstance(outs, list):
        outs = [outs]
    metrics_results = _eager_metrics_fn(model,
                                        outs,
                                        targets,
                                        sample_weights=sample_weights,
                                        masks=masks)
    total_loss = nest.flatten(total_loss)
    results = total_loss + output_losses + metrics_results

    return results
コード例 #3
0
def _process_single_batch(model,
                          inputs,
                          targets,
                          output_loss_metrics=None,
                          sample_weights=None,
                          training=False):
    """Calculate the loss and gradient for one input batch.

     The model weights are updated if training is set to True.

  Arguments:
      model: Model whose loss has to be calculated.
      inputs: List of input arrays.
      targets: List of target arrays.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.
      sample_weights: Optional list of sample weight arrays.
      training: The boolean represents if the weights of the model are updated.
              'fit' methods will set this to True while 'evaluate' methods will
              set this to False.

  Returns:
      output of the model, total loss, the loss and the mask
      associated with each output.

  Raises:
      ValueError: If the model has no loss to optimize.
  """
    with backend.eager_learning_phase_scope(1 if training else 0):
        with GradientTape() as tape:
            outs, total_loss, output_losses, aggregated_output_losses, masks = (
                _model_loss(model,
                            inputs,
                            targets,
                            output_loss_metrics=output_loss_metrics,
                            sample_weights=sample_weights,
                            training=training))
            if total_loss is None:
                raise ValueError('The model cannot be run '
                                 'because it has no loss to optimize.')
        if training:
            if not model._collected_trainable_weights:
                logging.warning(
                    'The list of trainable weights is empty. Make sure that'
                    ' you are not setting model.trainable to False before '
                    'compiling the model.')
            else:
                grads = tape.gradient(total_loss,
                                      model._collected_trainable_weights)
                model.optimizer.apply_gradients(
                    zip(grads, model._collected_trainable_weights))
        return outs, total_loss, output_losses, aggregated_output_losses, masks
コード例 #4
0
def _process_single_batch(model,
                          inputs,
                          targets,
                          output_loss_metrics=None,
                          sample_weights=None,
                          training=False):
  """Calculate the loss and gradient for one input batch.

     The model weights are updated if training is set to True.

  Arguments:
      model: Model whose loss has to be calculated.
      inputs: List of input arrays.
      targets: List of target arrays.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.
      sample_weights: Optional list of sample weight arrays.
      training: The boolean represents if the weights of the model are updated.
              'fit' methods will set this to True while 'evaluate' methods will
              set this to False.

  Returns:
      output of the model, total loss, the loss and the mask
      associated with each output.

  Raises:
      ValueError: If the model has no loss to optimize.
  """
  with backend.eager_learning_phase_scope(1 if training else 0):
    with GradientTape() as tape:
      outs, total_loss, output_losses, aggregated_output_losses, masks = (
          _model_loss(
              model,
              inputs,
              targets,
              output_loss_metrics=output_loss_metrics,
              sample_weights=sample_weights,
              training=training))
      if total_loss is None:
        raise ValueError('The model cannot be run '
                         'because it has no loss to optimize.')
    if training:
      if not model.trainable_weights:
        logging.warning('The list of trainable weights is empty. Make sure that'
                        ' you are not setting model.trainable to False before '
                        'compiling the model.')
      else:
        grads = tape.gradient(total_loss, model.trainable_weights)
        model.optimizer.apply_gradients(zip(grads,
                                            model.trainable_weights))
    return outs, total_loss, output_losses, aggregated_output_losses, masks
コード例 #5
0
def predict_with_uncertainty(model, testdata, ci=0.95, n_iter=100):
    func = K.function([model.input], [model.output])
    with eager_learning_phase_scope(value=1):
        result = []
        for i in range(n_iter):
            print
            result.append(func([testdata]))

        result = np.array(result)
        predmean = result.mean(axis=0).reshape(-1, )
        predsd = result.std(axis=0).reshape(-1, )
        lowerCI = predmean - scipy.stats.norm.ppf(1 - 0.5 * (1 - ci)) * predsd
        upperCI = predmean + scipy.stats.norm.ppf(1 - 0.5 * (1 - ci)) * predsd
        return np.exp(predmean), np.exp(lowerCI), np.exp(upperCI)
コード例 #6
0
ファイル: training_eager_v1.py プロジェクト: mrax714/nearme
def test_on_batch(model,
                  inputs,
                  targets,
                  sample_weights=None,
                  output_loss_metrics=None):
    """Calculates the loss for one input batch.

  Args:
      model: Model whose loss has to be calculated.
      inputs: Input batch data.
      targets: Target batch data.
      sample_weights: Sample weight batch data.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.

  Returns:
      Dict with three items:
        'total_loss': single tensor for overall loss,
        'output_losses': list of tensors for loss corresponding to each of the
          model output. Could be a empty list when model has only one output.
        'metrics': list of tensors for metric specified.
  """
    inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)

    with backend.eager_learning_phase_scope(0):
        outs, total_loss, output_losses, masks = (_model_loss(
            model,
            inputs,
            targets,
            sample_weights=sample_weights,
            training=False,
            output_loss_metrics=output_loss_metrics))
    if not isinstance(outs, list):
        outs = [outs]
    metrics_results = _eager_metrics_fn(model,
                                        outs,
                                        targets,
                                        sample_weights=sample_weights,
                                        masks=masks)
    total_loss = nest.flatten(total_loss)

    return {
        'total_loss': total_loss,
        'output_losses': output_losses,
        'metrics': metrics_results
    }
コード例 #7
0
def get_embeddings_low_mem(model, seq_input, chrom_input):
    f = keras_extract_fn(model)

    embedding_list_by_batch = []
    # iterate in batches for processing large datasets.
    for batch_start_idx in range(0, len(seq_input), 500):
        batch_end_idx = min(batch_start_idx + 500, len(seq_input))
        current_batch_seq = seq_input[batch_start_idx:batch_end_idx]
        current_batch_chrom = chrom_input[batch_start_idx:batch_end_idx]
        with eager_learning_phase_scope(value=0):
            sn_activations = np.array(
                f([current_batch_seq, current_batch_chrom]))
        activations_rs = np.reshape(sn_activations,
                                    (sn_activations.shape[1], 2))
        activations_rs = activations_rs.astype(np.float64)
        embedding_list_by_batch.append(activations_rs)

    activations = np.vstack(embedding_list_by_batch)
    w, b = model.layers[-1].get_weights()
    w = np.reshape(w, (2, ))
    weighted_embeddings = activations * w
    return weighted_embeddings
コード例 #8
0
    def predict(self, x, n_iter=100):
        """
        Args:
            x: xrd spectrum to be classified
        Returns:
            prediction: distribution of probabilities associated with reference phases
            len(certainties): number of phases with probabilities > 10%
            certanties: associated probabilities
        """

        x = [[val] for val in x]
        x = np.array([x])
        result = []
        with eager_learning_phase_scope(value=1):
            for _ in range(n_iter):
                result.append(self.f(x))

        result = np.array([
            list(np.array(sublist).flatten()) for sublist in result
        ])  ## Individual predictions
        prediction = result.mean(axis=0)  ## Average prediction

        all_preds = [np.argmax(pred) for pred in result
                     ]  ## Individual max indices (associated with phases)

        counts = []
        for index in set(all_preds):
            counts.append(all_preds.count(
                index))  ## Tabulate how many times each prediction arises

        certanties = []
        for each_count in counts:
            conf = each_count / sum(counts)
            if conf >= 0.1:  ## If prediction occurs at least 10% of the time
                certanties.append(conf)
        certanties = sorted(certanties, reverse=True)

        return prediction, len(certanties), certanties
コード例 #9
0
ファイル: training_eager.py プロジェクト: zwcdp/tensorflow
def test_on_batch(model,
                  inputs,
                  targets,
                  sample_weights=None,
                  output_loss_metrics=None):
  """Calculates the loss for one input batch.

  Arguments:
      model: Model whose loss has to be calculated.
      inputs: Input batch data.
      targets: Target batch data.
      sample_weights: Sample weight batch data.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.

  Returns:
      total loss, loss and metrics associated with each output.
  """
  inputs = training_utils.cast_to_model_input_dtypes(inputs, model)

  with backend.eager_learning_phase_scope(0):
    outs, total_loss, output_losses, masks = (
        _model_loss(
            model,
            inputs,
            targets,
            sample_weights=sample_weights,
            training=False,
            output_loss_metrics=output_loss_metrics))
  if not isinstance(outs, list):
    outs = [outs]
  metrics_results = _eager_metrics_fn(
      model, outs, targets, sample_weights=sample_weights, masks=masks)
  total_loss = nest.flatten(total_loss)
  results = total_loss + output_losses + metrics_results

  return results
コード例 #10
0
ファイル: training_eager_v1.py プロジェクト: mrax714/nearme
def _process_single_batch(model,
                          inputs,
                          targets,
                          output_loss_metrics=None,
                          sample_weights=None,
                          training=False):
    """Calculate the loss and gradient for one input batch.

     The model weights are updated if training is set to True.

  Args:
      model: Model whose loss has to be calculated.
      inputs: List of input arrays.
      targets: List of target arrays.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.
      sample_weights: Optional list of sample weight arrays.
      training: The boolean represents if the weights of the model are updated.
              'fit' methods will set this to True while 'evaluate' methods will
              set this to False.

  Returns:
      output of the model, total loss, the loss and the mask
      associated with each output.

  Raises:
      ValueError: If the model has no loss to optimize.
  """
    with backend.eager_learning_phase_scope(1 if training else 0), \
        training_utils.RespectCompiledTrainableState(model):
        with GradientTape() as tape:
            outs, total_loss, output_losses, masks = (_model_loss(
                model,
                inputs,
                targets,
                output_loss_metrics=output_loss_metrics,
                sample_weights=sample_weights,
                training=training))
            if isinstance(model.optimizer,
                          loss_scale_optimizer.LossScaleOptimizer):
                scaled_total_loss = model.optimizer.get_scaled_loss(total_loss)
            else:
                scaled_total_loss = total_loss
        if training:
            trainable_weights = model.trainable_weights
            if trainable_weights:
                # TODO(tanzheny) b/132690565: Provide mechanism for user to override
                # model.train_on_batch.
                if hasattr(model, '_backwards'):
                    model._backwards(tape, scaled_total_loss)
                else:
                    grads = tape.gradient(scaled_total_loss, trainable_weights)
                    if isinstance(model.optimizer,
                                  loss_scale_optimizer.LossScaleOptimizer):
                        grads = model.optimizer.get_unscaled_gradients(grads)
                    model.optimizer.apply_gradients(
                        zip(grads, trainable_weights))
            else:
                logging.warning(
                    'The list of trainable weights is empty. Make sure that'
                    ' you are not setting model.trainable to False before '
                    'compiling the model.')
        return outs, total_loss, output_losses, masks
コード例 #11
0
def model_iteration(model,
                    data,
                    steps_per_epoch=None,
                    epochs=1,
                    verbose=1,
                    callbacks=None,
                    validation_data=None,
                    validation_steps=None,
                    validation_freq=1,
                    class_weight=None,
                    max_queue_size=10,
                    workers=1,
                    use_multiprocessing=False,
                    shuffle=False,
                    initial_epoch=0,
                    mode=ModeKeys.TRAIN,
                    batch_size=None,
                    steps_name='steps',
                    **kwargs):
    """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.

  Arguments:
      model: Keras Model instance.
      data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or
        `(x, y, sample_weights)`) or a generator or
        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
      steps_per_epoch: Total number of steps (batches of samples) before
        declaring one epoch finished and starting the next epoch. Ignored with
        the default value of `None`.
      epochs: Number of times to iterate over the data.
      verbose: 0, 1, or 2. Verbosity mode.
        0 = silent, 1 = progress bar, 2 = one line per epoch.
        Note that the progress bar is not particularly useful when
        logged to a file, so verbose=2 is recommended when not running
        interactively (eg, in a production environment).
      callbacks: List of callbacks to be called during training.
      validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
        `(x, y)` or `(x, y, sample_weights)`) or a generator or
        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
      validation_steps: Total number of steps (batches of samples) before
        declaring validation finished.
      validation_freq: Only relevant if validation data is provided. Integer or
        `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
        integer, specifies how many training epochs to run before a new
        validation run is performed, e.g. `validation_freq=2` runs
        validation every 2 epochs. If a Container, specifies the epochs on
        which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
        validation at the end of the 1st, 2nd, and 10th epochs.
      class_weight: Dictionary mapping class indices to a weight for the class.
      max_queue_size: Integer. Maximum size for the generator queue. If
        unspecified, `max_queue_size` will default to 10.
      workers: Integer. Maximum number of processes to spin up when using
        process-based threading. If unspecified, `workers` will default to 1. If
        0, will execute the generator on the main thread.
      use_multiprocessing: Boolean. If `True`, use process-based threading. If
        unspecified, `use_multiprocessing` will default to `False`. Note that
        because this implementation relies on multiprocessing, you should not
        pass non-picklable arguments to the generator as they can't be passed
        easily to children processes.
      shuffle: Boolean. Whether to shuffle the order of the batches at the
        beginning of each epoch. Only used with instances of `Sequence`
        (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not
        `None`.
      initial_epoch: Epoch at which to start training (useful for resuming a
        previous training run).
      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
      batch_size: Integer batch size or None if unknown. Will only be used if
        `data` is in NumPy/Tensor format.
      steps_name: The string name of the steps argument, either `steps`,
        `validation_steps`, or `steps_per_epoch`. Only used for error message
        formatting.
      **kwargs: Additional arguments for backwards compatibility. `steps` is
        accepted as an alias for `steps_per_epoch`.

  Returns:
      - In TRAIN mode: `History` object.
      - In TEST mode: Evaluation metrics.
      - In PREDICT mode: Outputs of the Model called on inputs.

  Raises:
      ValueError: in case of invalid arguments.
  """
    if 'steps' in kwargs:
        steps_per_epoch = kwargs['steps']

    # Determine the number of steps per epoch and whether we should reset the
    # dataset at the end of each epoch.
    reset_dataset_after_each_epoch = False
    original_dataset = None
    is_dataset = isinstance(data,
                            (dataset_ops.DatasetV2, dataset_ops.DatasetV1))
    if is_dataset:
        original_dataset = data
        if steps_per_epoch is None:
            reset_dataset_after_each_epoch = True
            steps_per_epoch = training_utils.infer_steps_for_dataset(
                model,
                data,
                steps_per_epoch,
                epochs=epochs,
                steps_name=steps_name)

    # Convert to a format that supports `next(generator)`.
    generator, steps_per_epoch = convert_to_generator_like(
        data,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        epochs=epochs - initial_epoch,
        shuffle=shuffle)

    do_validation = validation_data is not None
    is_sequence = isinstance(generator, data_utils.Sequence)
    _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
                        steps_per_epoch, validation_data, validation_steps,
                        mode, kwargs)

    batch_function = _make_execution_function(model,
                                              mode,
                                              class_weight=class_weight)

    # Create the queue for the generator.
    enqueuer = None
    if not is_dataset:
        generator, enqueuer = _make_enqueued_generator(
            generator,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
            max_queue_size=max_queue_size,
            shuffle=shuffle)

    num_samples_or_steps, use_steps = _get_num_samples_or_steps(
        data, steps_per_epoch)

    count_mode = 'steps' if use_steps else 'samples'
    callbacks = cbks.configure_callbacks(callbacks,
                                         model,
                                         do_validation=do_validation,
                                         epochs=epochs,
                                         steps_per_epoch=steps_per_epoch,
                                         batch_size=batch_size,
                                         samples=num_samples_or_steps,
                                         count_mode=count_mode,
                                         verbose=verbose,
                                         mode=mode)

    if mode == ModeKeys.PREDICT:
        aggregator = training_utils.OutputsAggregator(True,
                                                      steps=steps_per_epoch)
    else:
        aggregator = training_utils.MetricsAggregator(True,
                                                      steps=steps_per_epoch)

    should_set_learning_phase = context.executing_eagerly(
    ) and model.run_eagerly
    if should_set_learning_phase:
        learning_phase_scope = backend.eager_learning_phase_scope(
            1 if mode == ModeKeys.TRAIN else 0)
        learning_phase_scope.__enter__()

    callbacks.model.stop_training = False
    callbacks._call_begin_hook(mode)

    initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
        initial_epoch, mode)

    for epoch in range(initial_epoch, epochs):
        if callbacks.model.stop_training:
            break

        # Setup work for each epoch.
        model.reset_metrics()
        epoch_logs = {}
        if mode == ModeKeys.TRAIN:
            callbacks.on_epoch_begin(epoch, epoch_logs)

        if steps_per_epoch is None:
            # Loop over dataset until `OutOfRangeError` is raised.
            target_steps = np.inf
        else:
            # Loop over dataset for the specified number of steps.
            target_steps = steps_per_epoch

        step = 0
        while step < target_steps:
            batch_data = _get_next_batch(generator)
            if batch_data is None:
                if is_dataset:
                    # The dataset passed by the user ran out of batches.
                    # Now we know the cardinality of the dataset.
                    # If steps_per_epoch was specified, then running out of data is
                    # unexpected, so we stop training and inform the user.
                    if steps_per_epoch:
                        callbacks.model.stop_training = True
                        logging.warning(
                            'Your dataset ran out of data; interrupting training. '
                            'Make sure that your dataset can generate at least '
                            '`%s * epochs` batches (in this case, %d batches). '
                            'You may need to use the repeat() function when '
                            'building your dataset.' %
                            (steps_name, steps_per_epoch * epochs))
                    elif step > 0:
                        steps_per_epoch = step
                        aggregator.steps = steps_per_epoch
                else:
                    # We ran out of batches while the user passed an iterator (legacy).
                    callbacks.model.stop_training = True
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your iterator '
                        'can generate at least `%s * epochs` '
                        'batches (in this case, %d batches). You may need to'
                        'use the repeat() function when building your '
                        'dataset.' % (steps_name, steps_per_epoch * epochs))
                break

            # `batch_size` used for validation data if validation
            # data is NumPy/EagerTensors.
            batch_size = int(nest.flatten(batch_data)[0].shape[0])

            # Callbacks batch begin.
            batch_logs = {'batch': step, 'size': batch_size}
            callbacks._call_batch_hook(mode, 'begin', step, batch_logs)

            is_deferred = not model._is_compiled
            batch_outs = batch_function(*batch_data)
            if not isinstance(batch_outs, list):
                batch_outs = [batch_outs]

            if step == 0:
                aggregator.create(batch_outs)

                if is_deferred:
                    # Set callbacks params. We do this here when model is compiled only
                    # in the first iteration of this loop (deferred build scenario).
                    cbks.set_callback_parameters(
                        callbacks,
                        model,
                        do_validation=do_validation,
                        batch_size=batch_size,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch,
                        samples=num_samples_or_steps,
                        verbose=verbose,
                        mode=mode)

            # Aggregate results.
            aggregator.aggregate(batch_outs)

            # Callbacks batch end.
            batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
            callbacks._call_batch_hook(mode, 'end', step, batch_logs)
            step += 1

            if callbacks.model.stop_training:
                break

        aggregator.finalize()
        results = aggregator.results
        epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
        if len(results) == 1:
            results = results[0]

        # Run the test loop every epoch during training.
        if (do_validation and training_utils.should_run_validation(
                validation_freq, epoch) and not callbacks.model.stop_training):
            val_results = model_iteration(
                model,
                validation_data,
                steps_per_epoch=validation_steps,
                batch_size=batch_size,
                class_weight=class_weight,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                max_queue_size=max_queue_size,
                callbacks=callbacks,
                verbose=verbose,
                mode=ModeKeys.TEST,
                steps_name='validation_steps')

            if not isinstance(val_results, list):
                val_results = [val_results]
            epoch_logs = cbks.make_logs(model,
                                        epoch_logs,
                                        val_results,
                                        mode,
                                        prefix='val_')

        if mode == ModeKeys.TRAIN:
            # Epochs only apply to `fit`.
            callbacks.on_epoch_end(epoch, epoch_logs)

        # Recreate dataset iterator for the next epoch.
        if reset_dataset_after_each_epoch and epoch < epochs - 1:
            generator = dataset_ops.make_one_shot_iterator(original_dataset)

    model._successful_loop_finish = True
    callbacks._call_end_hook(mode)

    if enqueuer is not None:
        enqueuer.stop()

    if should_set_learning_phase:
        learning_phase_scope.__exit__(None, None, None)

    if mode == ModeKeys.TRAIN:
        return model.history
    return results
コード例 #12
0
def model_iteration(model,
                    data,
                    steps_per_epoch=None,
                    epochs=1,
                    verbose=1,
                    callbacks=None,
                    validation_data=None,
                    validation_steps=None,
                    validation_freq=1,
                    train_class_weight=None,
                    val_class_weight=None,
                    max_queue_size=10,
                    workers=1,
                    use_multiprocessing=False,
                    shuffle=False,
                    initial_epoch=0,
                    mode=ModeKeys.TRAIN,
                    batch_size=None,
                    steps_name='steps',
                    **kwargs):

    if 'steps' in kwargs:
        steps_per_epoch = kwargs['steps']

    # Determine the number of steps per epoch and whether we should reset the
    # dataset at the end of each epoch.
    reset_dataset_after_each_epoch = False
    original_dataset = None
    is_dataset = isinstance(data,
                            (dataset_ops.DatasetV2, dataset_ops.DatasetV1))
    if is_dataset:
        original_dataset = data
        if steps_per_epoch is None:
            reset_dataset_after_each_epoch = True
            steps_per_epoch = training_utils.infer_steps_for_dataset(
                data, steps_per_epoch, epochs=epochs, steps_name=steps_name)

    # Convert to a format that supports `next(generator)`.
    generator, steps_per_epoch = convert_to_generator_like(
        data,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        epochs=epochs - initial_epoch,
        shuffle=shuffle)

    do_validation = validation_data is not None
    is_sequence = isinstance(generator, data_utils.Sequence)
    _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
                        steps_per_epoch, validation_data, validation_steps,
                        mode, kwargs)

    # print(train_class_weight, 'before make execution')
    ######################################################################
    batch_function = _make_execution_function(
        model,
        mode,
        train_class_weight=train_class_weight,
        val_class_weight=val_class_weight)
    ######################################################################

    # Create the queue for the generator.
    enqueuer = None
    if not is_dataset:
        generator, enqueuer = _make_enqueued_generator(
            generator,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
            max_queue_size=max_queue_size,
            shuffle=shuffle)

    num_samples_or_steps, use_steps = _get_num_samples_or_steps(
        data, steps_per_epoch)

    count_mode = 'steps' if use_steps else 'samples'
    callbacks = cbks.configure_callbacks(callbacks,
                                         model,
                                         do_validation=do_validation,
                                         epochs=epochs,
                                         steps_per_epoch=steps_per_epoch,
                                         batch_size=batch_size,
                                         samples=num_samples_or_steps,
                                         verbose=verbose,
                                         count_mode=count_mode,
                                         mode=mode)

    if mode == ModeKeys.PREDICT:
        aggregator = training_utils.OutputsAggregator(True,
                                                      steps=steps_per_epoch)
    else:
        aggregator = training_utils.MetricsAggregator(True,
                                                      steps=steps_per_epoch)

    should_set_learning_phase = context.executing_eagerly(
    ) and model.run_eagerly
    if should_set_learning_phase:
        learning_phase_scope = backend.eager_learning_phase_scope(
            1 if mode == ModeKeys.TRAIN else 0)
        learning_phase_scope.__enter__()

    callbacks.model.stop_training = False
    callbacks._call_begin_hook(mode)

    print(initial_epoch, mode)
    # TODO: mode is a bug?
    # https://github.com/tensorflow/tensorflow/blob/r2.2/tensorflow/python/keras/engine/training.py
    initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch)

    for epoch in range(initial_epoch, epochs):
        if callbacks.model.stop_training:
            break

        # Setup work for each epoch.
        model.reset_metrics()
        epoch_logs = {}
        if mode == ModeKeys.TRAIN:
            callbacks.on_epoch_begin(epoch, epoch_logs)

        if steps_per_epoch is None:
            # Loop over dataset until `OutOfRangeError` is raised.
            target_steps = np.inf
        else:
            # Loop over dataset for the specified number of steps.
            target_steps = steps_per_epoch

        step = 0
        while step < target_steps:
            batch_data = _get_next_batch(generator)
            if batch_data is None:
                if is_dataset:
                    # The dataset passed by the user ran out of batches.
                    # Now we know the cardinality of the dataset.
                    # If steps_per_epoch was specified, then running out of data is
                    # unexpected, so we stop training and inform the user.
                    if steps_per_epoch:
                        callbacks.model.stop_training = True
                        logging.warning(
                            'Your dataset ran out of data; interrupting training. '
                            'Make sure that your dataset can generate at least '
                            '`%s * epochs` batches (in this case, %d batches). '
                            'You may need to use the repeat() function when '
                            'building your dataset.' %
                            (steps_name, steps_per_epoch * epochs))
                    elif step > 0:
                        steps_per_epoch = step
                        aggregator.steps = steps_per_epoch
                else:
                    # We ran out of batches while the user passed an iterator (legacy).
                    callbacks.model.stop_training = True
                    logging.warning(
                        'Your dataset iterator ran out of data; '
                        'interrupting training. Make sure that your iterator '
                        'can generate at least `%s * epochs` '
                        'batches (in this case, %d batches). You may need to'
                        'use the repeat() function when building your '
                        'dataset.' % (steps_name, steps_per_epoch * epochs))
                break

            # `batch_size` used for validation data if validation
            # data is NumPy/EagerTensors.
            batch_size = int(nest.flatten(batch_data)[0].shape[0])

            # Callbacks batch begin.
            batch_logs = {'batch': step, 'size': batch_size}
            callbacks._call_batch_hook(mode, 'begin', step, batch_logs)

            is_deferred = not model._is_compiled
            ######################################################
            batch_outs = batch_function(*batch_data)
            ######################################################
            if not isinstance(batch_outs, list):
                batch_outs = [batch_outs]

            if step == 0:
                aggregator.create(batch_outs)

                if is_deferred:
                    # Set callbacks params. We do this here when model is compiled only
                    # in the first iteration of this loop (deferred build scenario).
                    cbks.set_callback_parameters(
                        callbacks,
                        model,
                        do_validation=do_validation,
                        batch_size=batch_size,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch,
                        samples=num_samples_or_steps,
                        verbose=verbose,
                        mode=mode)

            # Aggregate results.
            aggregator.aggregate(batch_outs)

            # Callbacks batch end.
            batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
            callbacks._call_batch_hook(mode, 'end', step, batch_logs)
            step += 1

            if callbacks.model.stop_training:
                break

        aggregator.finalize()
        results = aggregator.results
        epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
        if len(results) == 1:
            results = results[0]

        # Run the test loop every epoch during training.
        if (do_validation and training_utils.should_run_validation(
                validation_freq, epoch) and not callbacks.model.stop_training):
            ############################################################################
            val_results = model_iteration(
                model,
                validation_data,
                steps_per_epoch=validation_steps,
                batch_size=batch_size,
                val_class_weight=val_class_weight,  ######## HACK
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                max_queue_size=max_queue_size,
                callbacks=callbacks,
                verbose=0,
                mode=ModeKeys.TEST,
                steps_name='validation_steps')
            ############################################################################

            if not isinstance(val_results, list):
                val_results = [val_results]
            epoch_logs = cbks.make_logs(model,
                                        epoch_logs,
                                        val_results,
                                        mode,
                                        prefix='val_')

        if mode == ModeKeys.TRAIN:
            # Epochs only apply to `fit`.
            callbacks.on_epoch_end(epoch, epoch_logs)

        # Recreate dataset iterator for the next epoch.
        if reset_dataset_after_each_epoch and epoch < epochs - 1:
            generator = dataset_ops.make_one_shot_iterator(original_dataset)

    callbacks._call_end_hook(mode)

    if enqueuer is not None:
        enqueuer.stop()

    if should_set_learning_phase:
        learning_phase_scope.__exit__(None, None, None)

    if mode == ModeKeys.TRAIN:
        return model.history
    return results