Exemplo n.º 1
0
 def __init__(self,
              filename,
              val_data,
              lr_min=1e-6,
              lr_max=1,
              steps=1500,
              val_period=1,
              separator=','):
     self.filename = filename
     self.sep = separator
     self.lr_min = lr_min
     self.lr_max = lr_max
     self.lr = lr_min
     self.val_period = val_period
     self.epoch = 0
     assert steps > 1
     self.lr_increment = pow((lr_max / lr_min), 1. / (steps - 1))
     self._open_args = {}
     self.batch_no = 0
     self.keys = None
     self.writer = None
     self.append_header = True
     workers = 5  # TODO make a parameter
     max_queue_size = 10  # TODO make a parameter
     self.val_enqueuer = OrderedEnqueuer(val_data,
                                         use_multiprocessing=False)
     self.val_enqueuer.start(workers=workers, max_queue_size=max_queue_size)
     self.val_enqueuer_gen = self.val_enqueuer.get()
     self.validation_steps = len(val_data)
     super(EvalLrTest, self).__init__()
Exemplo n.º 2
0
 def __init__(self,
              data_generator,
              val_model,
              metrics_to_keep='all',
              **kwargs):
     self.metrics_to_keep = metrics_to_keep  # optionally filter the metrics to track
     super().__init__(data_generator, val_model, **kwargs)
     assert is_sequence(
         self.data_generator
     ), 'validation generator must be an instance of keras.utils.Sequence'
     val_enqueuer = OrderedEnqueuer(self.data_generator,
                                    use_multiprocessing=False)
     # n.b. that the enqueuer calls on_epoch_end: https://github.com/keras-team/keras/blob/efe72ef433852b1d7d54f283efff53085ec4f756/keras/utils/data_utils.py
     val_enqueuer.start(workers=1, max_queue_size=10)
     self.data_generator = val_enqueuer.get()
     self.eval_gen_workers = 0
def concurrent_generator(sequence,
                         num_workers=8,
                         max_queue_size=32,
                         use_multiprocessing=False):
    enqueuer = OrderedEnqueuer(sequence,
                               use_multiprocessing=use_multiprocessing)
    try:
        enqueuer.start(workers=num_workers, max_queue_size=max_queue_size)
        yield enqueuer.get()
    finally:
        enqueuer.stop()
Exemplo n.º 4
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    print('gpu id', FLAGS.gpu_list)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        print('FLAGS.train_txt_dir', FLAGS.train_txt_dir)
        generator = EAST_generator(batch_size=FLAGS.batch_size, nums=50)

        if FLAGS.workers > 0:
            ''' load data with multiprocessing   
            '''
            enqueuer = OrderedEnqueuer(
                generator,
                use_multiprocessing=FLAGS.use_multiprocessing,
                shuffle=True)
            enqueuer.start(workers=FLAGS.workers,
                           max_queue_size=FLAGS.max_queue_size)
            output_generator = enqueuer.get()
            print('workers ', FLAGS.workers)
        else:
            output_generator = iter_sequence_infinite(generator)

        start = time.time()
        step_print = FLAGS.save_checkpoint_steps // 10
        for epoch in range(3):
            print(generator.indexes)
            num_list = []
            for step in range(30):
                data = next(output_generator)
                insect = set(data) & set(num_list)
                if insect:
                    print(insect)
                num_list.extend(data)
                print('worker', data)
                time.sleep(1)
            generator.on_epoch_end()
            print(generator.indexes)
Exemplo n.º 5
0
def mine_hard_samples(model, datagen, batch_size, use_multiprocessing, shuffle, workers, max_queue_size):

    use_sequence_api = is_sequence(datagen)
    if use_sequence_api:
        enqueuer = OrderedEnqueuer(
            datagen,
            use_multiprocessing=use_multiprocessing,
            shuffle=shuffle)
    else:
        enqueuer = GeneratorEnqueuer(
            datagen,
            use_multiprocessing=use_multiprocessing)
    enqueuer.start(workers=workers, max_queue_size=max_queue_size)
    output_generator = enqueuer.get()

    while True:
        samples, targets = [], []
        while len(samples) < batch_size:
            x_data, y_data = next(output_generator)
            preds = model.predict(x_data)
            print('\nERRORS:\n{}\n'.format(np.abs(preds - y_data).max(axis=-1)))
            errors = np.abs(preds - y_data).max(axis=-1) > .99
            samples += x_data[errors].tolist()
            targets += y_data[errors].tolist()

        regular_samples = batch_size * 2 - len(samples)
        x_data, y_data = next(datagen)
        samples += x_data[:regular_samples].tolist()
        targets += y_data[:regular_samples].tolist()

        samples, targets = map(np.array, (samples, targets))

        idx = np.arange(batch_size * 2)
        np.random.shuffle(idx)
        batch1, batch2 = np.split(idx, 2)

        yield samples[batch1], targets[batch1]
        yield samples[batch2], targets[batch2]
def evaluate_and_predict_generator_with_sceneinst_metrics(
        model,
        generator,
        params,
        multithreading_metrics=False,
        steps=None,
        max_queue_size=10,
        workers=1,
        use_multiprocessing=False,
        verbose=0):
    """See docstring for `Model.evaluate_generator`."""
    model._make_test_function()

    stateful_metric_indices = []
    if hasattr(model, 'metrics'):
        for m in model.stateful_metric_functions:
            m.reset_states()
        stateful_metric_indices = [
            i for i, name in enumerate(model.metrics_names)
            if str(name) in model.stateful_metric_names
        ]
    else:
        stateful_metric_indices = []

    steps_done = 0
    wait_time = 0.01
    outs_per_batch = []
    batch_sizes = []
    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps is None:
        if is_sequence:
            steps = len(generator)
        else:
            raise ValueError('`steps=None` is only valid for a generator'
                             ' based on the `keras.utils.Sequence` class.'
                             ' Please specify `steps` or use the'
                             ' `keras.utils.Sequence` class.')
    enqueuer = None

    try:
        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = iter(generator)
            else:
                output_generator = generator

        if verbose == 1:
            progbar = Progbar(target=steps)

        # setup scene instance dictionary
        model.scene_instance_id_metrics_dict_eval = {}

        # create thread for asynchronous batch metrics calculation
        if multithreading_metrics:
            label_queue = queue.Queue(
            )  # threadsafe queue into which we will push (y_pred, y) tuples
            validmetrics_thread = threading.Thread(
                target=metrics_per_batch_thread_handler,
                args=(label_queue, model.scene_instance_id_metrics_dict_eval,
                      params['mask_value'], steps))
            validmetrics_thread.start()
            #print('thread for calculating the batch validation metrics has been started')

        model.val_loss_batch = []
        while steps_done < steps:
            generator_output = next(output_generator)
            if not hasattr(generator_output, '__len__'):
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' + str(generator_output))
            if len(generator_output) == 2:
                x, y = generator_output
                sample_weight = None
            elif len(generator_output) == 3:
                x, y, sample_weight = generator_output
            else:
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' + str(generator_output))

            # run forward pass
            # remark on label shape: last (fourth) dimension contains in 0 the true labels, in 1 the corresponding sceneinstid (millioncode)
            batch_loss, y_pred_logits = heiner_test_and_predict_on_batch(
                model, x, y[:, :, :, 0])

            model.val_loss_batch.append(batch_loss)

            # from logits to predicted class probabilities
            y_pred_probs = sigmoid(y_pred_logits,
                                   out=y_pred_logits)  # last arg: inplace
            # from probabilities to hard class decisions
            y_pred = np.greater_equal(y_pred_probs,
                                      params['outputthreshold'],
                                      out=y_pred_probs)  # last arg: inplace

            # increment metrics for scene instances in batch
            if multithreading_metrics:
                # the following two arrays need to be unchanged in order for being thread-safe
                # assumption 1: batchloader yields array copies (true for moritz loader)
                # assumption 2: *_and_predict_on_batch return newly allocated arrays
                label_queue.put((y_pred, y))
            else:
                heiner_calculate_class_accuracies_metrics_per_scene_instance_in_batch(
                    model.scene_instance_id_metrics_dict_eval, y_pred, y,
                    params['mask_value'])

            if x is None or len(x) == 0:
                # Handle data tensors support when no input given
                # step-size = 1 for data tensors
                batch_size = 1
            elif isinstance(x, list):
                batch_size = x[0].shape[0]
            elif isinstance(x, dict):
                batch_size = list(x.values())[0].shape[0]
            else:
                batch_size = x.shape[0]
            if batch_size == 0:
                raise ValueError('Received an empty batch. '
                                 'Batches should contain '
                                 'at least one item.')
            steps_done += 1
            batch_sizes.append(batch_size)
            if verbose == 1:
                progbar.update(steps_done)

    finally:
        if enqueuer is not None:
            enqueuer.stop()

        if multithreading_metrics:
            validmetrics_thread.join()

    return np.average(
        np.array(model.val_loss_batch)
    )  # for test phase: simply use the model.scene_instance_id_metrics_dict_test after execution
