Ejemplo n.º 1
0
def _make_enqueued_generator(generator,
                             workers=1,
                             use_multiprocessing=False,
                             max_queue_size=10,
                             shuffle=False):
    """Create a buffered queue of next elements of the generator."""
    is_sequence = isinstance(generator, data_utils.Sequence)
    enqueuer = None
    if workers > 0:
        if is_sequence:
            enqueuer = data_utils.OrderedEnqueuer(
                generator,
                use_multiprocessing=use_multiprocessing,
                shuffle=shuffle)
        else:
            enqueuer = data_utils.GeneratorEnqueuer(
                generator, use_multiprocessing=use_multiprocessing)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
    else:
        if is_sequence:
            output_generator = data_utils.iter_sequence_infinite(generator)
        else:
            output_generator = generator
    return output_generator, enqueuer
Ejemplo n.º 2
0
 def generator_fn():
   enqueuer = data_utils.GeneratorEnqueuer(
       x, use_multiprocessing=use_multiprocessing)
   enqueuer.start(workers=workers, max_queue_size=max_queue_size)
   return enqueuer.get()
Ejemplo n.º 3
0
 def generator_fn():
     enqueuer = data_utils.GeneratorEnqueuer(
         itertools.chain([peek], x),
         use_multiprocessing=use_multiprocessing)
     enqueuer.start(workers=workers, max_queue_size=max_queue_size)
     return enqueuer.get()
Ejemplo n.º 4
0
def predict_generator(model,
                      generator,
                      steps=None,
                      max_queue_size=10,
                      workers=1,
                      use_multiprocessing=False,
                      verbose=0):
    """See docstring for `Model.predict_generator`."""
    if not context.executing_eagerly():
        model._make_predict_function()

    steps_done = 0
    all_outs = []
    is_sequence = isinstance(generator, data_utils.Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        logging.warning(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps is None:
        if is_sequence:
            steps = len(generator)
        else:
            raise ValueError('Please specify the `steps` argument.')
    enqueuer = None

    try:
        if workers > 0:
            if is_sequence:
                enqueuer = data_utils.OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = data_utils.GeneratorEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = data_utils.iter_sequence_infinite(generator)
            else:
                output_generator = generator

        if verbose == 1:
            progbar = Progbar(target=steps)

        while steps_done < steps:
            generator_output = next(output_generator)
            if isinstance(generator_output, tuple):
                # Compatibility with the generators
                # used for training.
                if len(generator_output) == 2:
                    x, _ = generator_output
                elif len(generator_output) == 3:
                    x, _, _ = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))
            else:
                # Assumes a generator that only
                # yields inputs (not targets and sample weights).
                x = generator_output

            outs = model.predict_on_batch(x)
            if not isinstance(outs, list):
                outs = [outs]

            if not all_outs:
                for out in outs:
                    all_outs.append([])

            for i, out in enumerate(outs):
                all_outs[i].append(out)
            steps_done += 1
            if verbose == 1:
                progbar.update(steps_done)

    except (errors.OutOfRangeError, StopIteration):
        logging.warning(
            'Your dataset iterator ran out of data interrupting testing. '
            'Make sure that your dataset can generate at least `steps` '
            'batches (in this case, %d batches). You may need to use the '
            'repeat() function when building your dataset.', steps)

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    if len(all_outs) == 1:
        if steps_done == 1:
            return all_outs[0][0]
        else:
            return np.concatenate(all_outs[0])
    if steps_done == 1:
        return [out[0] for out in all_outs]
    else:
        return [np.concatenate(out) for out in all_outs]