def fit_and_predict_generator_with_sceneinst_metrics(
        model,
        generator,
        params,
        multithreading_metrics=False,
        steps_per_epoch=None,
        epochs=1,
        verbose=1,
        callbacks=None,
        validation_data=None,
        validation_steps=None,
        max_queue_size=10,
        workers=1,
        use_multiprocessing=False,
        shuffle=True,
        initial_epoch=0):
    """See docstring for `Model.fit_generator`."""
    wait_time = 0.01  # in seconds
    epoch = initial_epoch

    do_validation = bool(validation_data)
    model._make_train_function()
    if do_validation:
        model._make_test_function()

    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps_per_epoch is None:
        if is_sequence:
            steps_per_epoch = len(generator)
        else:
            raise ValueError('`steps_per_epoch=None` is only valid for a'
                             ' generator based on the '
                             '`keras.utils.Sequence`'
                             ' class. Please specify `steps_per_epoch` '
                             'or use the `keras.utils.Sequence` class.')

    # python 2 has 'next', 3 has '__next__'
    # avoid any explicit version checks
    val_gen = (hasattr(validation_data, 'next')
               or hasattr(validation_data, '__next__')
               or isinstance(validation_data, Sequence))
    if (val_gen and not isinstance(validation_data, Sequence)
            and not validation_steps):
        raise ValueError('`validation_steps=None` is only valid for a'
                         ' generator based on the `keras.utils.Sequence`'
                         ' class. Please specify `validation_steps` or use'
                         ' the `keras.utils.Sequence` class.')

    # Prepare display labels.
    out_labels = model.metrics_names
    callback_metrics = out_labels + ['val_' + n for n in out_labels]

    # prepare callbacks
    model.history = cbks.History()
    _callbacks = [
        cbks.BaseLogger(stateful_metrics=model.stateful_metric_names)
    ]
    if verbose:
        _callbacks.append(
            cbks.ProgbarLogger(count_mode='steps',
                               stateful_metrics=model.stateful_metric_names))
    _callbacks += (callbacks or []) + [model.history]
    callbacks = cbks.CallbackList(_callbacks)

    # it's possible to callback a different model than self:
    if hasattr(model, 'callback_model') and model.callback_model:
        callback_model = model.callback_model
    else:
        callback_model = model
    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': epochs,
        'steps': steps_per_epoch,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics,
    })
    callbacks.on_train_begin()

    enqueuer = None
    val_enqueuer = None

    try:
        if do_validation:
            if val_gen and workers > 0:
                # Create an Enqueuer that can be reused
                val_data = validation_data
                if isinstance(val_data, Sequence):
                    val_enqueuer = OrderedEnqueuer(
                        val_data, use_multiprocessing=use_multiprocessing)
                    validation_steps = len(val_data)
                else:
                    val_enqueuer = GeneratorEnqueuer(
                        val_data, use_multiprocessing=use_multiprocessing)
                val_enqueuer.start(workers=workers,
                                   max_queue_size=max_queue_size)
                val_enqueuer_gen = val_enqueuer.get()
            elif val_gen:
                val_data = validation_data
                if isinstance(val_data, Sequence):
                    val_enqueuer_gen = iter(val_data)
                else:
                    val_enqueuer_gen = val_data
            else:
                # Prepare data for validation
                if len(validation_data) == 2:
                    val_x, val_y = validation_data
                    val_sample_weight = None
                elif len(validation_data) == 3:
                    val_x, val_y, val_sample_weight = validation_data
                else:
                    raise ValueError('`validation_data` should be a tuple '
                                     '`(val_x, val_y, val_sample_weight)` '
                                     'or `(val_x, val_y)`. Found: ' +
                                     str(validation_data))
                val_x, val_y, val_sample_weights = model._standardize_user_data(
                    val_x, val_y, val_sample_weight)
                val_data = val_x + val_y + val_sample_weights
                if model.uses_learning_phase and not isinstance(
                        K.learning_phase(), int):
                    val_data += [0.]
                for cbk in callbacks:
                    cbk.validation_data = val_data

        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    shuffle=shuffle)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = iter(generator)
            else:
                output_generator = generator

        callback_model.stop_training = False
        # Construct epoch logs.
        epoch_logs = {}
        while epoch < epochs:

            # setup scene instance dictionary
            model.scene_instance_id_metrics_dict_train = {}

            # create thread for asynchronous batch metrics calculation (one thread per epoch, joined before final metrics calculation)
            if multithreading_metrics:
                label_queue = queue.Queue(
                )  # threadsafe queue into which we will push (y_pred, y) tuples
                trainmetrics_thread = threading.Thread(
                    target=metrics_per_batch_thread_handler,
                    args=(label_queue,
                          model.scene_instance_id_metrics_dict_train,
                          params['mask_value'], steps_per_epoch))

                trainmetrics_thread.start()
                #print('thread for calculating the batch train metrics has been started')

            for m in model.stateful_metric_functions:
                m.reset_states()
            callbacks.on_epoch_begin(epoch)
            steps_done = 0
            batch_index = 0

            runtime_generator_cumulated = 0.
            runtime_train_and_predict_on_batch_cumulated = 0.
            runtime_class_accuracies_cumulated = 0.
            skip_runtime_avg = 5  # skipping the first few batches to reduce bias due to inital extra time

            while steps_done < steps_per_epoch:
                t_start_batch = time()
                t_start = time()
                generator_output = next(output_generator)
                runtime_generator_next = time() - t_start

                if batch_index >= skip_runtime_avg:
                    runtime_generator_cumulated += runtime_generator_next

                if not hasattr(generator_output, '__len__'):
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))

                if len(generator_output) == 2:
                    x, y = generator_output
                    sample_weight = None
                elif len(generator_output) == 3:
                    x, y, sample_weight = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))
                # build batch logs
                batch_logs = {}
                if x is None or len(x) == 0:
                    # Handle data tensors support when no input given
                    # step-size = 1 for data tensors
                    batch_size = 1
                elif isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size
                t_start = time()
                callbacks.on_batch_begin(batch_index, batch_logs)
                runtime_callbacks_on_batch_begin = time() - t_start

                # remark on label shape: last (fourth) dimension contains in 0 the true labels, in 1 the corresponding sceneinstid (millioncode)
                t_start = time()

                # set sample weights
                if params['nosceneinstweights']:
                    sample_weight = None
                else:
                    sample_weight = heiner_calculate_sample_weights_batch(
                        y[:, :, 0, 1], generator.length_dict,
                        generator.scene_instance_ids_dict, 'train')

                # run forward and backward pass and do the gradient descent step
                batch_loss, y_pred_logits, gradient_norm = heiner_train_and_predict_on_batch(
                    model,
                    x,
                    y[:, :, :, 0],
                    sample_weight=sample_weight,
                    calc_global_gradient_norm=not params['nocalcgradientnorm'])
                runtime_train_and_predict_on_batch = time() - t_start
                if batch_index >= skip_runtime_avg:
                    runtime_train_and_predict_on_batch_cumulated += runtime_train_and_predict_on_batch

                batch_logs['loss'] = batch_loss

                model.gradient_norm = gradient_norm

                t_start = time()
                # from logits to predicted class probabilities
                y_pred_probs = sigmoid(y_pred_logits,
                                       out=y_pred_logits)  # last arg: inplace
                # from probabilities to hard class decisions
                y_pred = np.greater_equal(
                    y_pred_probs, params['outputthreshold'],
                    out=y_pred_probs)  # last arg: inplace

                # increment metrics for scene instances in batch
                if multithreading_metrics:
                    # the following two arrays need to be unchanged in order for being thread-safe
                    # assumption 1: batchloader yields array copies (true for moritz loader)
                    # assumption 2: *_and_predict_on_batch return newly allocated arrays
                    label_queue.put((y_pred, y))
                else:
                    heiner_calculate_class_accuracies_metrics_per_scene_instance_in_batch(
                        model.scene_instance_id_metrics_dict_train, y_pred, y,
                        params['mask_value'])
                runtime_class_accuracies = time() - t_start
                if batch_index >= skip_runtime_avg:
                    runtime_class_accuracies_cumulated += runtime_class_accuracies

                t_start = time()
                callbacks.on_batch_end(batch_index, batch_logs)
                runtime_callbacks_on_batch_end = time() - t_start

                runtime_batch = time() - t_start_batch
                # print((' ----> batch {} in epoch {} took in total {:.2f} sec => generator {:.2f} ' +
                #        'train_and_predict {:.2f}, metrics {:.2f}')
                #       .format(batch_index + 1, epoch + 1, runtime_batch, runtime_generator_next,
                #               runtime_train_and_predict_on_batch,
                #               runtime_class_accuracies))

                batch_index += 1
                steps_done += 1

                if steps_done > skip_runtime_avg and steps_done == steps_per_epoch - 1:
                    print(
                        ' --> batch {} we have average runtimes: generator {:.2f}, train_predict {:.2f}, metrics {:.2f}'
                        .format(
                            batch_index, runtime_generator_cumulated /
                            (steps_done - skip_runtime_avg),
                            runtime_train_and_predict_on_batch_cumulated /
                            (steps_done - skip_runtime_avg),
                            runtime_class_accuracies_cumulated /
                            (steps_done - skip_runtime_avg)))

                # Epoch finished.
                if steps_done >= steps_per_epoch and do_validation:
                    if val_gen:
                        val_outs = evaluate_and_predict_generator_with_sceneinst_metrics(
                            model,
                            val_enqueuer_gen,
                            params,
                            multithreading_metrics,
                            validation_steps,
                            workers=0,
                            verbose=1)
                    else:
                        # No need for try/except because
                        # data has already been validated.
                        val_outs = model.evaluate(
                            val_x,
                            val_y,
                            batch_size=batch_size,
                            sample_weight=val_sample_weights,
                            verbose=0)
                    val_outs = to_list(val_outs)
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

                if callback_model.stop_training:
                    break

            if multithreading_metrics:
                trainmetrics_thread.join()
                print(
                    ' --> both threads for calculating the batch metrics -- training and validation -- finished all their work'
                )

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1
            if callback_model.stop_training:
                break

    finally:
        try:
            if enqueuer is not None:
                enqueuer.stop()
        finally:
            if val_enqueuer is not None:
                val_enqueuer.stop()

        if multithreading_metrics:
            trainmetrics_thread.join()  # joined again (harmless)

    callbacks.on_train_end()
    return model.history
Exemplo n.º 8
0
    def train_srgan(self,
        epochs, batch_size,
        dataname,
        datapath_train,
        datapath_validation=None,
        steps_per_validation=10,
        datapath_test=None,
        workers=40, max_queue_size=100,
        first_epoch=0,
        print_frequency=2,
        crops_per_image=2,
        log_weight_frequency=1000,
        log_weight_path='./data/weights/',
        log_tensorboard_path='./data/logs/',
        log_tensorboard_name='SRGAN',
        log_tensorboard_update_freq=500,
        log_test_frequency=500,
        log_test_path="./images/samples/",
        ):

        # Create train data loader
        loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image
        )

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image
            )
        print("Picture Loaders has been ready.")
        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            loader,
            use_multiprocessing=False,
            shuffle=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
        print("Data Enqueuer has been ready.")

        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)

        # Loop through epochs / iterations
        for epoch in range(first_epoch, epochs + first_epoch):
            # Start epoch time
            if epoch % print_frequency == 1:
                start_epoch = datetime.datetime.now()

                # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)
            # SRGAN's loss (don't use them)
            # real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            # fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            # discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train Relativistic Discriminator
            discriminator_loss = self.RaGAN.train_on_batch([imgs_hr, generated_hr], None)

            # Train generator
            # features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch([imgs_lr, imgs_hr], None)

            # Callbacks
            # logs = named_logs(self.srgan, generator_loss)
            # tensorboard.on_epoch_end(epoch, logs)
            # print(generator_loss, discriminator_loss)
            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(self.srgan.metrics_names, g_avg_loss)
                print(self.RaGAN.metrics_names, d_avg_loss)
                print("\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}".format(
                    epoch, epochs + first_epoch,
                    (datetime.datetime.now() - start_epoch).seconds,
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.srgan.metrics_names, g_avg_loss)]),
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.RaGAN.metrics_names, d_avg_loss)])
                ))
                print_losses = {"G": [], "D": []}
            
            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                print(">> Ploting test images")
                plot_test_images(self, loader, datapath_test, log_test_path, epoch, refer_model=self.refer_model)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                print(">> Saving the network weights")
                self.save_weights(os.path.join(log_weight_path, dataname), epoch)
Exemplo n.º 9
0
    def evaluate_generator(self, generator, steps=None,
                           max_queue_size=10,
                           workers=1,
                           use_multiprocessing=False):
        """Evaluates the model on a data generator.

        The generator should return the same kind of data
        as accepted by `test_on_batch`.
        For documentation, refer to keras.engine.training.evaluate_generator (https://keras.io/models/model/)
        """

        steps_done = 0
        wait_time = 0.01
        all_outs = []
        batch_sizes = []
        is_sequence = isinstance(generator, Sequence)
        if not is_sequence and use_multiprocessing and workers > 1:
            warnings.warn(
                UserWarning('Using a generator with `use_multiprocessing=True`'
                            ' and multiple workers may duplicate your data.'
                            ' Please consider using the`keras.utils.Sequence'
                            ' class.'))
        if steps is None:
            if is_sequence:
                steps = len(generator)
            else:
                raise ValueError('`steps=None` is only valid for a generator'
                                 ' based on the `keras.utils.Sequence` class.'
                                 ' Please specify `steps` or use the'
                                 ' `keras.utils.Sequence` class.')
        enqueuer = None

        try:
            if workers > 0:
                if is_sequence:
                    enqueuer = OrderedEnqueuer(generator,
                                               use_multiprocessing=use_multiprocessing)
                else:
                    enqueuer = GeneratorEnqueuer(generator,
                                                 use_multiprocessing=use_multiprocessing,
                                                 wait_time=wait_time)
                enqueuer.start(workers=workers, max_queue_size=max_queue_size)
                output_generator = enqueuer.get()
            else:
                output_generator = generator

            while steps_done < steps:
                generator_output = next(output_generator)
                if not hasattr(generator_output, '__len__'):
                    raise ValueError('Output of generator should be a tuple '
                                     '(x, y, sample_weight) '
                                     'or (x, y). Found: ' +
                                     str(generator_output))
                if len(generator_output) == 2:
                    x, y = generator_output
                    analysis = None
                elif len(generator_output) == 3:
                    x, y, analysis = generator_output
                else:
                    raise ValueError('Output of generator should be a tuple '
                                     '(x, y, analysis) '
                                     'or (x, y). Found: ' +
                                     str(generator_output))
                outs = self.evaluate_on_batch(x, y, analysis=analysis, sample_weight=None)

                if isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                if batch_size == 0:
                    raise ValueError('Received an empty batch. '
                                     'Batches should at least contain one item.')
                all_outs.append(outs)

                steps_done += 1
                batch_sizes.append(batch_size)

        finally:
            if enqueuer is not None:
                enqueuer.stop()

        if not isinstance(outs, list):
            return np.average(np.asarray(all_outs),
                              weights=batch_sizes)
        else:
            averages = []
            for i in range(len(outs)):
                averages.append(np.average([out[i] for out in all_outs],
                                           weights=batch_sizes))
            return averages
Exemplo n.º 10
0
    def __getitem__(self, item):
        indexes = [i + item * self.batch_size for i in range(self.batch_size)]
        a, la = self.generate_data(indexes)
        return a, la


if __name__ == "__main__":

    src_dir = "/home/redivan/datasets/dog_breeds/images"
    train_gen = DataGenerator(
        src_dir,
        img_shape=(512, 512),
        uniq_classes=
        "/home/redivan/datasets/dog_breeds/images/model_thr30.list",
        batch_size=16)
    enqueuer = OrderedEnqueuer(train_gen)
    enqueuer.start(workers=1, max_queue_size=4)
    output_gen = enqueuer.get()

    gen_len = len(train_gen)
    try:
        for i in range(gen_len):
            batch = next(output_gen)
            for a, la in zip(batch[0], batch[1]):
                print(a.shape)
                cv2.imshow("win", a)
                print(la)
                print(np.argmax(la))
                cv2.waitKey(0)
    finally:
        enqueuer.stop()
Exemplo n.º 11
0
def evaluate_generator(model,
                       generator,
                       steps=None,
                       batch_size=1,
                       margin=0.5,
                       N_diff=5,
                       max_queue_size=10,
                       workers=1,
                       use_multiprocessing=False):
    """Evaluates the model on a data generator.
    The generator should return the same kind of data
    as accepted by `test_on_batch`.
    # Arguments
        generator: Generator yielding tuples (inputs, targets)
            or (inputs, targets, sample_weights)
            or an instance of Sequence (keras.utils.Sequence)
            object in order to avoid duplicate data
            when using multiprocessing.
        steps: Total number of steps (batches of samples)
            to yield from `generator` before stopping.
            Optional for `Sequence`: if unspecified, will use
            the `len(generator)` as a number of steps.
        max_queue_size: maximum size for the generator queue
        workers: Integer. Maximum number of processes to spin up
            when using process based threading.
            If unspecified, `workers` will default to 1. If 0, will
            execute the generator on the main thread.
        use_multiprocessing: if True, use process based threading.
            Note that because
            this implementation relies on multiprocessing,
            you should not pass
            non picklable arguments to the generator
            as they can't be passed
            easily to children processes.
    # Returns
        Scalar test loss (if the model has a single output and no metrics)
        or list of scalars (if the model has multiple outputs
        and/or metrics). The attribute `model.metrics_names` will give you
        the display labels for the scalar outputs.
    # Raises
        ValueError: In case the generator yields
            data in an invalid format.
    """
    # self._make_test_function()

    steps_done = 0
    wait_time = 0.01
    all_outs = []
    batch_sizes = []
    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps is None:
        if is_sequence:
            steps = len(generator)
        else:
            raise ValueError('`steps=None` is only valid for a generator'
                             ' based on the `keras.utils.Sequence` class.'
                             ' Please specify `steps` or use the'
                             ' `keras.utils.Sequence` class.')
    enqueuer = None

    try:
        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            output_generator = generator

        while steps_done < steps:
            generator_output = next(output_generator)
            if not hasattr(generator_output, '__len__'):
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, z, ii_ndiff) ' +
                                 str(generator_output))
            if len(generator_output) == batch_size:
                gen_out = generator_output
                sample_weight = None
            else:
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, z, ii_ndiff) ' +
                                 str(generator_output))

            loss_mat = np.zeros((batch_size, N_diff))
            for ii_ndiff in range(N_diff):
                # get the maximum sequence length
                len_anchor_max, len_same_max, len_diff_max = \
                    get_maximum_length(batch_size=batch_size,
                                       generator_output=gen_out,
                                       index=[ii_ndiff] * batch_size)

                # print(len_anchor_max, len_same_max, len_diff_max)
                # organize the input for the prediction
                input_anchor, input_same, input_diff = \
                    make_same_length_batch(batch_size=batch_size,
                                           len_anchor_max=len_anchor_max,
                                           len_same_max=len_same_max,
                                           len_diff_max=len_diff_max,
                                           generator_output=gen_out,
                                           index=[ii_ndiff] * batch_size)

                output_batch_pred = model.predict_on_batch(
                    [input_anchor, input_same, input_diff])

                loss = K.eval(triplet_loss_no_mean(output_batch_pred, margin))
                loss_mat[:, ii_ndiff] = loss

            outs = np.mean(np.max(loss_mat, axis=-1))

            # if isinstance(x, list):
            #     batch_size = x[0].shape[0]
            # elif isinstance(x, dict):
            #     batch_size = list(x.values())[0].shape[0]
            # else:
            #     batch_size = x.shape[0]
            # if batch_size == 0:
            #     raise ValueError('Received an empty batch. '
            #                      'Batches should at least contain one item.')
            all_outs.append(outs)

            steps_done += 1
            batch_sizes.append(batch_size)

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    if not isinstance(outs, list):
        return np.average(np.asarray(all_outs), weights=batch_sizes)
    else:
        averages = []
        for i in range(len(outs)):
            averages.append(
                np.average([out[i] for out in all_outs], weights=batch_sizes))
        return averages
Exemplo n.º 12
0
def evaluate_generator_autosized(model,
                                 generator,
                                 steps=None,
                                 callbacks=None,
                                 max_queue_size=10,
                                 workers=1,
                                 use_multiprocessing=False,
                                 verbose=0):
    """See docstring for `Model.evaluate_generator`."""
    model._make_test_function()

    stateful_metric_indices = []
    if hasattr(model, 'metrics'):
        for m in model.stateful_metric_functions:
            m.reset_states()
        stateful_metric_indices = [
            i for i, name in enumerate(model.metrics_names)
            if str(name) in model.stateful_metric_names
        ]
    else:
        stateful_metric_indices = []

    callbacks = cbks.CallbackList(callbacks or [])

    # it's possible to callback a different model than self:
    if hasattr(model, 'callback_model') and model.callback_model:
        callback_model = model.callback_model
    else:
        callback_model = model
    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': 1,
        'steps': steps,  # if None, will be refined during first epoch
        'verbose': verbose,
        'do_validation': False,
        'metrics': model.metrics_names,
    })

    steps_done = 0
    wait_time = 0.01
    outs_per_batch = []
    batch_sizes = []
    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    # if steps is None:
    #     if is_sequence:
    #         steps = len(generator)
    #     else:
    #         raise ValueError('`steps=None` is only valid for a generator'
    #                          ' based on the `keras.utils.Sequence` class.'
    #                          ' Please specify `steps` or use the'
    #                          ' `keras.utils.Sequence` class.')
    enqueuer = None

    try:
        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = iter(generator)
            else:
                output_generator = generator

        if verbose == 1:
            progbar = Progbar(target=steps)
        callbacks.on_epoch_begin(0)

        for generator_output in output_generator:
            if not generator_output:  # end of epoch?
                break
            if not hasattr(generator_output, '__len__'):
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' + str(generator_output))
            if len(generator_output) == 2:
                x, y = generator_output
                sample_weight = None
            elif len(generator_output) == 3:
                x, y, sample_weight = generator_output
            else:
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' + str(generator_output))
            # build batch logs
            batch_logs = {}
            if not x:
                # Handle data tensors support when no input given
                # step-size = 1 for data tensors
                batch_size = 1
            elif isinstance(x, list):
                batch_size = x[0].shape[0]
            elif isinstance(x, dict):
                batch_size = list(x.values())[0].shape[0]
            else:
                batch_size = x.shape[0]
            if batch_size == 0:
                raise ValueError('Received an empty batch. '
                                 'Batches should contain '
                                 'at least one item.')
            batch_logs['batch'] = steps_done
            batch_logs['size'] = batch_size
            callbacks.on_batch_begin(steps_done, batch_logs)

            outs = model.test_on_batch(x, y, sample_weight=sample_weight)
            if not isinstance(outs, list):
                outs = [outs]
            for l, o in zip(model.metrics_names, outs):
                batch_logs[l] = o
            outs_per_batch.append(outs)

            callbacks.on_batch_end(steps_done, batch_logs)

            steps_done += 1
            batch_sizes.append(batch_size)
            if verbose == 1:
                log_values = []
                for k in model.metrics_names:
                    if k in batch_logs:
                        log_values.append(('val_' + k, batch_logs[k]))
                progbar.update(steps_done, log_values)

        callbacks.on_epoch_end(1, {})

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    averages = []
    for i in range(len(model.metrics_names)):
        if i not in stateful_metric_indices:
            averages.append(
                np.average([out[i] for out in outs_per_batch],
                           weights=batch_sizes))
        else:
            averages.append(float(outs_per_batch[-1][i]))
    if len(averages) == 1:
        return averages[0], steps_done
    return averages, steps_done