Ejemplo n.º 5
0
def fit_generator(model,
                  generator,
                  steps_per_epoch=None,
                  epochs=1,
                  verbose=1,
                  callbacks=None,
                  validation_data=None,
                  validation_steps=None,
                  class_weight=None,
                  max_queue_size=10,
                  workers=1,
                  use_multiprocessing=False,
                  shuffle=True,
                  initial_epoch=0):
    """See docstring for `Model.fit_generator`."""
    epoch = initial_epoch

    do_validation = bool(validation_data)
    if not context.executing_eagerly():
        model._make_train_function()
        if do_validation:
            model._make_test_function()

    is_sequence = isinstance(generator, data_utils.Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        logging.warning(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps_per_epoch is None:
        if is_sequence:
            steps_per_epoch = len(generator)
        else:
            raise ValueError('Please specify the `steps_per_epoch` argument.')

    if (isinstance(validation_data, dataset_ops.Dataset)
            and context.executing_eagerly()):
        validation_data = validation_data.make_one_shot_iterator()
    val_gen = (data_utils.is_generator_or_sequence(validation_data)
               or isinstance(validation_data, iterator_ops.EagerIterator))
    if (val_gen and not isinstance(validation_data, data_utils.Sequence)
            and not validation_steps):
        raise ValueError('Please specify the `validation_steps` argument.')

    enqueuer = None
    val_enqueuer = None

    try:
        val_x, val_y, val_sample_weights = validation_data, None, None
        if do_validation and not val_gen:
            # Prepare data for validation
            if len(validation_data) == 2:
                val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
                val_sample_weights = None
            elif len(validation_data) == 3:
                val_x, val_y, val_sample_weights = validation_data  # pylint: disable=unpacking-non-sequence
            else:
                raise ValueError('`validation_data` should be a tuple '
                                 '`(val_x, val_y, val_sample_weight)` '
                                 'or `(val_x, val_y)`. Found: ' +
                                 str(validation_data))
            val_x, val_y, val_sample_weights = model._standardize_user_data(
                val_x, val_y, val_sample_weights)

        callbacks = cbks.configure_callbacks(
            callbacks,
            model,
            do_validation=do_validation,
            val_inputs=val_x,
            val_targets=val_y,
            val_sample_weights=val_sample_weights,
            epochs=epochs,
            validation_steps=validation_steps,
            steps_per_epoch=steps_per_epoch,
            verbose=verbose)

        if workers > 0:
            if is_sequence:
                enqueuer = data_utils.OrderedEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    shuffle=shuffle)
            else:
                enqueuer = data_utils.GeneratorEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = data_utils.iter_sequence_infinite(generator)
            else:
                output_generator = generator

        callbacks.on_train_begin()
        # Construct epoch logs.
        epoch_logs = {}
        while epoch < epochs:
            for m in model.metrics:
                m.reset_states()
            callbacks.on_epoch_begin(epoch)
            steps_done = 0
            batch_index = 0
            while steps_done < steps_per_epoch:
                generator_output = next(output_generator)
                if not hasattr(generator_output, '__len__'):
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))

                if len(generator_output) == 2:
                    x, y = generator_output
                    sample_weight = None
                elif len(generator_output) == 3:
                    x, y, sample_weight = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))
                # build batch logs
                batch_logs = {}
                if isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                batch_logs['batch'] = int(batch_index)
                batch_logs['size'] = int(batch_size)
                callbacks.on_batch_begin(batch_index, batch_logs)

                outs = model.train_on_batch(x,
                                            y,
                                            sample_weight=sample_weight,
                                            class_weight=class_weight)

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(model.metrics_names, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)

                batch_index += 1
                steps_done += 1

                # Epoch finished.
                if steps_done >= steps_per_epoch and do_validation:
                    if val_gen:
                        val_outs = evaluate_generator(
                            model,
                            validation_data,
                            validation_steps,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing,
                            max_queue_size=max_queue_size)
                    else:
                        # No need for try/except because
                        # data has already been validated.
                        val_outs = model.evaluate(
                            val_x,
                            val_y,
                            batch_size=batch_size,
                            sample_weight=val_sample_weights,
                            verbose=0)
                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(model.metrics_names, val_outs):
                        epoch_logs['val_' + l] = o

                if callbacks.model.stop_training:
                    break

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1
            if callbacks.model.stop_training:
                break

    except (errors.OutOfRangeError, StopIteration):
        logging.warning(
            'Your dataset iterator ran out of data interrupting testing. '
            'Make sure that your dataset can generate at least `steps_per_epoch` '
            'batches (in this case, %d batches). You may need to use the '
            'repeat() function when building your dataset.', steps_per_epoch)

    finally:
        try:
            if enqueuer is not None:
                enqueuer.stop()
        finally:
            if val_enqueuer is not None:
                val_enqueuer.stop()

    callbacks.on_train_end()
    return model.history
Ejemplo n.º 6
0
def evaluate_generator(model,
                       generator,
                       steps=None,
                       max_queue_size=10,
                       workers=1,
                       use_multiprocessing=False,
                       verbose=0):
    """See docstring for `Model.evaluate_generator`."""
    if not context.executing_eagerly():
        model._make_test_function()

    if hasattr(model, '_compile_metrics'):
        for m in model.metrics:
            m.reset_states()

    steps_done = 0
    all_outs = []
    batch_sizes = []
    is_sequence = isinstance(generator, data_utils.Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        logging.warning(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps is None:
        if is_sequence:
            steps = len(generator)
        else:
            raise ValueError('Please specify the `steps` argument.')
    enqueuer = None

    try:
        if workers > 0:
            if is_sequence:
                enqueuer = data_utils.OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = data_utils.GeneratorEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = data_utils.iter_sequence_infinite(generator)
            else:
                output_generator = generator

        if verbose == 1:
            progbar = Progbar(target=steps)

        while steps_done < steps:
            generator_output = next(output_generator)
            if not hasattr(generator_output, '__len__'):
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' + str(generator_output))
            if len(generator_output) == 2:
                x, y = generator_output
                sample_weight = None
            elif len(generator_output) == 3:
                x, y, sample_weight = generator_output
            else:
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' + str(generator_output))
            outs = model.test_on_batch(x, y, sample_weight=sample_weight)

            if isinstance(x, list):
                batch_size = int(x[0].shape[0])
            elif isinstance(x, dict):
                batch_size = int(list(x.values())[0].shape[0])
            else:
                batch_size = int(x.shape[0])
            if batch_size == 0:
                raise ValueError('Received an empty batch. '
                                 'Batches should at least contain one item.')
            all_outs.append(outs)

            steps_done += 1
            batch_sizes.append(batch_size)
            if verbose == 1:
                progbar.update(steps_done)

    except (errors.OutOfRangeError, StopIteration):
        logging.warning(
            'Your dataset iterator ran out of data interrupting testing. '
            'Make sure that your dataset can generate at least `steps` '
            'batches (in this case, %d batches). You may need to use the '
            'repeat() function when building your dataset.', steps)

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    if not isinstance(outs, list):
        return np.average(np.asarray(all_outs), weights=batch_sizes)
    else:
        averages = [float(all_outs[-1][0])]  # index 0 = 'loss'
        averages.extend([
            np.average([out[i] for out in all_outs], weights=batch_sizes)
            for i in range(1, len(outs))
        ])
        return averages