Exemplo n.º 13
0
def main(dataset, batch_size, patch_size, epochs, label_smoothing,
         label_flipping):
    print(project_dir)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
    sess = tf.Session(config=config)
    K.tensorflow_backend.set_session(
        sess)  # set this TensorFlow session as the default session for Keras

    image_data_format = "channels_first"
    K.set_image_data_format(image_data_format)

    save_images_every_n_batches = 30
    save_model_every_n_epochs = 0

    # configuration parameters
    print("Config params:")
    print("  dataset = {}".format(dataset))
    print("  batch_size = {}".format(batch_size))
    print("  patch_size = {}".format(patch_size))
    print("  epochs = {}".format(epochs))
    print("  label_smoothing = {}".format(label_smoothing))
    print("  label_flipping = {}".format(label_flipping))
    print("  save_images_every_n_batches = {}".format(
        save_images_every_n_batches))
    print("  save_model_every_n_epochs = {}".format(save_model_every_n_epochs))

    model_name = datetime.strftime(datetime.now(), '%y%m%d-%H%M')
    model_dir = os.path.join(project_dir, "models", model_name)
    fig_dir = os.path.join(project_dir, "reports", "figures")
    logs_dir = os.path.join(project_dir, "reports", "logs", model_name)

    os.makedirs(model_dir)

    # Load and rescale data
    ds_train_gen = data_utils.DataGenerator(file_path=dataset,
                                            dataset_type="train",
                                            batch_size=batch_size)
    ds_train_disc = data_utils.DataGenerator(file_path=dataset,
                                             dataset_type="train",
                                             batch_size=batch_size)
    ds_val = data_utils.DataGenerator(file_path=dataset,
                                      dataset_type="val",
                                      batch_size=batch_size)
    enq_train_gen = OrderedEnqueuer(ds_train_gen,
                                    use_multiprocessing=True,
                                    shuffle=True)
    enq_train_disc = OrderedEnqueuer(ds_train_disc,
                                     use_multiprocessing=True,
                                     shuffle=True)
    enq_val = OrderedEnqueuer(ds_val, use_multiprocessing=True, shuffle=False)

    img_dim = ds_train_gen[0][0].shape[-3:]

    n_batch_per_epoch = len(ds_train_gen)
    epoch_size = n_batch_per_epoch * batch_size

    print("Derived params:")
    print("  n_batch_per_epoch = {}".format(n_batch_per_epoch))
    print("  epoch_size = {}".format(epoch_size))
    print("  n_batches_val = {}".format(len(ds_val)))

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size)

    tensorboard = TensorBoard(log_dir=logs_dir,
                              histogram_freq=0,
                              batch_size=batch_size,
                              write_graph=True,
                              write_grads=True,
                              update_freq='batch')

    try:
        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # Load generator model
        generator_model = models.generator_unet_upsampling(img_dim)
        generator_model.summary()
        plot_model(generator_model,
                   to_file=os.path.join(fig_dir, "generator_model.png"),
                   show_shapes=True,
                   show_layer_names=True)

        # Load discriminator model
        # TODO: modify disc to accept real input as well
        discriminator_model = models.DCGAN_discriminator(
            img_dim_disc, nb_patch)
        discriminator_model.summary()
        plot_model(discriminator_model,
                   to_file=os.path.join(fig_dir, "discriminator_model.png"),
                   show_shapes=True,
                   show_layer_names=True)

        # TODO: pretty sure this is unnecessary
        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_data_format)

        # L1 loss applies to generated image, cross entropy applies to predicted label
        loss = [models.l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        tensorboard.set_model(DCGAN_model)

        # Start training
        enq_train_gen.start(workers=1, max_queue_size=20)
        enq_train_disc.start(workers=1, max_queue_size=20)
        enq_val.start(workers=1, max_queue_size=20)
        out_train_gen = enq_train_gen.get()
        out_train_disc = enq_train_disc.get()
        out_val = enq_val.get()

        print("Start training")
        for e in range(1, epochs + 1):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            start = time.time()

            for batch_counter in range(1, n_batch_per_epoch + 1):
                X_transformed_batch, X_orig_batch = next(out_train_disc)

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_transformed_batch,
                    X_orig_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(out_train_gen)
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                # Set labels to 1 (real) to maximize the discriminator loss
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                metrics = [("D logloss", disc_loss), ("G tot", gen_loss[0]),
                           ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]
                progbar.add(batch_size, values=metrics)

                logs = {k: v for (k, v) in metrics}
                logs["size"] = batch_size

                tensorboard.on_batch_end(batch_counter, logs=logs)

                # Save images for visualization
                if batch_counter % save_images_every_n_batches == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(
                        X_transformed_batch, X_orig_batch, generator_model,
                        os.path.join(logs_dir, "current_batch_training.png"))
                    X_transformed_batch, X_orig_batch = next(out_val)
                    data_utils.plot_generated_batch(
                        X_transformed_batch, X_orig_batch, generator_model,
                        os.path.join(logs_dir, "current_batch_validation.png"))

            print("")
            print('Epoch %s/%s, Time: %s' % (e, epochs, time.time() - start))
            tensorboard.on_epoch_end(e, logs=logs)

            if (save_model_every_n_epochs >= 1 and e % save_model_every_n_epochs == 0) or \
                    (e == epochs):
                print("Saving model for epoch {}...".format(e), end="")
                sys.stdout.flush()
                gen_weights_path = os.path.join(
                    model_dir, 'gen_weights_epoch{:03d}.h5'.format(e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    model_dir, 'disc_weights_epoch{:03d}.h5'.format(e))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    model_dir, 'DCGAN_weights_epoch{:03d}.h5'.format(e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)
                print("done")

    except KeyboardInterrupt:
        pass

    enq_train_gen.stop()
    enq_train_disc.stop()
    enq_val.stop()
Exemplo n.º 14
0
class EvalLrTest(Callback):
    '''
    '''
    def __init__(self,
                 filename,
                 val_data,
                 lr_min=1e-6,
                 lr_max=1,
                 steps=1500,
                 val_period=1,
                 separator=','):
        self.filename = filename
        self.sep = separator
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.lr = lr_min
        self.val_period = val_period
        self.epoch = 0
        assert steps > 1
        self.lr_increment = pow((lr_max / lr_min), 1. / (steps - 1))
        self._open_args = {}
        self.batch_no = 0
        self.keys = None
        self.writer = None
        self.append_header = True
        workers = 5  # TODO make a parameter
        max_queue_size = 10  # TODO make a parameter
        self.val_enqueuer = OrderedEnqueuer(val_data,
                                            use_multiprocessing=False)
        self.val_enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        self.val_enqueuer_gen = self.val_enqueuer.get()
        self.validation_steps = len(val_data)
        super(EvalLrTest, self).__init__()

    def on_train_begin(self, logs=None):
        self.csv_file = io.open(self.filename, 'w', **self._open_args)
        self.lr = self.lr_min

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch

    def on_batch_begin(self, batch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        K.set_value(self.model.optimizer.lr, self.lr)
        #if self.verbose > 0:
        if batch == 0:  # epoch start
            print(
                '\nEpoch %05d Batch %05d: EvalLrTest setting learning rate to %s.'
                % (self.epoch, self.batch_no, self.lr))

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        logs['lr'] = self.lr
        logs['epoch'] = self.epoch

        if (self.batch_no % self.val_period == 0) or (self.lr > self.lr_max):
            val_outs = self.model.evaluate_generator(self.val_enqueuer_gen,
                                                     self.validation_steps,
                                                     workers=0)
            val_outs = to_list(val_outs)
            # Same labels assumed.
            for l, o in zip(self.model.metrics_names, val_outs):
                logs['val_' + l] = o

            def handle_value(k):
                is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
                if isinstance(k, six.string_types):
                    return k
                elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
                    return '"[%s]"' % (', '.join(map(str, k)))
                else:
                    return k

            if self.keys is None:
                self.keys = sorted(logs.keys())

            if self.model.stop_training:
                # We set NA so that csv parsers do not fail for this last epoch.
                logs = dict([(k, logs[k]) if k in logs else (k, 'NA')
                             for k in self.keys])

            if not self.writer:

                class CustomDialect(csv.excel):
                    delimiter = self.sep

                fieldnames = ['batch_no'] + self.keys
                if six.PY2:
                    fieldnames = [unicode(x) for x in fieldnames]
                self.writer = csv.DictWriter(self.csv_file,
                                             fieldnames=fieldnames,
                                             dialect=CustomDialect)
                if self.append_header:
                    self.writer.writeheader()

            row_dict = OrderedDict({'batch_no': self.batch_no})
            row_dict.update(
                (key, handle_value(logs[key])) for key in self.keys)
            self.writer.writerow(row_dict)
            self.csv_file.flush()

        self.lr *= self.lr_increment
        self.batch_no += 1
        if self.lr > self.lr_max:
            self.model.stop_training = True

    def on_train_end(self, logs=None):
        self.val_enqueuer.stop()
        self.csv_file.close()
        self.writer = None
Exemplo n.º 15
0
def evaluate_gen(model,
                 generator,
                 steps=None,
                 callbacks=None,
                 max_queue_size=10,
                 workers=1,
                 use_multiprocessing=False,
                 verbose=0):
    model._make_test_function()
    steps_done = 0
    outs_per_batch = []
    batch_sizes = []

    use_sequence_api = is_sequence(generator)
    if steps is None:
        steps = len(generator)

    enqueuer = None

    try:
        if workers > 0:
            if use_sequence_api:
                enqueuer = OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if use_sequence_api:
                output_generator = iter_sequence_infinite(generator)
            else:
                output_generator = generator

        if verbose == 1:
            progbar = Progbar(target=steps)
        count = 0
        while steps_done < steps:
            generator_output = next(output_generator)
            if not hasattr(generator_output, '__len__'):
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' + str(generator_output))
            if len(generator_output) == 2:
                x, y = generator_output
                sample_weight = None
            elif len(generator_output) == 3:
                x, y, sample_weight = generator_output
            else:
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' + str(generator_output))

            if x is None or len(x) == 0:
                batch_size = 1
            elif isinstance(x, list):
                batch_size = x[0].shape[0]
            elif isinstance(x, dict):
                batch_size = list(x.values())[0].shape[0]
            else:
                batch_size = x.shape[0]
            if batch_size == 0:
                raise ValueError('Received an empty batch. '
                                 'Batches should contain '
                                 'at least one item.')

            batch_logs = {'batch': steps_done, 'size': batch_size}
            y_pred = model.predict_on_batch(x)
            success_result = K.eval(
                metrics.top_k_categorical_accuracy(y, y_pred, k=3))

            steps_done += 1

            if verbose == 1:
                progbar.update(steps_done)

    finally:
        if enqueuer is not None:
            enqueuer.stop()
    return success_result
Exemplo n.º 16
0
    def fit_generator_feed(self,
                           generator,
                           steps_per_epoch=None,
                           epochs=1,
                           verbose=1,
                           callbacks=None,
                           validation_data=None,
                           validation_steps=None,
                           class_weight=None,
                           max_queue_size=10,
                           workers=1,
                           use_multiprocessing=False,
                           shuffle=True,
                           initial_epoch=0,
                           check_array_lengths=True):
        """Train the model on data generated batch-by-batch by a Python generator
        or an instance of `Sequence`.

        See `Model.fit_generator()` for the full documentation.

        The only difference here is that the generator must also generate data for
        any native placeholders of the model.

        Only use this if you know what you are doing (especially with the `shuffle`
        and `check_array_lengths` parameters). If not, prefer `self.fit_fullbatches()`
        or `self.fit_minibatches()`.

        """

        # Disable validation, as we haven't converted the code for this yet.
        # All related code is commented with a `disabled:` prefix.
        if validation_data is not None:
            raise ValueError(
                'Validation with a feeding generator is not yet supported')
        # The original (feed-modified) method starts here.

        wait_time = 0.01  # in seconds
        epoch = initial_epoch

        # disable: do_validation = bool(validation_data)
        self._make_train_function()
        # disable: if do_validation:
        # disable:     self._make_test_function()

        is_sequence = isinstance(generator, Sequence)
        if not is_sequence and use_multiprocessing and workers > 1:
            warnings.warn(
                UserWarning('Using a generator with `use_multiprocessing=True`'
                            ' and multiple workers may duplicate your data.'
                            ' Please consider using the`keras.utils.Sequence'
                            ' class.'))
        if steps_per_epoch is None:
            if is_sequence:
                steps_per_epoch = len(generator)
            else:
                raise ValueError(
                    '`steps_per_epoch=None` is only valid for a'
                    ' generator based on the `keras.utils.Sequence`'
                    ' class. Please specify `steps_per_epoch` or use'
                    ' the `keras.utils.Sequence` class.')

        # disable: # python 2 has 'next', 3 has '__next__'
        # disable: # avoid any explicit version checks
        # disable: val_gen = (hasattr(validation_data, 'next') or
        # disable:            hasattr(validation_data, '__next__') or
        # disable:            isinstance(validation_data, Sequence))
        # disable: if (val_gen and not isinstance(validation_data, Sequence) and
        # disable:         not validation_steps):
        # disable:     raise ValueError('`validation_steps=None` is only valid for a'
        # disable:                      ' generator based on the `keras.utils.Sequence`'
        # disable:                      ' class. Please specify `validation_steps` or use'
        # disable:                      ' the `keras.utils.Sequence` class.')

        # Prepare display labels.
        out_labels = self.metrics_names
        callback_metrics = out_labels + ['val_' + n for n in out_labels]

        # prepare callbacks
        self.history = cbks.History()
        _callbacks = [
            cbks.BaseLogger(stateful_metrics=self.stateful_metric_names)
        ]
        if verbose:
            _callbacks.append(
                cbks.ProgbarLogger(
                    count_mode='steps',
                    stateful_metrics=self.stateful_metric_names))
        _callbacks += (callbacks or []) + [self.history]
        callbacks = cbks.CallbackList(_callbacks)

        # it's possible to callback a different model than self:
        if hasattr(self, 'callback_model') and self.callback_model:
            callback_model = self.callback_model
        else:
            callback_model = self
        callbacks.set_model(callback_model)
        callbacks.set_params({
            'epochs': epochs,
            'steps': steps_per_epoch,
            'verbose': verbose,
            # disable: 'do_validation': do_validation,
            'metrics': callback_metrics,
        })
        callbacks.on_train_begin()

        enqueuer = None
        # disable: val_enqueuer = None

        try:
            # disable: if do_validation and not val_gen:
            # disable:     # Prepare data for validation
            # disable:     if len(validation_data) == 2:
            # disable:         val_x, val_y = validation_data
            # disable:         val_sample_weight = None
            # disable:     elif len(validation_data) == 3:
            # disable:         val_x, val_y, val_sample_weight = validation_data
            # disable:     else:
            # disable:         raise ValueError('`validation_data` should be a tuple '
            # disable:                          '`(val_x, val_y, val_sample_weight)` '
            # disable:                          'or `(val_x, val_y)`. Found: ' +
            # disable:                          str(validation_data))
            # disable:     val_x, val_y, val_sample_weights = self._standardize_user_data(
            # disable:         val_x, val_y, val_sample_weight)
            # disable:     val_data = val_x + val_y + val_sample_weights
            # disable:     if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
            # disable:         val_data += [0.]
            # disable:     for cbk in callbacks:
            # disable:         cbk.validation_data = val_data

            if workers > 0:
                if is_sequence:
                    enqueuer = OrderedEnqueuer(
                        generator,
                        use_multiprocessing=use_multiprocessing,
                        shuffle=shuffle)
                else:
                    enqueuer = GeneratorEnqueuer(
                        generator,
                        use_multiprocessing=use_multiprocessing,
                        wait_time=wait_time)
                enqueuer.start(workers=workers, max_queue_size=max_queue_size)
                output_generator = enqueuer.get()
            else:
                if is_sequence:
                    output_generator = iter(generator)
                else:
                    output_generator = generator

            callback_model.stop_training = False
            # Construct epoch logs.
            epoch_logs = {}
            while epoch < epochs:
                for m in self.metrics:
                    if isinstance(m, Layer) and m.stateful:
                        m.reset_states()
                callbacks.on_epoch_begin(epoch)
                steps_done = 0
                batch_index = 0
                while steps_done < steps_per_epoch:
                    generator_output = next(output_generator)

                    if not hasattr(generator_output, '__len__'):
                        raise ValueError(
                            'Output of generator should be '
                            'a tuple `(x, y, feeds, sample_weight)` '
                            'or `(x, y, feeds)`. Found: ' +
                            str(generator_output))

                    if len(generator_output) == 3:
                        x, y, feeds = generator_output
                        sample_weight = None
                    elif len(generator_output) == 4:
                        x, y, feeds, sample_weight = generator_output
                    else:
                        raise ValueError(
                            'Output of generator should be '
                            'a tuple `(x, y, feeds, sample_weight)` '
                            'or `(x, y, feeds)`. Found: ' +
                            str(generator_output))
                    # build batch logs
                    batch_logs = {}
                    if x is None or len(x) == 0:
                        # Handle data tensors support when no input given
                        # step-size = 1 for data tensors
                        batch_size = 1
                    elif isinstance(x, list):
                        batch_size = x[0].shape[0]
                    elif isinstance(x, dict):
                        batch_size = list(x.values())[0].shape[0]
                    else:
                        batch_size = x.shape[0]
                    batch_logs['batch'] = batch_index
                    batch_logs['size'] = batch_size
                    callbacks.on_batch_begin(batch_index, batch_logs)

                    outs = self.train_on_fed_batch(
                        x,
                        y,
                        feeds=feeds,
                        sample_weight=sample_weight,
                        class_weight=class_weight,
                        check_array_lengths=check_array_lengths)

                    if not isinstance(outs, list):
                        outs = [outs]
                    for l, o in zip(out_labels, outs):
                        batch_logs[l] = o

                    callbacks.on_batch_end(batch_index, batch_logs)

                    batch_index += 1
                    steps_done += 1

                    # Epoch finished.
                    # disable: if steps_done >= steps_per_epoch and do_validation:
                    # disable:     if val_gen:
                    # disable:         val_outs = self.evaluate_generator(
                    # disable:             validation_data,
                    # disable:             validation_steps,
                    # disable:             workers=workers,
                    # disable:             use_multiprocessing=use_multiprocessing,
                    # disable:             max_queue_size=max_queue_size)
                    # disable:     else:
                    # disable:         # No need for try/except because
                    # disable:         # data has already been validated.
                    # disable:         val_outs = self.evaluate(
                    # disable:             val_x, val_y,
                    # disable:             batch_size=batch_size,
                    # disable:             sample_weight=val_sample_weights,
                    # disable:             verbose=0)
                    # disable:     if not isinstance(val_outs, list):
                    # disable:         val_outs = [val_outs]
                    # disable:     # Same labels assumed.
                    # disable:     for l, o in zip(out_labels, val_outs):
                    # disable:         epoch_logs['val_' + l] = o

                    if callback_model.stop_training:
                        break

                callbacks.on_epoch_end(epoch, epoch_logs)
                epoch += 1
                if callback_model.stop_training:
                    break

        finally:
            try:
                if enqueuer is not None:
                    enqueuer.stop()
            finally:
                pass
                # disable: if val_enqueuer is not None:
                # disable:     val_enqueuer.stop()

        callbacks.on_train_end()
        return self.history
Exemplo n.º 17
0
    reg_energy = load_model(rege_path)
    reg_direction = load_model(altaz_path)

    print('Building test generator...')
    test_generator = DataGeneratorChain(h5files,
                                        batch_size=batch_size,
                                        arrival_time=time,
                                        shuffle=True)

    # retrieve ground truth
    print('Inference on data...')
    steps_done = 0
    steps = len(test_generator)
    # steps = 2

    enqueuer = OrderedEnqueuer(test_generator, use_multiprocessing=True)
    enqueuer.start(workers=4, max_queue_size=10)
    output_generator = enqueuer.get()

    progbar = Progbar(target=steps)

    table = np.array([]).reshape(0, 9)

    while steps_done < steps:
        generator_output = next(output_generator)
        x, y, intensity, energy, altaz = generator_output

        y_prd = classifier.predict_on_batch(x)
        e_reco = reg_energy.predict_on_batch(x)
        altaz_reco = reg_direction.predict_on_batch(x)
Exemplo n.º 18
0
def main(argv=None):
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    print('gpu id', FLAGS.gpu_list)
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)

    input_images = tf.placeholder(tf.float32,
                                  shape=[None, None, None, 3],
                                  name='input_images')
    input_score_maps = tf.placeholder(tf.float32,
                                      shape=[None, None, None, 1],
                                      name='input_score_maps')
    if FLAGS.geometry == 'RBOX':
        input_geo_maps = tf.placeholder(tf.float32,
                                        shape=[None, None, None, 5],
                                        name='input_geo_maps')
    else:
        input_geo_maps = tf.placeholder(tf.float32,
                                        shape=[None, None, None, 8],
                                        name='input_geo_maps')
    input_training_masks = tf.placeholder(tf.float32,
                                          shape=[None, None, None, 1],
                                          name='input_training_masks')

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               global_step,
                                               decay_steps=10000,
                                               decay_rate=0.94,
                                               staircase=True)
    # add summary
    tf.summary.scalar('learning_rate', learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)
    # opt = tf.train.MomentumOptimizer(learning_rate, 0.9)

    # split
    input_images_split = tf.split(input_images, len(gpus))
    input_score_maps_split = tf.split(input_score_maps, len(gpus))
    input_geo_maps_split = tf.split(input_geo_maps, len(gpus))
    input_training_masks_split = tf.split(input_training_masks, len(gpus))

    tower_grads = []
    reuse_variables = None
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_images_split[i]
                isms = input_score_maps_split[i]
                igms = input_geo_maps_split[i]
                itms = input_training_masks_split[i]
                total_loss, model_loss = tower_loss(iis, isms, igms, itms,
                                                    reuse_variables)
                batch_norm_updates_op = tf.group(
                    *tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True

                grads = opt.compute_gradients(total_loss)
                tower_grads.append(grads)

    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # save moving average
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # batch norm updates
    with tf.control_dependencies(
        [variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path,
                                           tf.get_default_graph())

    init = tf.global_variables_initializer()

    if FLAGS.pretrained_model_path is not None:
        variable_restore_op = slim.assign_from_checkpoint_fn(
            FLAGS.pretrained_model_path,
            slim.get_trainable_variables(),
            ignore_missing_vars=True)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        if FLAGS.restore:
            print('continue training from previous checkpoint')
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
            saver.restore(sess, ckpt)
        else:
            sess.run(init)
            if FLAGS.pretrained_model_path is not None:
                variable_restore_op(sess)

        print('FLAGS.train_txt_dir', FLAGS.train_txt_dir)
        generator = icdar.EAST_generator(
            data_path=FLAGS.images_dir,
            txt_dir=FLAGS.txt_dir,
            batch_size=FLAGS.batch_size,
        )

        if FLAGS.workers > 0:
            ''' load data with multiprocessing   
            '''
            enqueuer = OrderedEnqueuer(
                generator,
                use_multiprocessing=FLAGS.use_multiprocessing,
                shuffle=True)
            enqueuer.start(workers=FLAGS.workers,
                           max_queue_size=FLAGS.max_queue_size)
            output_generator = enqueuer.get()
            print('workers ', FLAGS.workers)
        else:
            output_generator = iter_sequence_infinite(generator)

        start = time.time()
        step_print = FLAGS.save_checkpoint_steps // 10
        for epoch in range(FLAGS.epochs):
            for step in range(len(generator)):
                data = next(output_generator)
                temp_data = data
                ## prevent for the batch_size%len(gpus)!=0
                if len(data[0]) % len(gpus) != 0:
                    temp_data = []
                    for item in data:
                        temp_data.append(item[:len(data[0]) // len(gpus) *
                                              len(gpus)])
                data = temp_data
                ml, tl, _ = sess.run(
                    [model_loss, total_loss, train_op],
                    feed_dict={
                        input_images: data[0],
                        input_score_maps: data[1],
                        input_geo_maps: data[2],
                        input_training_masks: data[3]
                    })
                if np.isnan(tl):
                    print('Loss diverged, stop training')
                    break

                if step % step_print == 0:
                    avg_time_per_step = (time.time() - start) / step_print
                    avg_examples_per_second = (
                        step_print * FLAGS.batch_size) / (time.time() - start)
                    start = time.time()
                    print(
                        'Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'
                        .format(step, ml, tl, avg_time_per_step,
                                avg_examples_per_second))

                if step % FLAGS.save_checkpoint_steps == 0:
                    saver.save(sess,
                               FLAGS.checkpoint_path + 'model.ckpt',
                               global_step=global_step)

                if step % FLAGS.save_summary_steps == 0:
                    _, tl, summary_str = sess.run(
                        [train_op, total_loss, summary_op],
                        feed_dict={
                            input_images: data[0],
                            input_score_maps: data[1],
                            input_geo_maps: data[2],
                            input_training_masks: data[3]
                        })
                    summary_writer.add_summary(summary_str, global_step=step)
            generator.on_epoch_end()
Exemplo n.º 19
0
def create_parallel_queue(data_seq):
    enqueuer = OrderedEnqueuer(data_seq,
                               use_multiprocessing=True,
                               shuffle=True)
    enqueuer.start(workers=4, max_queue_size=8)
    return enqueuer.get()
Exemplo n.º 20
0
def fit_generator_autosized(
        model,
        generator,
        epochs=1,
        #steps_per_epoch=None,
        verbose=1,
        callbacks=None,
        validation_data=None,
        validation_steps=None,
        validation_callbacks=None,
        class_weight=None,
        max_queue_size=10,
        workers=1,
        use_multiprocessing=False,
        shuffle=True,
        initial_epoch=0):
    """See docstring for `Model.fit_generator`."""
    wait_time = 0.01  # in seconds
    epoch = initial_epoch

    do_validation = bool(validation_data)
    model._make_train_function()
    if do_validation:
        model._make_test_function()

    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    # if steps_per_epoch is None:
    #     if is_sequence:
    #         steps_per_epoch = len(generator)
    #     else:
    #         raise ValueError('`steps_per_epoch=None` is only valid for a'
    #                          ' generator based on the '
    #                          '`keras.utils.Sequence`'
    #                          ' class. Please specify `steps_per_epoch` '
    #                          'or use the `keras.utils.Sequence` class.')

    # python 2 has 'next', 3 has '__next__'
    # avoid any explicit version checks
    val_gen = (hasattr(validation_data, 'next')
               or hasattr(validation_data, '__next__')
               or isinstance(validation_data, Sequence))
    # if (val_gen and not isinstance(validation_data, Sequence) and
    #         not validation_steps):
    #     raise ValueError('`validation_steps=None` is only valid for a'
    #                      ' generator based on the `keras.utils.Sequence`'
    #                      ' class. Please specify `validation_steps` or use'
    #                      ' the `keras.utils.Sequence` class.')

    # Prepare display labels.
    out_labels = model.metrics_names
    callback_metrics = out_labels + ['val_' + n for n in out_labels]

    # prepare callbacks
    model.history = cbks.History()
    _callbacks = [
        cbks.BaseLogger(stateful_metrics=model.stateful_metric_names)
    ]
    # instead of ProgbarLogger (but only for first epoch):
    if verbose:
        print('Epoch 1/%d' % epochs)
        progbar = Progbar(target=None,
                          verbose=1,
                          stateful_metrics=model.stateful_metric_names)
    _callbacks += (callbacks or []) + [model.history]
    callbacks = cbks.CallbackList(_callbacks)

    # it's possible to callback a different model than self:
    if hasattr(model, 'callback_model') and model.callback_model:
        callback_model = model.callback_model
    else:
        callback_model = model
    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': epochs,
        'steps': None,  # will be refined during first epoch
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics,
    })
    callbacks.on_train_begin()

    enqueuer = None
    val_enqueuer = None

    try:
        if do_validation and not val_gen:
            # Prepare data for validation
            if len(validation_data) == 2:
                val_x, val_y = validation_data
                val_sample_weight = None
            elif len(validation_data) == 3:
                val_x, val_y, val_sample_weight = validation_data
            else:
                raise ValueError('`validation_data` should be a tuple '
                                 '`(val_x, val_y, val_sample_weight)` '
                                 'or `(val_x, val_y)`. Found: ' +
                                 str(validation_data))
            val_x, val_y, val_sample_weights = model._standardize_user_data(
                val_x, val_y, val_sample_weight)
            val_data = val_x + val_y + val_sample_weights
            if model.uses_learning_phase and not isinstance(
                    K.learning_phase(), int):
                val_data += [0.]
            for cbk in callbacks:
                cbk.validation_data = val_data

        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    shuffle=shuffle)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = iter(generator)
            else:
                output_generator = generator

        callback_model.stop_training = False
        # Construct epoch logs.
        epoch_logs = {}
        while epoch < epochs:
            for m in model.stateful_metric_functions:
                m.reset_states()
            callbacks.on_epoch_begin(epoch)
            steps_done = 0
            batch_index = 0
            for generator_output in output_generator:
                if not generator_output:  # end of epoch?
                    break
                if not hasattr(generator_output, '__len__'):
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))

                if len(generator_output) == 2:
                    x, y = generator_output
                    sample_weight = None
                elif len(generator_output) == 3:
                    x, y, sample_weight = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))
                # build batch logs
                batch_logs = {}
                if not x:
                    # Handle data tensors support when no input given
                    # step-size = 1 for data tensors
                    batch_size = 1
                elif isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size
                callbacks.on_batch_begin(batch_index, batch_logs)

                outs = model.train_on_batch(x,
                                            y,
                                            sample_weight=sample_weight,
                                            class_weight=class_weight)

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)
                if epoch == initial_epoch and verbose:
                    log_values = []
                    for k in callback_metrics:
                        if k in batch_logs:
                            log_values.append((k, batch_logs[k]))
                    progbar.update(steps_done, log_values)

                batch_index += 1
                steps_done += 1

                if callback_model.stop_training:
                    break

            if epoch == initial_epoch:
                if verbose:
                    log_values = []
                    for k in callback_metrics:
                        if k in batch_logs:
                            log_values.append((k, batch_logs[k]))
                    progbar.update(steps_done, log_values)

            # Epoch finished.
            if do_validation:
                if val_gen:
                    val_outs, validation_steps = evaluate_generator_autosized(
                        model,
                        validation_data,
                        steps=validation_steps,
                        callbacks=validation_callbacks,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        max_queue_size=max_queue_size,
                        verbose=1)
                else:
                    # No need for try/except because
                    # data has already been validated.
                    val_outs = model.evaluate(val_x,
                                              val_y,
                                              batch_size=batch_size,
                                              sample_weight=val_sample_weights,
                                              verbose=0)
                if not isinstance(val_outs, list):
                    val_outs = [val_outs]
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

                if callback_model.stop_training:
                    break

            callbacks.on_epoch_end(epoch, epoch_logs)
            if epoch == initial_epoch:
                if verbose:
                    print()
                    progbar = cbks.ProgbarLogger(
                        count_mode='steps',
                        stateful_metrics=model.stateful_metric_names)
                    progbar.set_model(callback_model)
                    callbacks.append(progbar)
                callbacks.set_params({
                    'epochs': epochs,
                    'steps': steps_done,  # refine
                    'verbose': verbose,
                    'do_validation': do_validation,
                    'metrics': callback_metrics,
                })
                if verbose:
                    progbar.on_train_begin()

            epoch += 1
            if callback_model.stop_training:
                break

    finally:
        try:
            if enqueuer is not None:
                enqueuer.stop()
        finally:
            if val_enqueuer is not None:
                val_enqueuer.stop()

    callbacks.on_train_end()
    return model.history
Exemplo n.º 21
0
def custom_fit_generator(model,
                         generator,
                         steps_per_epoch=None,
                         epochs=1,
                         verbose=1,
                         callbacks=None,
                         validation_data=None,
                         validation_steps=None,
                         class_weight=None,
                         max_queue_size=10,
                         workers=1,
                         use_multiprocessing=False,
                         shuffle=True,
                         initial_epoch=0):
    """
        Same function fit_generator as Keras but with only a subset of the variables displayed
        """
    wait_time = 0.01  # in seconds
    epoch = initial_epoch

    do_validation = bool(validation_data)
    model._make_train_function()
    if do_validation:
        model._make_test_function()

    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps_per_epoch is None:
        if is_sequence:
            steps_per_epoch = len(generator)
        else:
            raise ValueError('`steps_per_epoch=None` is only valid for a'
                             ' generator based on the `keras.utils.Sequence`'
                             ' class. Please specify `steps_per_epoch` or use'
                             ' the `keras.utils.Sequence` class.')

    # python 2 has 'next', 3 has '__next__'
    # avoid any explicit version checks
    val_gen = (hasattr(validation_data, 'next')
               or hasattr(validation_data, '__next__')
               or isinstance(validation_data, Sequence))
    if (val_gen and not isinstance(validation_data, Sequence)
            and not validation_steps):
        raise ValueError('`validation_steps=None` is only valid for a'
                         ' generator based on the `keras.utils.Sequence`'
                         ' class. Please specify `validation_steps` or use'
                         ' the `keras.utils.Sequence` class.')

    # Prepare display labels.
    out_labels = model.metrics_names
    callback_metrics = out_labels + ['val_' + n for n in out_labels]
    callback_metrics = [
        'loss', 'acc', 'case_loss', 'case_acc', 'val_loss', 'val_acc',
        'val_case_loss', 'val_case_acc'
    ]
    # prepare callbacks
    model.history = cbks.History()
    _callbacks = [
        cbks.BaseLogger(stateful_metrics=model.stateful_metric_names)
    ]
    if verbose:
        _callbacks.append(
            cbks.ProgbarLogger(count_mode='steps',
                               stateful_metrics=model.stateful_metric_names))
    _callbacks += (callbacks or []) + [model.history]
    callbacks = cbks.CallbackList(_callbacks)

    # it's possible to callback a different model than model:
    if hasattr(model, 'callback_model') and model.callback_model:
        callback_model = model.callback_model
    else:
        callback_model = model
    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': epochs,
        'steps': steps_per_epoch,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics,
    })
    callbacks.on_train_begin()

    enqueuer = None
    val_enqueuer = None

    try:
        if do_validation and not val_gen:
            # Prepare data for validation
            if len(validation_data) == 2:
                val_x, val_y = validation_data
                val_sample_weight = None
            elif len(validation_data) == 3:
                val_x, val_y, val_sample_weight = validation_data
            else:
                raise ValueError('`validation_data` should be a tuple '
                                 '`(val_x, val_y, val_sample_weight)` '
                                 'or `(val_x, val_y)`. Found: ' +
                                 str(validation_data))
            val_x, val_y, val_sample_weights = model._standardize_user_data(
                val_x, val_y, val_sample_weight)
            val_data = val_x + val_y + val_sample_weights
            if model.uses_learning_phase and not isinstance(
                    K.learning_phase(), int):
                val_data += [0.]
            for cbk in callbacks:
                cbk.validation_data = val_data

        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    shuffle=shuffle)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = iter(generator)
            else:
                output_generator = generator

        callback_model.stop_training = False
        # Construct epoch logs.
        epoch_logs = {}
        while epoch < epochs:
            callbacks.on_epoch_begin(epoch)
            steps_done = 0
            batch_index = 0
            while steps_done < steps_per_epoch:
                generator_output = next(output_generator)

                if not hasattr(generator_output, '__len__'):
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))

                if len(generator_output) == 2:
                    x, y = generator_output
                    sample_weight = None
                elif len(generator_output) == 3:
                    x, y, sample_weight = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))
                # build batch logs
                batch_logs = {}
                if isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size
                callbacks.on_batch_begin(batch_index, batch_logs)

                outs = model.train_on_batch(x,
                                            y,
                                            sample_weight=sample_weight,
                                            class_weight=class_weight)

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)

                batch_index += 1
                steps_done += 1

                # Epoch finished.
                if steps_done >= steps_per_epoch and do_validation:
                    if val_gen:
                        val_outs = model.evaluate_generator(
                            validation_data,
                            validation_steps,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing,
                            max_queue_size=max_queue_size)
                    else:
                        # No need for try/except because
                        # data has already been validated.
                        val_outs = model.evaluate(
                            val_x,
                            val_y,
                            batch_size=batch_size,
                            sample_weight=val_sample_weights,
                            verbose=0)
                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

                if callback_model.stop_training:
                    break

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1
            if callback_model.stop_training:
                break

    finally:
        try:
            if enqueuer is not None:
                enqueuer.stop()
        finally:
            if val_enqueuer is not None:
                val_enqueuer.stop()

    callbacks.on_train_end()
    return model.history
Exemplo n.º 22
0
def my_fit_generator(params,
                     generator,
                     ckpt_dir,
                     val_data_list = None,
                     learning_rate=1e-3,
                     lr_decay_step = 1.0,
                     lr_decay_rate = 1.0,
                     epochs = 1,
                     max_queue_size=10,
                     workers=1,
                     use_multiprocessing=False,
                     shuffle=True,
                     save_step = 1,
                     initial_epoch = 0):
    """See docstring for `Model.fit_generator`."""

    wait_time = 0.01  # in seconds
    epoch = initial_epoch


    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))

    if is_sequence:
        steps_per_epoch = len(generator)
    else:
        raise ValueError('`steps_per_epoch=None` is only valid for a'
                         ' generator based on the '
                         '`keras.utils.Sequence`'
                         ' class. Please specify `steps_per_epoch` '
                         'or use the `keras.utils.Sequence` class.')

    enqueuer = None

    try:
        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    shuffle=shuffle)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = iter_sequence_infinite(generator)
            else:
                output_generator = generator



        # for test graph
        tf.reset_default_graph()
        input_shape = (None, 64, 64, 64, 1)
        Iref = tf.placeholder(tf.float32, shape=input_shape, name='Iref')
        Imov = tf.placeholder(tf.float32, shape=input_shape, name='Imov')

        out = mm.flownet(Iref, Imov)
        flowTst = tf.identity(tf.squeeze(out), name='flowTst')

        Iwarp = image_warp(Imov, out)
        Iwarp = tf.identity(tf.squeeze(Iwarp), name='Iwarp')

        if not os.path.isdir(ckpt_dir):
            os.makedirs(ckpt_dir)
        sessFileNameTst = ckpt_dir + 'modelTst'
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            savedFile = saver.save(sess, sessFileNameTst, latest_filename='checkpointTst')
        print('testing model saved:' + savedFile)

        # %%   construct training map
        tf.reset_default_graph()
        input_shape = (params['batch_size'],)+( 64, 64, 64, 1)
        Iref = tf.placeholder(tf.float32, shape=input_shape, name='Iref')  # [batch, nx, ny, nz, channel]
        Imov = tf.placeholder(tf.float32, shape=input_shape, name='Imov')  # [batch, nx, ny, nz]

        flows, out, Iref_out, Imov_out, border_mask = mm.flownet(Iref, Imov, training=True, augment=True)
        # flowT = tf.identity(tf.squeeze(out), name='flowT')

        lr = tf.placeholder(tf.float32, name='learning_rate')
        loss = mm.unsupervised_loss(flows, Iref_out, Imov_out, border_mask, params)
        tf.summary.scalar('loss', loss)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.name_scope('optimizer'):
            optimizer = tf.train.AdamOptimizer(learning_rate=lr)
            gvs = optimizer.compute_gradients(loss)
            # capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
            # opToRun=optimizer.apply_gradients(capped_gvs)
            opToRun = optimizer.apply_gradients(gvs)

        # %% training code
        totalLoss, ep, eval_Loss = [], 0, []
        lossT = tf.placeholder(tf.float32)
        lossE = tf.placeholder(tf.float32)

        lossSumT = tf.summary.scalar("TrnLoss", lossT)
        lossSumE = tf.summary.scalar("TestLoss", lossE)

        sessFileName = ckpt_dir + 'model'

        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        with tf.Session(config=sess_config) as sess:
            saver = restore_networks(sess, ckpt_dir, ckpt)
            writer = tf.summary.FileWriter(ckpt_dir, sess.graph)

            while epoch < epochs:

                if (epoch+1)%lr_decay_step ==0:
                    learning_rate = learning_rate / lr_decay_rate

                steps_done = 0
                while steps_done < steps_per_epoch:
                    generator_output = next(output_generator)
                    img_ref = generator_output[...,0][...,np.newaxis]
                    img_mov = generator_output[...,1][...,np.newaxis]
                    feed_dict = {lr: learning_rate, Iref: img_ref, Imov:img_mov}

                    _,_, trn_loss = sess.run(
                        [opToRun,update_ops, loss],
                        feed_dict=feed_dict)
                    steps_done += 1

                    if (epoch * steps_per_epoch + steps_done) % params['display_interval']==0 or (epoch==0 and steps_done==1):
                        ep = ep + 1
                        lossSum = sess.run(lossSumT, feed_dict={lossT: trn_loss})
                        writer.add_summary(lossSum, ep)

                        if val_data_list is not None:
                            eval_data = get_eval(val_data_list)
                            img_ref = eval_data[...,0][...,np.newaxis]
                            img_mov = eval_data[...,1][...,np.newaxis]
                            tst_loss = sess.run(loss,feed_dict={Iref: img_ref, Imov:img_mov})
                            writer.add_summary(sess.run(lossSumE, feed_dict={lossE:tst_loss}), ep)

                            print("-- train: epoch = {}, steps_done/steps per epoch = {}/{}, Train loss = {}, Test loss = {}"
                                  .format(epoch + 1, steps_done, steps_per_epoch, trn_loss, tst_loss))
                        else:
                            print(
                                "-- train: epoch = {}, steps_done/steps per epoch = {}/{}, Train loss = {}"
                                .format(epoch + 1, steps_done, steps_per_epoch, trn_loss))

                saver.save(sess, sessFileName, global_step=epoch, write_meta_graph=True)
                epoch += 1

            # Epoch finished.
            writer.close()

    finally:
        if enqueuer is not None:
            enqueuer.stop()
Exemplo n.º 23
0
    def evaluate_generator_custom(model,
                                  generator,
                                  steps=None,
                                  max_queue_size=10,
                                  class_weight=None,
                                  workers=1,
                                  use_multiprocessing=False,
                                  verbose=0):
        """See docstring for `Model.evaluate_generator`."""
        model._make_test_function()

        stateful_metric_indices = []
        if hasattr(model, 'metrics'):
            for m in model.stateful_metric_functions:
                m.reset_states()
            stateful_metric_indices = [
                i for i, name in enumerate(model.metrics_names)
                if str(name) in model.stateful_metric_names
            ]
        else:
            stateful_metric_indices = []

        steps_done = 0
        wait_time = 0.01
        outs_per_batch = []
        batch_sizes = []
        is_sequence = isinstance(generator, Sequence)
        if not is_sequence and use_multiprocessing and workers > 1:
            warnings.warn(
                UserWarning('Using a generator with `use_multiprocessing=True`'
                            ' and multiple workers may duplicate your data.'
                            ' Please consider using the`keras.utils.Sequence'
                            ' class.'))
        if steps is None:
            if is_sequence:
                steps = len(generator)
            else:
                raise ValueError('`steps=None` is only valid for a generator'
                                 ' based on the `keras.utils.Sequence` class.'
                                 ' Please specify `steps` or use the'
                                 ' `keras.utils.Sequence` class.')
        enqueuer = None

        try:
            if workers > 0:
                if is_sequence:
                    enqueuer = OrderedEnqueuer(
                        generator, use_multiprocessing=use_multiprocessing)
                else:
                    enqueuer = GeneratorEnqueuer(
                        generator,
                        use_multiprocessing=use_multiprocessing,
                        wait_time=wait_time)
                enqueuer.start(workers=workers, max_queue_size=max_queue_size)
                output_generator = enqueuer.get()
            else:
                if is_sequence:
                    output_generator = iter(generator)
                else:
                    output_generator = generator

            if verbose == 1:
                progbar = Progbar(target=steps)

            while steps_done < steps:
                generator_output = next(output_generator)
                if not hasattr(generator_output, '__len__'):
                    raise ValueError('Output of generator should be a tuple '
                                     '(x, y, sample_weight) '
                                     'or (x, y). Found: ' +
                                     str(generator_output))
                if len(generator_output) == 2:
                    x, y = generator_output
                    sample_weight = None
                elif len(generator_output) == 3:
                    x, y, sample_weight = generator_output
                else:
                    raise ValueError('Output of generator should be a tuple '
                                     '(x, y, sample_weight) '
                                     'or (x, y). Found: ' +
                                     str(generator_output))

                # Ken: to weight validation examples, testing must re-weight the examples
                ###############################################################
                outs = model.test_on_batch_custom(x,
                                                  y,
                                                  sample_weight=sample_weight,
                                                  class_weight=class_weight)
                ###############################################################

                outs = to_list(outs)
                outs_per_batch.append(outs)

                if x is None or len(x) == 0:
                    # Handle data tensors support when no input given
                    # step-size = 1 for data tensors
                    batch_size = 1
                elif isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                if batch_size == 0:
                    raise ValueError('Received an empty batch. '
                                     'Batches should contain '
                                     'at least one item.')
                steps_done += 1
                batch_sizes.append(batch_size)
                if verbose == 1:
                    progbar.update(steps_done)

        finally:
            if enqueuer is not None:
                enqueuer.stop()

        averages = []
        for i in range(len(outs)):
            if i not in stateful_metric_indices:
                averages.append(
                    np.average([out[i] for out in outs_per_batch],
                               weights=batch_sizes))
            else:
                averages.append(np.float64(outs_per_batch[-1][i]))
        return unpack_singleton(averages)
Exemplo n.º 24
0
def fit_generator_Ndiff(model,
                        generator,
                        steps_per_epoch=None,
                        batch_size=1,
                        N_diff=5,
                        margin=0.5,
                        epochs=1,
                        verbose=1,
                        callbacks=None,
                        validation_data=None,
                        validation_steps=None,
                        class_weight=None,
                        max_queue_size=10,
                        workers=1,
                        use_multiprocessing=False,
                        shuffle=True,
                        initial_epoch=0):
    """Trains the model on data yielded batch-by-batch by a Python generator.
    The generator is run in parallel to the model, for efficiency.
    For instance, this allows you to do real-time data augmentation
    on images on CPU in parallel to training your model on GPU.
    The use of `keras.utils.Sequence` guarantees the ordering
    and guarantees the single use of every input per epoch when
    using `use_multiprocessing=True`.
    # Arguments
        generator: A generator or an instance of `Sequence`
            (`keras.utils.Sequence`) object in order to avoid
            duplicate data when using multiprocessing.
            The output of the generator must be either
            - a tuple `(inputs, targets)`
            - a tuple `(inputs, targets, sample_weights)`.
            This tuple (a single output of the generator) makes a single
            batch. Therefore, all arrays in this tuple must have the same
            length (equal to the size of this batch). Different batches
            may have different sizes. For example, the last batch of the
            epoch is commonly smaller than the others, if the size of the
            dataset is not divisible by the batch size.
            The generator is expected to loop over its data
            indefinitely. An epoch finishes when `steps_per_epoch`
            batches have been seen by the model.
        steps_per_epoch: Integer.
            Total number of steps (batches of samples)
            to yield from `generator` before declaring one epoch
            finished and starting the next epoch. It should typically
            be equal to the number of samples of your dataset
            divided by the batch size.
            Optional for `Sequence`: if unspecified, will use
            the `len(generator)` as a number of steps.
        epochs: Integer. Number of epochs to train the model.
            An epoch is an iteration over the entire data provided,
            as defined by `steps_per_epoch`.
            Note that in conjunction with `initial_epoch`,
            `epochs` is to be understood as "final epoch".
            The model is not trained for a number of iterations
            given by `epochs`, but merely until the epoch
            of index `epochs` is reached.
        verbose: Integer. 0, 1, or 2. Verbosity mode.
            0 = silent, 1 = progress bar, 2 = one line per epoch.
        callbacks: List of `keras.callbacks.Callback` instances.
            List of callbacks to apply during training.
            See [callbacks](/callbacks).
        validation_data: This can be either
            - a generator for the validation data
            - tuple `(x_val, y_val)`
            - tuple `(x_val, y_val, val_sample_weights)`
            on which to evaluate
            the loss and any model metrics at the end of each epoch.
            The model will not be trained on this data.
        validation_steps: Only relevant if `validation_data`
            is a generator. Total number of steps (batches of samples)
            to yield from `validation_data` generator before stopping.
            Optional for `Sequence`: if unspecified, will use
            the `len(validation_data)` as a number of steps.
        class_weight: Optional dictionary mapping class indices (integers)
            to a weight (float) value, used for weighting the loss function
            (during training only).
            This can be useful to tell the model to
            "pay more attention" to samples from
            an under-represented class.
        max_queue_size: Integer. Maximum size for the generator queue.
            If unspecified, `max_queue_size` will default to 10.
        workers: Integer. Maximum number of processes to spin up
            when using process based threading.
            If unspecified, `workers` will default to 1. If 0, will
            execute the generator on the main thread.
        use_multiprocessing: Boolean. If True, use process based threading.
            If unspecified, `use_multiprocessing` will default to False.
            Note that because
            this implementation relies on multiprocessing,
            you should not pass
            non picklable arguments to the generator
            as they can't be passed
            easily to children processes.
        shuffle: Boolean. Whether to shuffle the training data
            in batch-sized chunks before each epoch.
            Only used with instances of `Sequence` (`keras.utils.Sequence`).
        initial_epoch: Integer.
            Epoch at which to start training
            (useful for resuming a previous training run).
    # Returns
        A `History` object. Its `History.history` attribute is
        a record of training loss values and metrics values
        at successive epochs, as well as validation loss values
        and validation metrics values (if applicable).
    # Example
    ```python
        def generate_arrays_from_file(path):
            while 1:
                with open(path) as f:
                    for line in f:
                        # create numpy arrays of input data
                        # and labels, from each line in the file
                        x1, x2, y = process_line(line)
                        yield ({'input_1': x1, 'input_2': x2}, {'output': y})
        model.fit_generator(generate_arrays_from_file('/my_file.txt'),
                            steps_per_epoch=10000, epochs=10)
    ```
    # Raises
        ValueError: In case the generator yields
            data in an invalid format.
    """
    wait_time = 0.01  # in seconds
    epoch = initial_epoch

    do_validation = bool(validation_data)
    # self._make_train_function()
    # if do_validation:
    #     self._make_test_function()

    is_sequence = isinstance(generator, Sequence)
    # do_validation = True if is_sequence else False

    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps_per_epoch is None:
        if is_sequence:
            steps_per_epoch = len(generator)
        else:
            raise ValueError('`steps_per_epoch=None` is only valid for a'
                             ' generator based on the `keras.utils.Sequence`'
                             ' class. Please specify `steps_per_epoch` or use'
                             ' the `keras.utils.Sequence` class.')

    # python 2 has 'next', 3 has '__next__'
    # avoid any explicit version checks
    val_gen = (hasattr(validation_data, 'next')
               or hasattr(validation_data, '__next__')
               or isinstance(validation_data, Sequence))
    if (val_gen and not isinstance(validation_data, Sequence)
            and not validation_steps):
        raise ValueError('`validation_steps=None` is only valid for a'
                         ' generator based on the `keras.utils.Sequence`'
                         ' class. Please specify `validation_steps` or use'
                         ' the `keras.utils.Sequence` class.')

    # Prepare display labels.
    out_labels = model._get_deduped_metrics_names()
    callback_metrics = out_labels + ['val_' + n for n in out_labels]

    # prepare callbacks
    history = cbks.History()
    callbacks = [cbks.BaseLogger()] + (callbacks or []) + [history]
    if verbose:
        callbacks += [cbks.ProgbarLogger(count_mode='steps')]
    callbacks = cbks.CallbackList(callbacks)

    # # it's possible to callback a different model than self:
    if hasattr(model, 'callback_model') and model.callback_model:
        callback_model = model.callback_model
    else:
        callback_model = model
    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': epochs,
        'steps': steps_per_epoch,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics,
    })
    callbacks.on_train_begin()

    enqueuer = None
    val_enqueuer = None

    try:
        if do_validation:
            if val_gen:
                if workers > 0:
                    if isinstance(validation_data, Sequence):
                        val_enqueuer = OrderedEnqueuer(
                            validation_data,
                            use_multiprocessing=use_multiprocessing)
                        if validation_steps is None:
                            validation_steps = len(validation_data)
                    else:
                        val_enqueuer = GeneratorEnqueuer(
                            validation_data,
                            use_multiprocessing=use_multiprocessing,
                            wait_time=wait_time)
                    val_enqueuer.start(workers=workers,
                                       max_queue_size=max_queue_size)
                    validation_generator = val_enqueuer.get()
                else:
                    validation_generator = validation_data
            else:
                pass
                # if len(validation_data) == 2:
                #     val_x, val_y = validation_data
                #     val_sample_weights = None
                # elif len(validation_data) == 3:
                #     val_x, val_y, val_sample_weights = validation_data
                # else:
                #     raise ValueError('`validation_data` should be a tuple '
                #                      '`(val_x, val_y, val_sample_weight)` '
                #                      'or `(val_x, val_y)`. Found: ' +
                #                      str(validation_data))
                # val_x, val_y, val_sample_weights = _standardize_user_data(
                #     val_x, val_y, val_sample_weight)
                # val_data = val_x + val_y + val_sample_weights
                # if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
                #     val_data += [0.]
                # for cbk in callbacks:
                #     cbk.validation_data = val_data

        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    shuffle=shuffle)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            output_generator = generator

        callback_model.stop_training = False
        # Construct epoch logs.
        epoch_logs = {}
        while epoch < epochs:
            callbacks.on_epoch_begin(epoch)
            steps_done = 0
            batch_index = 0
            while steps_done < steps_per_epoch:
                generator_output = next(output_generator)

                if not hasattr(generator_output, '__len__'):
                    raise ValueError('Output of generator should be '
                                     'batch_size lists ' +
                                     str(generator_output))

                if len(generator_output) == batch_size:
                    # ii_ndiff: the index of the negative sample
                    gen_out = generator_output
                    sample_weight = None
                else:
                    raise ValueError('Output of generator should be '
                                     'batch_size lists ' +
                                     str(generator_output))

                # build batch logs
                batch_logs = {}
                # if isinstance(x, list):
                #     batch_size = x[0].shape[0]
                # elif isinstance(x, dict):
                #     batch_size = list(x.values())[0].shape[0]
                # else:
                #     batch_size = x.shape[0]
                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size
                callbacks.on_batch_begin(batch_index, batch_logs)

                # aggregate the losses by inner index n_diff
                loss_mat = np.zeros((batch_size, N_diff))
                for ii_ndiff in range(N_diff):

                    # get the maximum sequence length
                    len_anchor_max, len_same_max, len_diff_max = \
                        get_maximum_length(batch_size=batch_size,
                                           generator_output=gen_out,
                                           index=[ii_ndiff]*batch_size)

                    print(len_anchor_max, len_same_max, len_diff_max)
                    # organize the input for the prediction
                    input_anchor, input_same, input_diff = \
                        make_same_length_batch(batch_size=batch_size,
                                               len_anchor_max=len_anchor_max,
                                               len_same_max=len_same_max,
                                               len_diff_max=len_diff_max,
                                               generator_output=gen_out,
                                               index=[ii_ndiff]*batch_size)

                    output_batch_pred = model.predict_on_batch(
                        [input_anchor, input_same, input_diff])

                    loss = K.eval(
                        triplet_loss_no_mean(output_batch_pred, margin))
                    loss_mat[:, ii_ndiff] = loss

                # this the index of the input which has the maximum loss for each N_diff pairs
                index_max_loss = np.argmax(loss_mat, axis=-1)

                len_anchor_max, len_same_max, len_diff_max = get_maximum_length(
                    batch_size=batch_size,
                    generator_output=gen_out,
                    index=index_max_loss)

                input_anchor, input_same, input_diff = \
                    make_same_length_batch(batch_size=batch_size,
                                           len_anchor_max=len_anchor_max,
                                           len_same_max=len_same_max,
                                           len_diff_max=len_diff_max,
                                           generator_output=gen_out,
                                           index=index_max_loss)

                outs = model.train_on_batch(
                    [input_anchor, input_same, input_diff],
                    None,
                    sample_weight=sample_weight,
                    class_weight=class_weight)

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)

                batch_index += 1
                steps_done += 1

                # Epoch finished.
                if steps_done >= steps_per_epoch and do_validation:
                    if val_gen:
                        val_outs = evaluate_generator(
                            model=model,
                            generator=validation_generator,
                            steps=validation_steps,
                            batch_size=batch_size,
                            margin=margin,
                            N_diff=N_diff,
                            workers=0)
                    else:
                        pass
                        # # No need for try/except because
                        # # data has already been validated.
                        # val_outs = model.evaluate(
                        #     val_x, val_y,
                        #     batch_size=batch_size,
                        #     sample_weight=val_sample_weights,
                        #     verbose=0)
                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

                if callback_model.stop_training:
                    break

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1
            if callback_model.stop_training:
                break

    finally:
        try:
            if enqueuer is not None:
                enqueuer.stop()
        finally:
            if val_enqueuer is not None:
                val_enqueuer.stop()

    callbacks.on_train_end()
    return history
Exemplo n.º 25
0
def multiple_models_generator(model1,
                              model2,
                              generator,
                              steps=None,
                              max_queue_size=10,
                              workers=1,
                              use_multiprocessing=False,
                              verbose=0):
    """See docstring for `Model.predict_generator`."""
    #    model._make_predict_function()

    steps_done = 0
    wait_time = 0.01
    all_outs = []
    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps is None:
        if is_sequence:
            steps = len(generator)
        else:
            raise ValueError('`steps=None` is only valid for a generator'
                             ' based on the `keras.utils.Sequence` class.'
                             ' Please specify `steps` or use the'
                             ' `keras.utils.Sequence` class.')
    enqueuer = None

    try:
        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = iter(generator)
            else:
                output_generator = generator

        if verbose == 1:
            progbar = Progbar(target=steps)

        while steps_done < steps:
            generator_output = next(output_generator)
            if isinstance(generator_output, tuple):
                # Compatibility with the generators
                # used for training.
                if len(generator_output) == 2:
                    x, _ = generator_output
                elif len(generator_output) == 3:
                    x, _, _ = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))
            else:
                # Assumes a generator that only
                # yields inputs (not targets and sample weights).
                x = generator_output

            outs1 = model1.predict_on_batch(x)
            outs2 = model2.predict_on_batch(x)

            nimages = outs1.shape[0]
            #            outs1 = np.reshape(outs1, (nimages, -1))
            #            outs2 = np.reshape(outs2, (nimages, -1))

            kernels = [
                27, 129, 138, 155, 195, 260, 301, 368, 406, 462, 482, 511
            ]
            outs = np.zeros((nimages, len(kernels)))
            for i in range(nimages):
                for k_ind, k in enumerate(kernels):
                    #                    (outs[i, k_ind], _) = pearsonr(outs1[i, :, :, k].flatten(), outs2[i, :, :, k].flatten())
                    outs[i, k_ind] = np.mean(outs1[i, :, :, k].flatten() -
                                             outs2[i, :, :, k].flatten())
#            import pdb
#            pdb.set_trace()

            if not isinstance(outs, list):
                outs = [outs]

            if not all_outs:
                for out in outs:
                    all_outs.append([])

            for i, out in enumerate(outs):
                all_outs[i].append(out)
            steps_done += 1
            if verbose == 1:
                progbar.update(steps_done)

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    if len(all_outs) == 1:
        if steps_done == 1:
            return all_outs[0][0]
        else:
            return np.concatenate(all_outs[0])
    if steps_done == 1:
        return [out[0] for out in all_outs]
    else:
        return [np.concatenate(out) for out in all_outs]
Exemplo n.º 26
0
    def train_srgan(self,
        epochs, batch_size,
        dataname,
        datapath_train,
        datapath_validation=None,
        steps_per_validation=10,
        datapath_test=None,
        workers=40, max_queue_size=100,
        first_epoch=0,
        print_frequency=2,
        crops_per_image=2,
        log_weight_frequency=1000,
        log_weight_path='./data/weights/',
        log_tensorboard_path='./data/logs/',
        log_tensorboard_name='ESRGAN',
        log_tensorboard_update_freq=500,
        log_test_frequency=500,
        log_test_path="./images/samples/",
        ):
        """Train the ESRGAN network
        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath_train: path for te image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        :param str log_tensorboard_name: what folder should tf logs be saved under
        """

        # Create train data loader
        loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image
        )

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image
            )
        print("Picture Loaders has been ready.")
        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            loader,
            use_multiprocessing=False,
            shuffle=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
        print("Data Enqueuer has been ready.")
        # Callback: tensorboard
        # if log_tensorboard_path:
        #     tensorboard = TensorBoard(
        #         log_dir=os.path.join(log_tensorboard_path, log_tensorboard_name),
        #         histogram_freq=0,
        #         batch_size=batch_size,
        #         write_graph=False,
        #         write_grads=False,
        #         update_freq=log_tensorboard_update_freq
        #     )
        #     tensorboard.set_model(self.srgan)
        # else:
        #     print(">> Not logging to tensorboard since no log_tensorboard_path is set")

        # Callback: format input value
        # def named_logs(model, logs):
        #     """Transform train_on_batch return value to dict expected by on_batch_end callback"""
        #     result = {}
        #     for l in zip(model.metrics_names, logs):
        #         result[l[0]] = l[1]
        #     return result

        # Shape of output from discriminator
        # disciminator_output_shape = list(self.discriminator.output_shape)
        # disciminator_output_shape[0] = batch_size
        # disciminator_output_shape = tuple(disciminator_output_shape)

        # # # VALID / FAKE targets for discriminator
        # real = np.ones(disciminator_output_shape)
        # fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)

        # Loop through epochs / iterations
        for epoch in range(first_epoch, epochs + first_epoch):
            # Start epoch time
            if epoch % print_frequency == 1:
                start_epoch = datetime.datetime.now()

                # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)
            # SRGAN's loss (don't use them)
            # real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            # fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            # discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train Relativistic Discriminator
            discriminator_loss = self.RaGAN.train_on_batch([imgs_hr, generated_hr], None)

            # Train generator
            # features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch([imgs_lr, imgs_hr], None)

            # Callbacks
            # logs = named_logs(self.srgan, generator_loss)
            # tensorboard.on_epoch_end(epoch, logs)
            # print(generator_loss, discriminator_loss)
            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(self.srgan.metrics_names, g_avg_loss)
                print(self.RaGAN.metrics_names, d_avg_loss)
                print("\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}".format(
                    epoch, epochs + first_epoch,
                    (datetime.datetime.now() - start_epoch).seconds,
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.srgan.metrics_names, g_avg_loss)]),
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.RaGAN.metrics_names, d_avg_loss)])
                ))
                print_losses = {"G": [], "D": []}
                # Run validation inference if specified
                # if datapath_validation:
                #     print(">> Running validation inference")
                #     validation_losses = self.generator.evaluate_generator(
                #         validation_loader,
                #         steps=steps_per_validation,
                #         use_multiprocessing=workers>1,
                #         workers=workers
                #     )
                #     print(">> Validation Losses: {}".format(
                #         ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.generator.metrics_names, validation_losses)])
                #     ))

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                print(">> Ploting test images")
                plot_test_images(self, loader, datapath_test, log_test_path, epoch, refer_model=self.refer_model)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                print(">> Saving the network weights")
                self.save_weights(os.path.join(log_weight_path, dataname), epoch)
Exemplo n.º 27
0
def analysis_generator(model,
                       generator,
                       steps=None,
                       max_queue_size=10,
                       workers=1,
                       use_multiprocessing=False,
                       verbose=0):
    """See docstring for `Model.predict_generator`."""
    model._make_predict_function()

    steps_done = 0
    wait_time = 0.01
    all_outs = []
    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    if steps is None:
        if is_sequence:
            steps = len(generator)
        else:
            raise ValueError('`steps=None` is only valid for a generator'
                             ' based on the `keras.utils.Sequence` class.'
                             ' Please specify `steps` or use the'
                             ' `keras.utils.Sequence` class.')
    enqueuer = None

    try:
        if workers > 0:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    generator, use_multiprocessing=use_multiprocessing)
            else:
                enqueuer = GeneratorEnqueuer(
                    generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()
        else:
            if is_sequence:
                output_generator = iter_sequence_infinite(generator)
            else:
                output_generator = generator

        if verbose == 1:
            progbar = Progbar(target=steps)

        while steps_done < steps:
            generator_output = next(output_generator)
            if isinstance(generator_output, tuple):
                # Compatibility with the generators
                # used for training.
                if len(generator_output) == 2:
                    x, _ = generator_output
                elif len(generator_output) == 3:
                    x, _, _ = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))
            else:
                # Assumes a generator that only
                # yields inputs (not targets and sample weights).
                x = generator_output

            outs = model.predict_on_batch(x)
            outs = to_list(outs)

            if not all_outs:
                for out in outs:
                    all_outs.append([])

            for i, out in enumerate(outs):
                all_outs[i].append(out)
            steps_done += 1
            if verbose == 1:
                progbar.update(steps_done)

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    if len(all_outs) == 1:
        if steps_done == 1:
            return all_outs[0][0]
        else:
            return np.concatenate(all_outs[0])
    if steps_done == 1:
        return [out[0] for out in all_outs]
    else:
        return [np.concatenate(out) for out in all_outs]
Exemplo n.º 28
0
    def train_srgan(self,
                    epochs,
                    batch_size,
                    dataname,
                    datapath_train,
                    datapath_validation=None,
                    steps_per_epoch=10000,
                    steps_per_validation=100,
                    datapath_test=None,
                    workers=16,
                    max_queue_size=10,
                    first_step=0,
                    print_frequency=1,
                    crops_per_image=4,
                    log_weight_frequency=None,
                    log_weight_path='./data/weights/',
                    log_tensorboard_path='./data/logs/',
                    log_tensorboard_name='SRGAN',
                    log_tensorboard_update_freq=10000,
                    log_test_frequency=1,
                    log_test_path="./images/samples/",
                    job_dir=None):
        """Train the SRGAN network

        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved        
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        :param str log_tensorboard_name: what folder should tf logs be saved under        
        """

        # Create train data loader
        loader = DataLoader(datapath_train, batch_size, self.height_hr,
                            self.width_hr, self.upscaling_factor,
                            crops_per_image)

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(datapath_validation, batch_size,
                                           self.height_hr, self.width_hr,
                                           self.upscaling_factor,
                                           crops_per_image)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(loader,
                                   use_multiprocessing=True,
                                   shuffle=True)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, log_tensorboard_name),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=False,
                                      write_grads=False,
                                      update_freq=log_tensorboard_update_freq)
            tensorboard.set_model(self.srgan)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)

        # Some dummy variables to track
        current_epoch = 1
        logs = {}

        # Loop through epochs / iterations
        for step in range(0, steps_per_epoch * int(epochs)):

            # Epoch change e.g. steps_per_epoch steps completed
            epoch_change = False
            if step > current_epoch * steps_per_epoch:
                current_epoch += 1
                epoch_change = True

            # print('Step {}, Current Epoch {}'.format(step, current_epoch))

            # Start epoch time
            if epoch_change:
                start_epoch = datetime.datetime.now()

            # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)
            real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch(imgs_lr,
                                                       [real, features_hr])

            # Callbacks
            if logs and not epoch_change:
                for k, v in named_logs(self.srgan, generator_loss).items():
                    logs[k] += v
            else:
                if logs and epoch_change:
                    tensorboard.on_epoch_end(step + first_step, logs)

                logs = named_logs(self.srgan, generator_loss)

            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch_change and current_epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(
                    "\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}"
                    .format(
                        current_epoch, epochs,
                        (datetime.datetime.now() - start_epoch).seconds,
                        ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.srgan.metrics_names, g_avg_loss)
                        ]), ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.discriminator.metrics_names, d_avg_loss)
                        ])))
                print_losses = {"G": [], "D": []}

                # Run validation inference if specified
                if datapath_validation:
                    validation_losses = self.generator.evaluate_generator(
                        validation_loader,
                        steps=steps_per_validation,
                        use_multiprocessing=workers > 1,
                        workers=workers)
                    print(">> Validation Losses: {}".format(", ".join([
                        "{}={:.4f}".format(k, v) for k, v in zip(
                            self.generator.metrics_names, validation_losses)
                    ])))

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch_change and current_epoch % log_test_frequency == 0:
                plot_test_images(self, loader, datapath_test, log_test_path,
                                 current_epoch)

            # Check if we should save the network weights
            if log_weight_frequency and epoch_change and current_epoch % log_weight_frequency == 0:
                # Save the network weights
                self.save_weights(log_weight_path, dataname)
Exemplo n.º 29
0
def dragonn_predict_generator(model, generator,
                              steps=None,
                              max_queue_size=10,
                              workers=1,
                              use_multiprocessing=False,
                              verbose=1):
    """See docstring for `Model.predict_generator`."""
    model._make_predict_function()
    generator_indices=generator.indices
    prediction_indices=None
    batch_size=generator.batch_size 
    steps_done = 0
    wait_time = 0.01
    all_outs = []
    steps=len(generator)
    enqueuer = OrderedEnqueuer(
        generator,
        use_multiprocessing=use_multiprocessing)
    enqueuer.start(workers=workers, max_queue_size=max_queue_size)
    output_generator = enqueuer.get()
    if verbose == 1:
        progbar = Progbar(target=steps)
    try:
        while steps_done < steps:
            generator_output = next(output_generator)
            #print("got batch") 
            if isinstance(generator_output, tuple):
                # Compatibility with the generators
                # used for training.
                if len(generator_output) == 2:
                    x, idx = generator_output
                elif len(generator_output) == 3:
                    x, y, idx = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, idx)` '
                                     'or `(x, idx)`. Found: ' +
                                     str(generator_output))
            else:
                raise ValueError('Output of generator should be '
                                 'a tuple `(x, y, idx)` '
                                 'or `(x, idx)`. Found: ' +
                                 str(generator_output))
            outs = model.predict_on_batch(x)
            cur_inds=generator_indices[idx*batch_size:(idx+1)*batch_size]
            if prediction_indices is None:
                prediction_indices=cur_inds
            else:
                prediction_indices=np.concatenate((prediction_indices,cur_inds),axis=0)
            outs = to_list(outs)
            if not all_outs:
                for out in outs:
                    all_outs.append([])

            for i, out in enumerate(outs):
                all_outs[i].append(out)
            steps_done += 1
            if verbose == 1:
                progbar.update(steps_done)
    except:
        print("Error, stopping enqueuer") 
        enqueuer.stop()
        print("exiting")
    enqueuer.stop()
    
    if len(all_outs) == 1:
        if steps_done == 1:
            return (all_outs[0][0],prediction_indices)
        else:
            return (np.concatenate(all_outs[0]),prediction_indices)
    if steps_done == 1:
        return ([out[0] for out in all_outs],prediction_indices)
    else:
        return ([np.concatenate(out) for out in all_outs],prediction_indices)