コード例 #1
0
    def __init__(self, models, generators, workers=1, shuffle=True,
            max_queue_size=10, print_full_losses=False):

        assert len(models) == len(generators), \
                'ValueError: models and generators should be lists of same size'

        if type(workers) is not list:
            workers = len(models)*[workers]

        self.models = models
        self.output_generators = []
        self.batch_logs = {}
        self.print_full_losses = print_full_losses

        metric_names = []

        batch_size = 0
        for i in range(len(models)):
            assert isinstance(generators[i], BatchLoader), \
                    'Only BatchLoader class is supported'
            batch_size += generators[i].get_batch_size()
            enqueuer = OrderedEnqueuer(generators[i], shuffle=shuffle)
            enqueuer.start(workers=workers[i], max_queue_size=max_queue_size)
            self.output_generators.append(enqueuer.get())

            metric_names.append('loss%d' % i)
            if self.print_full_losses:
                for out in models[i].outputs:
                    metric_names.append(out.name.split('/')[0])

        self.batch_logs['size'] = batch_size
        self.metric_names = metric_names
コード例 #2
0
 def test(self, X):
     l = 0.
     p = 0.
     num_samples = 0
     if isinstance(X, tuple):
         y, A = X[0]
         x = X[1]
         for batch_start in tqdm.tqdm(
                 range(0, y.shape[0], self.config['batch_size'])):
             batch_end = np.min(
                 [batch_start + self.config['batch_size'], y.shape[0]])
             y_batch, A_batch, x_batch = y[batch_start:batch_end], A[
                 batch_start:batch_end], x[batch_start:batch_end]
             batch = ([y_batch, A_batch], x_batch)
             pred_batch = []
             for i in range(batch_start, batch_end):
                 pred = solve(
                     y_batch[i],
                     A_batch[i],
                     (1, self.config['shape1'], self.config['shape2'],
                      self.config['shape3'], self.config['shape4']),
                     rho=self.config['rho'],
                     max_iter=self.config['max_iter'])
                 pred_batch.append(pred)
             pred_batch = np.concatenate(pred_batch, axis=0)
             loss = weighted(self.loss, x_batch, pred_batch)
             psnr = weighted(self.psnr, x_batch, pred_batch)
             l += (y_batch.shape[0] * loss)
             p += (y_batch.shape[0] * psnr)
             num_samples += y_batch.shape[0]
     elif isinstance(X, Sequence):
         if self.config['workers'] > 0:
             enqueuer = OrderedEnqueuer(X, use_multiprocessing=False)
             enqueuer.start(workers=self.config['workers'],
                            max_queue_size=self.config['max_queue_size'])
             output_generator = enqueuer.get()
         else:
             output_generator = X
         l = 0.
         for steps_done in tqdm.tqdm(range(len(X))):
             generator_output = next(output_generator)
             (y_batch, A_batch), x_batch = generator_output
             pred_batch = []
             for i in range(0, y_batch.shape[0]):
                 pred = solve(
                     y_batch[i],
                     A_batch[i],
                     (1, self.config['shape1'], self.config['shape2'],
                      self.config['shape3'], self.config['shape4']),
                     rho=self.config['rho'],
                     max_iter=self.config['max_iter'])
                 pred_batch.append(pred)
             pred_batch = np.concatenate(pred_batch, axis=0)
             loss = weighted(self.loss, x_batch, pred_batch)
             psnr = weighted(self.psnr, x_batch, pred_batch)
             l += (y_batch.shape[0] * loss)
             p += (y_batch.shape[0] * psnr)
             num_samples += y_batch.shape[0]
     return [(l / num_samples), (p / num_samples)]
コード例 #3
0
 def predict(self, X):
     outputs = []
     if isinstance(X, tuple) or isinstance(X, list):
         if isinstance(X, tuple):
             y, A = X[0]
         elif isinstance(X, list):
             y, A = X
         for batch_start in tqdm.tqdm(
                 range(0, y.shape[0], self.config['batch_size'])):
             batch_end = np.min(
                 [batch_start + self.config['batch_size'], y.shape[0]])
             y_batch, A_batch = y[batch_start:batch_end], A[
                 batch_start:batch_end]
             pred_batch = []
             for i in range(batch_start, batch_end):
                 print(batch_start, i)
                 pred = solve(
                     y_batch[i],
                     A_batch[i],
                     (1, self.config['shape1'], self.config['shape2'],
                      self.config['shape3'], self.config['shape4']),
                     rho=self.config['rho'],
                     max_iter=self.config['max_iter'])
                 pred_batch.append(pred)
             pred_batch = np.concatenate(pred_batch, axis=0)
             outputs.append(pred_batch)
     elif isinstance(X, Sequence):
         if self.config['workers'] > 0:
             enqueuer = OrderedEnqueuer(X, use_multiprocessing=False)
             enqueuer.start(workers=self.config['workers'],
                            max_queue_size=self.config['max_queue_size'])
             output_generator = enqueuer.get()
         else:
             output_generator = X
         for steps_done in tqdm.tqdm(range(len(X))):
             generator_output = next(output_generator)
             (y_batch, A_batch), _ = generator_output
             pred_batch = []
             for i in range(0, y_batch.shape[0]):
                 pred = solve(
                     y_batch[i],
                     A_batch[i],
                     (1, self.config['shape1'], self.config['shape2'],
                      self.config['shape3'], self.config['shape4']),
                     rho=self.config['rho'],
                     max_iter=self.config['max_iter'])
                 pred_batch.append(pred)
             pred_batch = np.concatenate(pred_batch, axis=0)
             outputs.append(pred_batch)
     return np.concatenate(outputs, axis=0)
コード例 #4
0
def generator():
    from keras.preprocessing.image import ImageDataGenerator
    from keras.utils import OrderedEnqueuer
    gen = ImageDataGenerator(data_format='channels_first',
                             rescale=1. / 255,
                             fill_mode='nearest').flow_from_directory(
                                 data_dir + '/train',
                                 target_size=(height, width),
                                 batch_size=batch_size)
    enqueuer = OrderedEnqueuer(gen, use_multiprocessing=False)
    enqueuer.start(workers=16)
    n_classes = gen.num_classes

    while True:
        batch_xs, batch_ys = next(enqueuer.get())
        yield batch_xs, batch_ys
コード例 #5
0
ファイル: data_utils_test.py プロジェクト: heechul90/Keras
def test_missing_inputs():
    missing_idx = 10

    class TimeOutSequence(DummySequence):
        def __getitem__(self, item):
            if item == missing_idx:
                time.sleep(120)
            return super(TimeOutSequence, self).__getitem__(item)

    enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs(
        TimeOutSequence([3, 2, 2, 3])),
                                 use_multiprocessing=True)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    with pytest.warns(UserWarning, match='An input could not be retrieved.'):
        for _ in range(4 * missing_idx):
            next(gen_output)

    enqueuer = OrderedEnqueuer(TimeOutSequence([3, 2, 2, 3]),
                               use_multiprocessing=True)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    warning_msg = "The input {} could not be retrieved.".format(missing_idx)
    with pytest.warns(UserWarning, match=warning_msg):
        for _ in range(11):
            next(gen_output)
コード例 #6
0
ファイル: utils.py プロジェクト: zjj-2015/land-cover
    def __init__(self, data_generator, batch_size, num_samples, output_dir,
                 superres):
        def read_file_colormap(file_path):
            out_list = []
            with open(file_path) as color_map:
                csv_color_map = csv.reader(color_map)
                next(csv_color_map)
                for row in csv_color_map:
                    out_list.append((int(row[0]), (row[1], (row[2:]))))
            return collections.OrderedDict(out_list)

        def to_matplotlib_colormap(ordered_dict):
            def rgb(r, g, b):
                def clamp(x):
                    return max(0, min(int(x), 255))

                return "#{0:02x}{1:02x}{2:02x}".format(clamp(r), clamp(g),
                                                       clamp(b))

            return matplotlib.colors.ListedColormap([
                rgb(*ordered_dict[i][1]) if i in ordered_dict else "#000000"
                for i in ordered_dict
            ])

        super().__init__()
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.superres = superres
        self.enqueuer = OrderedEnqueuer(data_generator,
                                        use_multiprocessing=True,
                                        shuffle=False)
        self.enqueuer.start(workers=4, max_queue_size=4)
        self.writer = tf.summary.create_file_writer(output_dir)

        self.hr_classes = read_file_colormap(config.HR_COLOR)
        assert (len(self.hr_classes) == config.HR_NCLASSES -
                1), f"Wrong HR color map {config.HR_COLOR}"
        self.sr_classes = read_file_colormap(config.LR_COLOR)
        assert (len(self.sr_classes) == config.LR_NCLASSES -
                1), f"Wrong SR color map {config.LR_COLOR}"

        self.hr_classes_cmap = to_matplotlib_colormap(self.hr_classes)
        self.sr_classes_cmap = to_matplotlib_colormap(self.sr_classes)
コード例 #7
0
ファイル: data_utils_test.py プロジェクト: zqy1/keras
def test_on_epoch_end_processes():
    enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=True)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    acc = []
    for i in range(200):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc[100:] == list([k * 5 for k in range(100)]), "Order was not keep in GeneratorEnqueuer with processes"
    enqueuer.stop()
コード例 #8
0
def test_ordered_enqueuer_processes():
    enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    acc = []
    for i in range(100):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc == list(range(100)), "Order was not keep in GeneratorEnqueuer with processes"
    enqueuer.stop()
コード例 #9
0
ファイル: data_utils_test.py プロジェクト: heechul90/Keras
def test_ordered_enqueuer_threads():
    enqueuer = OrderedEnqueuer(DummySequence([3, 10, 10, 3]),
                               use_multiprocessing=False)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    acc = []
    for i in range(100):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc == list(range(100)), ('Order was not keep in GeneratorEnqueuer '
                                     'with threads')
    enqueuer.stop()
コード例 #10
0
def Get_data(datapath,hight, width, batch_size):
    generator = image.ImageDataGenerator(
            rescale = 1./255,
            featurewise_center=False,  # set input mean to 0 over the dataset
            samplewise_center=False,  # set each sample mean to 0
            featurewise_std_normalization=False,  # divide inputs by std of the dataset
            samplewise_std_normalization=False,  # divide each input by its std
            zca_whitening=False,  # apply ZCA whitening
            rotation_range=10,  # randomly rotate images in the range (degrees, 0 to 180)
            width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
            height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
            horizontal_flip=True,  # randomly flip images
            vertical_flip=False)
    dataset = generator.flow_from_directory(
        shuffle = True,
        batch_size = batch_size,
        target_size = (hight, width),
        directory = datapath)
    enqueuer = OrderedEnqueuer(dataset,use_multiprocessing=False,shuffle=True)
    enqueuer.start(workers=1, max_queue_size=10)
    output_generator = enqueuer.get()

    return output_generator
コード例 #11
0
ファイル: data_utils_test.py プロジェクト: Bjoux2/keras
def test_on_epoch_end_processes():
    enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=True)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    acc = []
    for i in range(200):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc[100:] == list([k * 5 for k in range(100)]), "Order was not keep in GeneratorEnqueuer with processes"
    enqueuer.stop()
コード例 #12
0
ファイル: data_utils_test.py プロジェクト: heechul90/Keras
def test_on_epoch_end_threads_sequence_change_length():
    seq = LengthChangingSequence([3, 10, 10, 3])
    enqueuer = OrderedEnqueuer(seq, use_multiprocessing=False)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    acc = []
    for i in range(100):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc == list(range(100)), ('Order was not keep in GeneratorEnqueuer '
                                     'with threads')

    enqueuer.join_end_of_epoch()
    assert len(seq) == 50
    acc = []
    for i in range(50):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc == list([
        k * 5 for k in range(50)
    ]), ('Order was not keep in GeneratorEnqueuer with processes')
    enqueuer.stop()
コード例 #13
0
def DISABLED_test_on_epoch_end_threads():
    enqueuer = OrderedEnqueuer(DummySequence([3, 10, 10, 3]),
                               use_multiprocessing=False)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    acc = []
    for i in range(100):
        acc.append(next(gen_output)[0, 0, 0, 0])
    acc = []
    for i in range(100):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc == list([k * 5 for k in range(100)]), (
        'Order was not keep in GeneratorEnqueuer with processes')
    enqueuer.stop()
コード例 #14
0
ファイル: data_utils_test.py プロジェクト: Bjoux2/keras
def test_ordered_enqueuer_threads_not_ordered():
    enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]),
                               use_multiprocessing=False,
                               shuffle=True)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    acc = []
    for i in range(100):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc != list(range(100)), "Order was not keep in GeneratorEnqueuer with threads"
    enqueuer.stop()
コード例 #15
0
ファイル: callbacks.py プロジェクト: payne911/SR-ResCNN-Keras
    def __init__(self, data_generator, m_batch_size, num_samples, output_dir, normalization_mean):
        super().__init__()
        self.epoch_index = 0
        self.data_generator = data_generator
        self.batch_size = m_batch_size
        self.num_samples = num_samples
        self.tensorboard_writer = TensorBoardWriter(output_dir)
        self.normalization_mean = normalization_mean
        is_sequence = isinstance(self.data_generator, Sequence)
        if is_sequence:
            self.enqueuer = OrderedEnqueuer(self.data_generator,
                                            use_multiprocessing=True,
                                            shuffle=False)
        else:
            self.enqueuer = GeneratorEnqueuer(self.data_generator,
                                              use_multiprocessing=False,  # todo: how to 'True' ?
                                              wait_time=0.01)
        # todo: integrate the Sequence generator properly
#        import multiprocessing
#        self.enqueuer.start(workers=multiprocessing.cpu_count(), max_queue_size=4)
        self.enqueuer.start(workers=1, max_queue_size=4)
コード例 #16
0
def test_ordered_enqueuer_timeout_threads():
    enqueuer = OrderedEnqueuer(SlowSequence([3, 10, 10, 3]),
                               use_multiprocessing=False)

    def handler(signum, frame):
        raise TimeoutError('Sequence deadlocked')

    old = signal.signal(signal.SIGALRM, handler)
    signal.setitimer(signal.ITIMER_REAL, 60)
    with pytest.warns(UserWarning) as record:
        enqueuer.start(5, 10)
        gen_output = enqueuer.get()
        for epoch_num in range(2):
            acc = []
            for i in range(10):
                acc.append(next(gen_output)[0, 0, 0, 0])
            assert acc == list(range(10)), 'Order was not keep in ' \
                                           'OrderedEnqueuer with threads'
        enqueuer.stop()
    assert len(record) == 1
    assert str(record[0].message) == 'The input 0 could not be retrieved. ' \
                                     'It could be because a worker has died.'
    signal.setitimer(signal.ITIMER_REAL, 0)
    signal.signal(signal.SIGALRM, old)
コード例 #17
0
            import matplotlib
            import matplotlib.pyplot as plt
            # clear view image
            plt.imshow((np.squeeze(data[0][0]) / 2.0) + 0.5)
            plt.draw()
            plt.pause(0.25)
            # current timestep image
            plt.imshow((np.squeeze(data[0][1]) / 2.0) + 0.5)
            plt.draw()
            plt.pause(0.25)
            # uncomment the following line to wait for
            # one window to be closed before showing the next
            # plt.show()
    # a = next(training_generator)
    enqueuer = OrderedEnqueuer(training_generator,
                               use_multiprocessing=False,
                               shuffle=True)
    enqueuer.start(workers=1, max_queue_size=1)
    generator = iter(enqueuer.get())
    print("-------------------")
    generator_ouput = next(generator)
    print("-------------------op")
    x, y = generator_ouput
    print("x-shape-----------", x.shape)
    print("y-shape---------", y.shape)

    # X,y=training_generator.__getitem__(1)
    #print(X.keys())
    # print(X[0].shape)
    # print(X[0].shape)
    # print(y[0])
コード例 #18
0
ファイル: data_utils_test.py プロジェクト: poolio/keras
def test_ordered_enqueuer_fail_threads():
    enqueuer = OrderedEnqueuer(FaultSequence(), use_multiprocessing=False)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    with pytest.raises(IndexError):
        next(gen_output)
コード例 #19
0
# fix settings file
from src.settings.settings import Settings

WORKER = 3
N = 1000

Settings.FILE = "../../settings/settings-sebastian.yml"

if __name__ == "__main__":
    sequence = TrainSequence(32, input_caption=False)
    iterations = N #len(sequence)


    print("using keras utils enqueuer")
    from keras.utils import OrderedEnqueuer
    enq = OrderedEnqueuer(sequence, use_multiprocessing=True, shuffle=False)
    print("starting enqueuer")
    enq.start(WORKER, 20)

    batches = 0
    while batches < iterations:
        batch = enq.get()
        next(batch)
        batches += 1

        if batches % 100 == 0:
            print("Processed {} batches".format(batches))

    print("stopping enqueuer")
    enq.stop(1.5)
コード例 #20
0
if __name__ == '__main__':
    nlp = spacy.load('en_core_web_sm',
                     disable=['vectors', 'textcat', 'tagger', 'ner'])
    ds = TextDataset('dev-v1.1.jsonl').map(json.loads) \
        .map(lambda x: [token.text for token in nlp(x['question'])
                        if not token.is_space])

    # PyTorch
    print('PyTorch')
    loader = DataLoader(ds, batch_size=3, num_workers=4, shuffle=True)
    it = iter(loader)
    print(next(it))
    del it

    # Chainer
    print('Chainer')
    it = MultiprocessIterator(ds, batch_size=3, n_processes=4, shuffle=True)
    print(next(it))
    it.finalize()

    # Keras
    print('Keras')
    sequence = TextSequence(ds, batch_size=3)
    enqueuer = OrderedEnqueuer(sequence,
                               use_multiprocessing=True,
                               shuffle=True)
    enqueuer.start()
    it = enqueuer.get()
    print(next(it))
    enqueuer.stop()
コード例 #21
0
ファイル: data_utils_test.py プロジェクト: Bjoux2/keras
def test_context_switch():
    enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=True)
    enqueuer2 = OrderedEnqueuer(DummySequence([3, 200, 200, 3], value=15), use_multiprocessing=True)
    enqueuer.start(3, 10)
    enqueuer2.start(3, 10)
    gen_output = enqueuer.get()
    gen_output2 = enqueuer2.get()
    acc = []
    for i in range(100):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc[-1] == 99
    # One epoch is completed so enqueuer will switch the Sequence

    acc = []
    for i in range(100):
        acc.append(next(gen_output2)[0, 0, 0, 0])
    assert acc[-1] == 99 * 15
    # One epoch has been completed so enqueuer2 will switch

    # Be sure that both Sequence were updated
    assert next(gen_output)[0, 0, 0, 0] == 0
    assert next(gen_output)[0, 0, 0, 0] == 5
    assert next(gen_output2)[0, 0, 0, 0] == 0
    assert next(gen_output2)[0, 0, 0, 0] == 15 * 5

    # Tear down everything
    enqueuer.stop()
    enqueuer2.stop()
コード例 #22
0
def run(image_shape, data_dir, valid_pairs, classes, num_classes, architecture,
        weights, batch_size, last_base_layer, pooling, device,
        predictions_activation, dropout_rate, ckpt, validation_steps,
        use_multiprocessing, use_gram_matrix, dense_layers, embedding_units,
        limb_weights, trainable_limbs):
    if isinstance(classes, int):
        classes = sorted(os.listdir(os.path.join(data_dir, 'train')))[:classes]

    g = ImageDataGenerator(
        preprocessing_function=utils.get_preprocess_fn(architecture))
    valid_data = BalancedDirectoryPairsSequence(os.path.join(
        data_dir, 'valid'),
                                                g,
                                                target_size=image_shape[:2],
                                                pairs=valid_pairs,
                                                classes=classes,
                                                batch_size=batch_size)
    if validation_steps is None:
        validation_steps = len(valid_data)

    with tf.device(device):
        print('building...')
        model = build_siamese_model(
            image_shape,
            architecture,
            dropout_rate,
            weights,
            num_classes,
            last_base_layer,
            use_gram_matrix,
            dense_layers,
            pooling,
            include_base_top=False,
            include_top=True,
            predictions_activation=predictions_activation,
            limb_weights=limb_weights,
            trainable_limbs=trainable_limbs,
            embedding_units=embedding_units,
            joints='multiply')
        print('siamese model summary:')
        model.summary()
        if ckpt:
            print('loading weights...')
            model.load_weights(ckpt)

        enqueuer = None
        try:
            enqueuer = OrderedEnqueuer(valid_data,
                                       use_multiprocessing=use_multiprocessing)
            enqueuer.start()
            output_generator = enqueuer.get()

            y, p = [], []
            for step in range(validation_steps):
                x, _y = next(output_generator)
                _p = model.predict(x, batch_size=batch_size)
                y.append(_y)
                p.append(_p)

            y, p = (np.concatenate(e).flatten() for e in (y, p))

            print('actual:', y[:80])
            print('expected:', p[:80])
            print('accuracy:', metrics.accuracy_score(y, p >= 0.5))
            print(metrics.classification_report(y, p >= 0.5))
            print(metrics.confusion_matrix(y, p >= 0.5))

        finally:
            if enqueuer is not None:
                enqueuer.stop()
コード例 #23
0
    def fit_with_pseudo_label(self,
                              steps_per_epoch,
                              validation_steps=None,
                              use_checkpoints=True,
                              class_labels=None,
                              verbose=1,
                              use_multiprocessing=False,
                              shuffle=False,
                              workers=1,
                              max_queue_size=10):

        # Default value if validation steps is none
        if (validation_steps == None):
            validation_steps = self.validation_generator.samples // self.batch_size

        wait_time = 0.01  # in seconds

        self.model._make_train_function()

        # Create a checkpoint callback
        checkpoint = ModelCheckpoint("../models_checkpoints/" +
                                     str(self.h5_filename) + ".h5",
                                     monitor='val_acc',
                                     verbose=1,
                                     save_best_only=True,
                                     save_weights_only=True,
                                     mode='auto',
                                     period=1)

        # Generate callbacks
        callback_list = []
        if use_checkpoints:
            callback_list.append(checkpoint)

        # Init train counters
        epoch = 0

        validation_data = self.validation_generator
        do_validation = bool(validation_data)
        self.model._make_train_function()
        if do_validation:
            self.model._make_test_function()

        val_gen = (hasattr(validation_data, 'next')
                   or hasattr(validation_data, '__next__')
                   or isinstance(validation_data, Sequence))
        if (val_gen and not isinstance(validation_data, Sequence)
                and not validation_steps):
            raise ValueError('`validation_steps=None` is only valid for a'
                             ' generator based on the `keras.utils.Sequence`'
                             ' class. Please specify `validation_steps` or use'
                             ' the `keras.utils.Sequence` class.')

        # Prepare display labels.
        out_labels = self.model.metrics_names
        callback_metrics = out_labels + ['val_' + n for n in out_labels]

        # Prepare train callbacks
        self.model.history = cbks.History()
        callbacks = [cbks.BaseLogger()] + (callback_list or []) + \
            [self.model.history]
        if verbose:
            callbacks += [cbks.ProgbarLogger(count_mode='steps')]
        callbacks = cbks.CallbackList(callbacks)

        # it's possible to callback a different model than self:
        if hasattr(self.model, 'callback_model') and self.model.callback_model:
            callback_model = self.model.callback_model

        else:
            callback_model = self.model

        callbacks.set_model(callback_model)

        is_sequence = isinstance(self.train_generator, Sequence)
        if not is_sequence and use_multiprocessing and workers > 1:
            warnings.warn(
                UserWarning('Using a generator with `use_multiprocessing=True`'
                            ' and multiple workers may duplicate your data.'
                            ' Please consider using the`keras.utils.Sequence'
                            ' class.'))

        if is_sequence:
            steps_per_epoch = len(self.train_generator)

        enqueuer = None
        val_enqueuer = None

        callbacks.set_params({
            'epochs': self.epochs,
            'steps': steps_per_epoch,
            'verbose': verbose,
            'do_validation': do_validation,
            'metrics': callback_metrics,
        })
        callbacks.on_train_begin()

        try:
            if do_validation and not val_gen:
                # Prepare data for validation
                if len(validation_data) == 2:
                    val_x, val_y = validation_data
                    val_sample_weight = None
                elif len(validation_data) == 3:
                    val_x, val_y, val_sample_weight = validation_data
                else:
                    raise ValueError('`validation_data` should be a tuple '
                                     '`(val_x, val_y, val_sample_weight)` '
                                     'or `(val_x, val_y)`. Found: ' +
                                     str(validation_data))
                val_x, val_y, val_sample_weights = self.model._standardize_user_data(
                    val_x, val_y, val_sample_weight)
                val_data = val_x + val_y + val_sample_weights
                if self.model.uses_learning_phase and not isinstance(
                        K.learning_phase(), int):
                    val_data += [0.]
                for cbk in callbacks:
                    cbk.validation_data = val_data

            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    self.train_generator,
                    use_multiprocessing=use_multiprocessing,
                    shuffle=shuffle)
            else:
                enqueuer = GeneratorEnqueuer(
                    self.train_generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()

            # Train the model

            # Construct epoch logs.
            epoch_logs = {}
            # Epochs
            while epoch < self.epochs:
                callbacks.on_epoch_begin(epoch)
                steps_done = 0
                batch_index = 0

                # Steps per epoch
                while steps_done < steps_per_epoch:

                    generator_output = next(output_generator)

                    if len(generator_output) == 2:
                        x, y = generator_output
                        sample_weight = None
                    elif len(generator_output) == 3:
                        x, y, sample_weight = generator_output
                    else:
                        raise ValueError('Output of generator should be '
                                         'a tuple `(x, y, sample_weight)` '
                                         'or `(x, y)`. Found: ' +
                                         str(generator_output))

                    #==========================
                    # Mini-batch
                    #==========================
                    if (self.print_pseudo_generate):
                        print ''
                        print 'Generating pseudo-labels...'
                        verbose = 1
                    else:
                        verbose = 0

                    if self.no_label_generator.samples > 0:
                        no_label_output = self.model.predict_generator(
                            self.no_label_generator,
                            self.no_label_generator.samples,
                            verbose=verbose)

                        # One-hot encoded
                        self.no_label_generator.classes = np.argmax(
                            no_label_output, axis=1)

                        # Concat Pseudo labels with true labels
                        x_pseudo, y_pseudo = next(self.no_label_generator)
                        x, y = np.concatenate((x, x_pseudo),
                                              axis=0), np.concatenate(
                                                  (y, y_pseudo), axis=0)

                    # build batch logs
                    batch_logs = {}
                    if isinstance(x, list):
                        batch_size = x[0].shape[0]
                    elif isinstance(x, dict):
                        batch_size = list(x.values())[0].shape[0]
                    else:
                        batch_size = x.shape[0]
                    batch_logs['batch'] = batch_index
                    batch_logs['size'] = batch_size
                    callbacks.on_batch_begin(batch_index, batch_logs)

                    # Runs a single gradient update on a single batch of data
                    scalar_training_loss = self.model.train_on_batch(x=x, y=y)

                    if not isinstance(scalar_training_loss, list):
                        scalar_training_loss = [scalar_training_loss]
                    for l, o in zip(out_labels, scalar_training_loss):
                        batch_logs[l] = o

                    callbacks.on_batch_end(batch_index, batch_logs)

                    #==========================
                    # end Mini-batch
                    #==========================

                    batch_index += 1
                    steps_done += 1

                if steps_done >= steps_per_epoch and do_validation:
                    if val_gen:
                        val_outs = self.model.evaluate_generator(
                            validation_data,
                            validation_steps,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing,
                            max_queue_size=max_queue_size)
                    else:
                        # No need for try/except because
                        # data has already been validated.
                        val_outs = self.model.evaluate(
                            val_x,
                            val_y,
                            batch_size=batch_size,
                            sample_weight=val_sample_weights,
                            verbose=0)
                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

                # Epoch finished.
                callbacks.on_epoch_end(epoch, epoch_logs)
                epoch += 1

        finally:
            try:
                if enqueuer is not None:
                    enqueuer.stop()
            finally:
                if val_enqueuer is not None:
                    val_enqueuer.stop()

        callbacks.on_train_end()
        return self.model.history
コード例 #24
0
def predict_labels_generator(model,
                             generator,
                             steps=None,
                             max_queue_size=10,
                             workers=1,
                             verbose=1):
    """Reimplementation of the Keras function `predict_generator`, to return
    also the labels given by the generator.
    """
    model._make_predict_function()

    steps_done = 0
    all_outs = []
    all_labels = []

    if steps is None:
        steps = len(generator)
    enqueuer = None

    try:
        enqueuer = OrderedEnqueuer(generator)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        if verbose == 1:
            progbar = Progbar(target=steps)

        while steps_done < steps:
            generator_output = next(output_generator)
            if isinstance(generator_output, tuple):
                # Compatibility with the generators
                # used for training.
                if len(generator_output) == 2:
                    x, y = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y)`. '
                                     'Found: ' + str(generator_output))
            else:
                raise ValueError('Generator should yield a tuple `(x, y)`')

            outs = model.predict_on_batch(x)
            outs = to_list(outs)
            labels = to_list(y)

            if not all_outs:
                for out in outs:
                    all_outs.append([])

            if not all_labels:
                for lab in labels:
                    all_labels.append([])

            for i, out in enumerate(outs):
                all_outs[i].append(out)

            for i, lab in enumerate(labels):
                all_labels[i].append(lab)

            steps_done += 1
            if verbose == 1:
                progbar.update(steps_done)

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    if len(all_outs) == 1:
        if steps_done == 1:
            all_outs = all_outs[0][0]
            all_labels = all_labels[0][0]
        else:
            all_outs = np.concatenate(all_outs[0])
            all_labels = np.concatenate(all_labels[0])

        return all_outs, all_labels

    if steps_done == 1:
        all_outs = [out[0] for out in all_outs]
        all_labels = [lab[0] for lab in all_labels]
    else:
        all_outs = [np.concatenate(out) for out in all_outs]
        all_labels = [np.concatenate(lab) for lab in all_labels]

    return all_outs, all_labels
コード例 #25
0
def test_ordered_enqueuer_fail_processes():
    enqueuer = OrderedEnqueuer(FaultSequence(), use_multiprocessing=True)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    with pytest.raises(StopIteration):
        next(gen_output)
コード例 #26
0
ファイル: utils.py プロジェクト: zjj-2015/land-cover
class ModelDiagnoser(Callback):
    # pylint: disable=not-context-manager,too-many-instance-attributes
    # Disable context manager warning for generator
    def __init__(self, data_generator, batch_size, num_samples, output_dir,
                 superres):
        def read_file_colormap(file_path):
            out_list = []
            with open(file_path) as color_map:
                csv_color_map = csv.reader(color_map)
                next(csv_color_map)
                for row in csv_color_map:
                    out_list.append((int(row[0]), (row[1], (row[2:]))))
            return collections.OrderedDict(out_list)

        def to_matplotlib_colormap(ordered_dict):
            def rgb(r, g, b):
                def clamp(x):
                    return max(0, min(int(x), 255))

                return "#{0:02x}{1:02x}{2:02x}".format(clamp(r), clamp(g),
                                                       clamp(b))

            return matplotlib.colors.ListedColormap([
                rgb(*ordered_dict[i][1]) if i in ordered_dict else "#000000"
                for i in ordered_dict
            ])

        super().__init__()
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.superres = superres
        self.enqueuer = OrderedEnqueuer(data_generator,
                                        use_multiprocessing=True,
                                        shuffle=False)
        self.enqueuer.start(workers=4, max_queue_size=4)
        self.writer = tf.summary.create_file_writer(output_dir)

        self.hr_classes = read_file_colormap(config.HR_COLOR)
        assert (len(self.hr_classes) == config.HR_NCLASSES -
                1), f"Wrong HR color map {config.HR_COLOR}"
        self.sr_classes = read_file_colormap(config.LR_COLOR)
        assert (len(self.sr_classes) == config.LR_NCLASSES -
                1), f"Wrong SR color map {config.LR_COLOR}"

        self.hr_classes_cmap = to_matplotlib_colormap(self.hr_classes)
        self.sr_classes_cmap = to_matplotlib_colormap(self.sr_classes)

    def plot_confusion_matrix(self, correct_labels, predict_labels):
        labels = [0] + list(self.hr_classes.keys())
        cm = confusion_matrix(correct_labels.reshape(-1),
                              predict_labels.reshape(-1),
                              labels=labels)
        figure = plt.figure(figsize=(10, 10))
        plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
        plt.title("Confusion matrix")
        plt.colorbar()
        tick_marks = np.arange(len(labels))
        plt.xticks(tick_marks, labels, rotation=45)
        plt.yticks(tick_marks, labels)

        # Normalize the confusion matrix.
        cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis],
                       decimals=2)

        # Use white text if squares are dark; otherwise black.
        threshold = cm.max() / 2.0
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            color = "white" if cm[i, j] > threshold else "black"
            plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

        plt.tight_layout()
        plt.ylabel("True label")
        plt.xlabel("Predicted label")
        return self.plot_to_image(figure)

    def plot_classification(self, np_array, cmap):
        figure = plt.figure(figsize=(10, 10))
        plt.imshow(np.squeeze(np_array), cmap=cmap, vmin=0, vmax=cmap.N)
        plt.axis("off")
        return self.plot_to_image(figure)

    @staticmethod
    def plot_to_image(figure):
        # Save the plot to a PNG in memory.
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        # Closing the figure prevents it from being displayed directly inside
        # the notebook.
        plt.close(figure)
        buf.seek(0)
        # Convert PNG buffer to TF image
        image = tf.image.decode_png(buf.getvalue(), channels=4)
        # Add the batch dimension
        image = tf.expand_dims(image, 0)
        return image

    def on_epoch_end(self, epoch, logs=None):
        def to_label(batch):
            label_batch = np.zeros(batch.shape[0:3])
            for i in range(batch.shape[0]):
                label_batch[i] = np.argmax(batch[i], axis=2)
            return label_batch

        output_generator = self.enqueuer.get()
        generator_output = next(output_generator)
        if self.superres:
            x_batch, y = generator_output
            y_batch_hr = y["outputs_hr"]
            y_batch_sr = y["outputs_sr"]
        else:
            x_batch, y_batch_hr = generator_output

        y_pred = self.model.predict(x_batch)
        if self.superres:
            y_pred, y_pred_sr = y_pred

        label_y_hr = to_label(y_batch_hr)
        label_y_pred = to_label(y_pred)

        with self.writer.as_default():
            for sample_index in [
                    i for i in range(0, 3) if i <= self.batch_size - 1
            ]:
                tf.summary.image(
                    "Epoch-{}/{}/image".format(epoch, sample_index),
                    x_batch[[sample_index], :, :, :3],
                    step=epoch,
                )
                tf.summary.image(
                    "Epoch-{}/{}/label".format(epoch, sample_index),
                    self.plot_classification(
                        label_y_hr[sample_index],
                        self.hr_classes_cmap,
                    ),
                    step=epoch,
                )
                tf.summary.image(
                    "Epoch-{}/{}/pred".format(epoch, sample_index),
                    self.plot_classification(
                        label_y_pred[sample_index],
                        self.hr_classes_cmap,
                    ),
                    step=epoch,
                )
                tf.summary.image(
                    f"Epoch-{epoch}/confusion_matrix",
                    self.plot_confusion_matrix(
                        label_y_hr.squeeze(),
                        label_y_pred.squeeze(),
                    ),
                    step=epoch,
                )

                if self.superres:
                    tf.summary.image(
                        "Epoch-{}/{}/label_sr".format(epoch, sample_index),
                        self.plot_classification(
                            np.argmax(y_batch_sr[sample_index, :, :, :],
                                      axis=2),
                            self.sr_classes_cmap,
                        ),
                        step=epoch,
                    )
                    tf.summary.image(
                        "Epoch-{}/{}/pred_sr".format(epoch, sample_index),
                        self.plot_classification(
                            np.argmax(y_pred_sr[sample_index, :, :, :],
                                      axis=2),
                            self.hr_classes_cmap,
                        ),
                        step=epoch,
                    )

    def on_train_end(self, logs=None):
        self.enqueuer.stop()
        self.writer.close()
コード例 #27
0
ファイル: utils.py プロジェクト: up42/land-cover-public
class ModelDiagnoser(Callback):
    # pylint: disable=not-context-manager,too-many-instance-attributes
    # Disable context manager warning for generator
    """TensorFlow based Callback class to plot source imagery, labels,
    predicted labels, and confusion matrix in TensorBoard.
    Based on https://stackoverflow.com/a/55856716
    """

    def __init__(
        self,
        data_generator: keras.utils.Sequence,
        batch_size: int,
        num_samples: int,
        output_dir: str,
        superres: bool = False,
    ):
        """
        Parameters
        ----------
        data_generator : keras.utils.Sequence
            Training or validtion generator. The imagery and labels that are
            plotted will come from here.
        batch_size : int
            The size of each batch.
        num_samples : int
            The number of samples (images) to plot at each epoch.
        output_dir : str
            Output directory of logs for TensorBoard.
        superres : bool
            If the model is using superres loss function.

        Returns
        -------
        ModelDiagnoser
            Initialized TensorFlow callback, ready to be used with fit method.

        """

        def read_file_colormap(file_path: str) -> collections.OrderedDict:
            """Reads a text file with color assigned to each class. i.e.:
            class, r, g, b
            1, 255, 0, 0

            Which will result in label 1 to be assigned color red.

            Parameters
            ----------
            file_path : str
                File path of the text file described above.

            Returns
            -------
            collections.OrderedDict
                OrderedDict with label plus assigned rgb color.
            """
            out_list = []
            with open(file_path) as color_map:
                csv_color_map = csv.reader(color_map)
                next(csv_color_map)
                for row in csv_color_map:
                    out_list.append((int(row[0]), (row[1], (row[2:]))))
            return collections.OrderedDict(out_list)

        def to_matplotlib_colormap(
            ordered_dict: collections.OrderedDict,
        ) -> matplotlib.colors.ListedColormap:
            """From ordered_dict generate a ListedColormap.

            Parameters
            ----------
            ordered_dict : collections.OrderedDict
                OrderedDict generated by read_file_colormap.

            Returns
            -------
            matplotlib.colors.ListedColormap
                Colormap for labels ready to use within a matplotlib figure.

            """

            def rgb(r: int, g: int, b: int) -> str:
                def clamp(x):
                    return max(0, min(int(x), 255))

                return "#{0:02x}{1:02x}{2:02x}".format(clamp(r), clamp(g), clamp(b))

            return matplotlib.colors.ListedColormap(
                [
                    rgb(*ordered_dict[i][1]) if i in ordered_dict else "#000000"
                    for i in ordered_dict
                ]
            )

        super().__init__()
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.superres = superres
        self.enqueuer = OrderedEnqueuer(
            data_generator, use_multiprocessing=True, shuffle=False
        )
        self.enqueuer.start(workers=4, max_queue_size=4)
        self.writer = tf.summary.create_file_writer(output_dir)

        self.hr_classes = read_file_colormap(config.HR_COLOR)
        assert (
            len(self.hr_classes) == config.HR_NCLASSES - 1
        ), f"Wrong HR color map {config.HR_COLOR}"
        self.sr_classes = read_file_colormap(config.LR_COLOR)
        assert (
            len(self.sr_classes) == config.LR_NCLASSES - 1
        ), f"Wrong SR color map {config.LR_COLOR}"

        self.hr_classes_cmap = to_matplotlib_colormap(self.hr_classes)
        self.sr_classes_cmap = to_matplotlib_colormap(self.sr_classes)

    def plot_confusion_matrix(
        self, correct_labels: np.ndarray, predict_labels: np.ndarray
    ) -> tf.Tensor:
        """From labels and predict_labels generate a confusion matrix.

        Parameters
        ----------
        correct_labels : np.ndarray
            Given with the data_generator.
        predict_labels : np.ndarray
            Result of apllying the model on imagery given by data_generator.

        Returns
        -------
        tf.Tensor
            A tensor representing a confusion matrix image.

        """
        labels = [0] + list(self.hr_classes.keys())
        cm = confusion_matrix(
            correct_labels.reshape(-1), predict_labels.reshape(-1), labels=labels
        )
        figure = plt.figure(figsize=(10, 10))
        plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
        plt.title("Confusion matrix")
        plt.colorbar()
        tick_marks = np.arange(len(labels))
        plt.xticks(tick_marks, labels, rotation=45)
        plt.yticks(tick_marks, labels)

        # Normalize the confusion matrix.
        cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2)

        # Use white text if squares are dark; otherwise black.
        threshold = cm.max() / 2.0
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            color = "white" if cm[i, j] > threshold else "black"
            plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

        plt.tight_layout()
        plt.ylabel("True label")
        plt.xlabel("Predicted label")
        return self.plot_to_image(figure)

    def plot_classification(
        self, np_array: np.ndarray, cmap: matplotlib.colors.ListedColormap
    ) -> tf.Tensor:
        """From labels generate an image with the classes mapped to the cmap (color).

        Parameters
        ----------
        np_array : np.ndarray
            Labels as given by the data_generator.
        cmap : matplotlib.colors.ListedColormap
            Mapping of class to color.

        Returns
        -------
        tf.Tensor
            A tensor representing a label plot with appropriate color.

        """
        figure = plt.figure(figsize=(10, 10))
        plt.imshow(np.squeeze(np_array), cmap=cmap, vmin=0, vmax=cmap.N)
        plt.axis("off")
        return self.plot_to_image(figure)

    @staticmethod
    def plot_to_image(figure: matplotlib.figure.Figure) -> tf.Tensor:
        """Generate a tensor from a matplotlib figure.
        From https://www.tensorflow.org/tensorboard/image_summaries#logging_arbitrary_image_data

        Parameters
        ----------
        figure : matplotlib.figure.Figure
            Figure to convert to tensor.

        Returns
        -------
        tf.Tensor
            A tensor representation of the figure.

        """
        # Save the plot to a PNG in memory.
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        # Closing the figure prevents it from being displayed directly inside
        # the notebook.
        plt.close(figure)
        buf.seek(0)
        # Convert PNG buffer to TF image
        image = tf.image.decode_png(buf.getvalue(), channels=4)
        # Add the batch dimension
        image = tf.expand_dims(image, 0)
        return image

    def on_epoch_end(self, epoch: int, logs=None):
        """Defines what is run at the end of one epoch.
        - Plots images
        - Plots label
        - Plots prediction labels
        - Plots confusion matrix of batch

        Parameters
        ----------
        epoch : int
            The epoch number.
        """

        def to_label(batch: np.ndarray) -> np.ndarray:
            """From label 1hot encoded to classified labels.

            Parameters
            ----------
            batch : np.ndarray
                One hot encoded labels (b_size, h, w, nclasses).

            Returns
            -------
            np.ndarray
                Regular labels. (b_size, h, w, 1)

            """

            label_batch = np.zeros(batch.shape[0:3])
            for i in range(batch.shape[0]):
                label_batch[i] = np.argmax(batch[i], axis=2)
            return label_batch

        output_generator = self.enqueuer.get()
        generator_output = next(output_generator)
        if self.superres:
            x_batch, y = generator_output
            y_batch_hr = y["outputs_hr"]
            y_batch_sr = y["outputs_sr"]
        else:
            x_batch, y_batch_hr = generator_output

        y_pred = self.model.predict(x_batch)
        if self.superres:
            y_pred, y_pred_sr = y_pred

        label_y_hr = to_label(y_batch_hr)
        label_y_pred = to_label(y_pred)

        with self.writer.as_default():
            for sample_index in [i for i in range(0, 3) if i <= self.batch_size - 1]:
                tf.summary.image(
                    "Epoch-{}/{}/image".format(epoch, sample_index),
                    x_batch[[sample_index], :, :, :3],
                    step=epoch,
                )
                tf.summary.image(
                    "Epoch-{}/{}/label".format(epoch, sample_index),
                    self.plot_classification(
                        label_y_hr[sample_index], self.hr_classes_cmap,
                    ),
                    step=epoch,
                )
                tf.summary.image(
                    "Epoch-{}/{}/pred".format(epoch, sample_index),
                    self.plot_classification(
                        label_y_pred[sample_index], self.hr_classes_cmap,
                    ),
                    step=epoch,
                )
                tf.summary.image(
                    f"Epoch-{epoch}/confusion_matrix",
                    self.plot_confusion_matrix(
                        label_y_hr.squeeze(), label_y_pred.squeeze(),
                    ),
                    step=epoch,
                )

                if self.superres:
                    tf.summary.image(
                        "Epoch-{}/{}/label_sr".format(epoch, sample_index),
                        self.plot_classification(
                            np.argmax(y_batch_sr[sample_index, :, :, :], axis=2),
                            self.sr_classes_cmap,
                        ),
                        step=epoch,
                    )
                    tf.summary.image(
                        "Epoch-{}/{}/pred_sr".format(epoch, sample_index),
                        self.plot_classification(
                            np.argmax(y_pred_sr[sample_index, :, :, :], axis=2),
                            self.hr_classes_cmap,
                        ),
                        step=epoch,
                    )

    def on_train_end(self, logs=None):
        """Defines what is run at the end of training.
        - Stops the data_generator
        - Closes TensorBoard writer
        """
        self.enqueuer.stop()
        self.writer.close()
コード例 #28
0
ファイル: data_utils_test.py プロジェクト: heechul90/Keras
def test_ordered_enqueuer_fail_threads():
    enqueuer = OrderedEnqueuer(FaultSequence(), use_multiprocessing=False)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    with pytest.raises(IndexError):
        next(gen_output)
コード例 #29
0
ファイル: apply_alan.py プロジェクト: jorisgu/keras-retinanet
# model_path =  "/dds/work/workspace/keras-retinanet/snapshots/alan_training_02_1eraout/resnet50_alan_13_inference.h5"
# model_path =  "/dds/work/workspace/keras-retinanet/snapshots/alan_training_02_1eraout/resnet50_alan_02_inference.h5"
model_path = "/dds/work/workspace/keras-retinanet/snapshots/alan_training_02_1eraout/resnet50_alan_06_inference.h5"
model = models.load_model(model_path, backbone_name='resnet50')


score_threshold = 0.1
max_detections = 100


save_path = "/dds/work/workspace/alan_tmp_files/results_apply_alan"
if save_path is not None:
    mkdir_p(save_path)
data_sequence = TestSequence(image_path=image_to_test, folder_crops="/dds/work/workspace/alan_tmp_files/sequence_crops")
psaver = prediction_saver(save_path)
ordered_data_sequence = OrderedEnqueuer(data_sequence, use_multiprocessing=True)
ordered_data_sequence.start(workers=4, max_queue_size=100)
datas = ordered_data_sequence.get()
t = tqdm(datas,total=len(data_sequence))

for id, xywh, image in t:
    if len([None for x in data_sequence.results if x is None])==0:
        break
    if data_sequence.results[id] is not None:
        continue
    # if not (id==211 or id==245):
    #     continue

    # run network
    boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3]
コード例 #30
0
ファイル: data_utils_test.py プロジェクト: rilut/keras
def test_ordered_enqueuer_fail_processes():
    enqueuer = OrderedEnqueuer(FaultSequence(), use_multiprocessing=True)
    enqueuer.start(3, 10)
    gen_output = enqueuer.get()
    with pytest.raises(StopIteration):
        next(gen_output)
コード例 #31
0
ファイル: utils.py プロジェクト: up42/land-cover-public
    def __init__(
        self,
        data_generator: keras.utils.Sequence,
        batch_size: int,
        num_samples: int,
        output_dir: str,
        superres: bool = False,
    ):
        """
        Parameters
        ----------
        data_generator : keras.utils.Sequence
            Training or validtion generator. The imagery and labels that are
            plotted will come from here.
        batch_size : int
            The size of each batch.
        num_samples : int
            The number of samples (images) to plot at each epoch.
        output_dir : str
            Output directory of logs for TensorBoard.
        superres : bool
            If the model is using superres loss function.

        Returns
        -------
        ModelDiagnoser
            Initialized TensorFlow callback, ready to be used with fit method.

        """

        def read_file_colormap(file_path: str) -> collections.OrderedDict:
            """Reads a text file with color assigned to each class. i.e.:
            class, r, g, b
            1, 255, 0, 0

            Which will result in label 1 to be assigned color red.

            Parameters
            ----------
            file_path : str
                File path of the text file described above.

            Returns
            -------
            collections.OrderedDict
                OrderedDict with label plus assigned rgb color.
            """
            out_list = []
            with open(file_path) as color_map:
                csv_color_map = csv.reader(color_map)
                next(csv_color_map)
                for row in csv_color_map:
                    out_list.append((int(row[0]), (row[1], (row[2:]))))
            return collections.OrderedDict(out_list)

        def to_matplotlib_colormap(
            ordered_dict: collections.OrderedDict,
        ) -> matplotlib.colors.ListedColormap:
            """From ordered_dict generate a ListedColormap.

            Parameters
            ----------
            ordered_dict : collections.OrderedDict
                OrderedDict generated by read_file_colormap.

            Returns
            -------
            matplotlib.colors.ListedColormap
                Colormap for labels ready to use within a matplotlib figure.

            """

            def rgb(r: int, g: int, b: int) -> str:
                def clamp(x):
                    return max(0, min(int(x), 255))

                return "#{0:02x}{1:02x}{2:02x}".format(clamp(r), clamp(g), clamp(b))

            return matplotlib.colors.ListedColormap(
                [
                    rgb(*ordered_dict[i][1]) if i in ordered_dict else "#000000"
                    for i in ordered_dict
                ]
            )

        super().__init__()
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.superres = superres
        self.enqueuer = OrderedEnqueuer(
            data_generator, use_multiprocessing=True, shuffle=False
        )
        self.enqueuer.start(workers=4, max_queue_size=4)
        self.writer = tf.summary.create_file_writer(output_dir)

        self.hr_classes = read_file_colormap(config.HR_COLOR)
        assert (
            len(self.hr_classes) == config.HR_NCLASSES - 1
        ), f"Wrong HR color map {config.HR_COLOR}"
        self.sr_classes = read_file_colormap(config.LR_COLOR)
        assert (
            len(self.sr_classes) == config.LR_NCLASSES - 1
        ), f"Wrong SR color map {config.LR_COLOR}"

        self.hr_classes_cmap = to_matplotlib_colormap(self.hr_classes)
        self.sr_classes_cmap = to_matplotlib_colormap(self.sr_classes)
コード例 #32
0
def test_context_switch():
    enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]),
                               use_multiprocessing=True)
    enqueuer2 = OrderedEnqueuer(DummySequence([3, 200, 200, 3], value=15),
                                use_multiprocessing=True)
    enqueuer.start(3, 10)
    enqueuer2.start(3, 10)
    gen_output = enqueuer.get()
    gen_output2 = enqueuer2.get()
    acc = []
    for i in range(100):
        acc.append(next(gen_output)[0, 0, 0, 0])
    assert acc[-1] == 99
    # One epoch is completed so enqueuer will switch the Sequence

    acc = []
    for i in range(100):
        acc.append(next(gen_output2)[0, 0, 0, 0])
    assert acc[-1] == 99 * 15
    # One epoch has been completed so enqueuer2 will switch

    # Be sure that both Sequence were updated
    assert next(gen_output)[0, 0, 0, 0] == 0
    assert next(gen_output)[0, 0, 0, 0] == 5
    assert next(gen_output2)[0, 0, 0, 0] == 0
    assert next(gen_output2)[0, 0, 0, 0] == 15 * 5

    # Tear down everything
    enqueuer.stop()
    enqueuer2.stop()
コード例 #33
0
    def fit_with_pseudo_label(self,
                              steps_per_epoch,
                              use_checkpoints=False,
                              class_labels=None,
                              verbose=1,
                              use_multiprocessing=False,
                              shuffle=False,
                              workers=1,
                              max_queue_size=10):

        wait_time = 0.01  # in seconds

        self.model._make_train_function()

        # Create a checkpoint callback
        checkpoint = ModelCheckpoint("../models_checkpoints/" +
                                     str(self.h5_filename) + ".h5",
                                     monitor='val_acc',
                                     verbose=1,
                                     save_best_only=True,
                                     save_weights_only=True,
                                     mode='auto',
                                     period=1)

        # Generate callbacks
        callback_list = []
        if use_checkpoints:
            callback_list.extend(checkpoint)

        # Init train counters
        epoch = 0

        # Prepare display labels.
        out_labels = self.model._get_deduped_metrics_names()
        callback_metrics = out_labels + ['val_' + n for n in out_labels]

        # Prepare train callbacks
        self.model.history = cbks.History()
        callbacks = [cbks.BaseLogger()] + (callback_list or []) + \
            [self.model.history]
        if verbose:
            callbacks += [cbks.ProgbarLogger(count_mode='steps')]
        callbacks = cbks.CallbackList(callbacks)

        # it's possible to callback a different model than self:
        if hasattr(self.model, 'callback_model') and self.model.callback_model:
            callback_model = self.model.callback_model

        else:
            callback_model = self.model

        callbacks.set_model(callback_model)

        is_sequence = isinstance(self.train_generator, Sequence)
        if not is_sequence and use_multiprocessing and workers > 1:
            warnings.warn(
                UserWarning('Using a generator with `use_multiprocessing=True`'
                            ' and multiple workers may duplicate your data.'
                            ' Please consider using the`keras.utils.Sequence'
                            ' class.'))

        if is_sequence:
            steps_per_epoch = len(self.train_generator)
        enqueuer = None

        callbacks.set_params({
            'epochs': self.epochs,
            'steps': steps_per_epoch,
            'verbose': verbose,
            'do_validation': True,
            'metrics': callback_metrics,
        })
        callbacks.on_train_begin()

        try:
            if is_sequence:
                enqueuer = OrderedEnqueuer(
                    self.train_generator,
                    use_multiprocessing=use_multiprocessing,
                    shuffle=shuffle)
            else:
                enqueuer = GeneratorEnqueuer(
                    self.train_generator,
                    use_multiprocessing=use_multiprocessing,
                    wait_time=wait_time)
            enqueuer.start(workers=workers, max_queue_size=max_queue_size)
            output_generator = enqueuer.get()

            # Train the model
            # Epochs
            while epoch < self.epochs:
                callbacks.on_epoch_begin(epoch)
                steps_done = 0
                batch_index = 0

                # Steps per epoch
                while steps_done < steps_per_epoch:

                    generator_output = next(output_generator)

                    if len(generator_output) == 2:
                        x, y = generator_output
                        sample_weight = None
                    elif len(generator_output) == 3:
                        x, y, sample_weight = generator_output
                    else:
                        raise ValueError('Output of generator should be '
                                         'a tuple `(x, y, sample_weight)` '
                                         'or `(x, y)`. Found: ' +
                                         str(generator_output))

                    #==========================
                    # Mini-batch
                    #==========================
                    print ''
                    print 'Generating pseudo-labels...'
                    no_label_output = self.model.predict_generator(
                        self.no_label_generator,
                        None,  # because the model is instance of sequence
                        verbose=1)

                    # One-hot encoded
                    self.no_label_generator.classes = np.argmax(
                        no_label_output, axis=1)

                    # Concat Pseudo labels with true labels
                    x_pseudo, y_pseudo = next(self.no_label_generator)
                    x, y = np.concatenate(
                        (x, x_pseudo), axis=0), np.concatenate((y, y_pseudo),
                                                               axis=0)

                    if len(generator_output) == 2:
                        x, y = generator_output
                        sample_weight = None
                    elif len(generator_output) == 3:
                        x, y, sample_weight = generator_output
                    else:
                        raise ValueError('Output of generator should be '
                                         'a tuple `(x, y, sample_weight)` '
                                         'or `(x, y)`. Found: ' +
                                         str(generator_output))

                    # build batch logs
                    batch_logs = {}
                    if isinstance(x, list):
                        batch_size = x[0].shape[0]
                    elif isinstance(x, dict):
                        batch_size = list(x.values())[0].shape[0]
                    else:
                        batch_size = x.shape[0]
                    batch_logs['batch'] = batch_index
                    batch_logs['size'] = batch_size
                    callbacks.on_batch_begin(batch_index, batch_logs)

                    # Runs a single gradient update on a single batch of data
                    scalar_training_loss = self.model.train_on_batch(x=x, y=y)

                    if not isinstance(scalar_training_loss, list):
                        scalar_training_loss = [scalar_training_loss]
                    for l, o in zip(out_labels, scalar_training_loss):
                        batch_logs[l] = o

                    callbacks.on_batch_end(batch_index, batch_logs)

                    #==========================
                    # end Mini-batch
                    #==========================

                    batch_index += 1
                    steps_done += 1

                # Epoch finished.
                epoch += 1

        finally:
            if enqueuer is not None:
                enqueuer.stop()

        callbacks.on_train_end()
        return self.model.history
コード例 #34
0
ファイル: models.py プロジェクト: j-varun/enas
    def __init__(self,
                 images,
                 labels,
                 cutout_size=None,
                 batch_size=32,
                 eval_batch_size=32,
                 clip_mode=None,
                 grad_bound=None,
                 l2_reg=1e-4,
                 lr_init=0.1,
                 lr_dec_start=0,
                 lr_dec_every=100,
                 lr_dec_rate=0.1,
                 keep_prob=1.0,
                 optim_algo=None,
                 sync_replicas=False,
                 num_aggregate=None,
                 num_replicas=None,
                 data_format="NHWC",
                 name="generic_model",
                 seed=None,
                 valid_set_size=32,
                 image_shape=(32, 32, 3),
                 translation_only=False,
                 rotation_only=False,
                 stacking_reward=False,
                 use_root=False,
                 dataset="cifar",
                 data_base_path="",
                 one_hot_encoding=False,
                 random_augmentation=None):
        """
        Args:
          lr_dec_every: number of epochs to decay
        """
        print("-" * 80)
        print("Build model {}".format(name))

        self.cutout_size = cutout_size
        self.batch_size = batch_size
        # TODO change back to eval_batch size, pass eval_batch_size from arguments
        self.eval_batch_size = batch_size
        self.clip_mode = clip_mode
        self.grad_bound = grad_bound
        self.l2_reg = l2_reg
        self.lr_init = lr_init
        self.lr_dec_start = lr_dec_start
        self.lr_dec_rate = lr_dec_rate
        self.keep_prob = keep_prob
        self.optim_algo = optim_algo
        self.sync_replicas = sync_replicas
        self.num_aggregate = num_aggregate
        self.num_replicas = num_replicas
        self.data_format = data_format
        self.name = name
        self.seed = seed
        self.dataset = dataset
        self.valid_set_size = valid_set_size
        self.image_shape = image_shape
        self.rotation_only = rotation_only
        self.translation_only = translation_only
        self.stacking_reward = stacking_reward
        self.random_augmentation = random_augmentation
        self.data_base_path = data_base_path
        self.use_root = use_root
        self.one_hot_encoding = one_hot_encoding

        self.global_step = None
        self.valid_acc = None
        self.test_acc = None
        print("Build data ops")
        with tf.device("/cpu:0"):
            # training data

            # Support for stacking generator
            print("dataset----------------------", self.dataset)
            if self.dataset == "stacking":
                Dataset = tf.data.Dataset
                flags = tf.app.flags
                FLAGS = flags.FLAGS
                np.random.seed(0)
                val_test_size = self.valid_set_size
                if images["path"] != "":
                    print("datadir------------", images["path"])
                    file_names = glob.glob(os.path.expanduser(images["path"]))
                    train_data = file_names[val_test_size * 2:]
                    validation_data = file_names[val_test_size:val_test_size *
                                                 2]
                    self.validation_data = validation_data
                    test_data = file_names[:val_test_size]
                else:
                    print(
                        "-------Loading train-test-val from txt files-------")
                    self.data_base_path = os.path.expanduser(
                        self.data_base_path)
                    with open(
                            self.data_base_path +
                            'costar_block_stacking_v0.3_success_only_train_files.txt',
                            mode='r') as myfile:
                        train_data = myfile.read().splitlines()
                    with open(
                            self.data_base_path +
                            'costar_block_stacking_v0.3_success_only_test_files.txt',
                            mode='r') as myfile:
                        test_data = myfile.read().splitlines()
                    with open(
                            self.data_base_path +
                            'costar_block_stacking_v0.3_success_only_val_files.txt',
                            mode='r') as myfile:
                        validation_data = myfile.read().splitlines()
                    print(train_data)
                    # train_data = [self.data_base_path + name for name in train_data]
                    # test_data = [self.data_base_path + name for name in test_data]
                    # validation_data = [self.data_base_path + name for name in validation_data]
                    print(validation_data)
                # number of images to look at per example
                # TODO(ahundt) currently there is a bug in one of these calculations, lowering images per example to reduce number of steps per epoch for now.
                estimated_images_per_example = 2
                print("valid set size", val_test_size)
                # TODO(ahundt) fix quick hack to proceed through epochs faster
                # self.num_train_examples = len(train_data) * self.batch_size * estimated_images_per_example
                # self.num_train_batches = (self.num_train_examples + self.batch_size - 1) // self.batch_size
                self.num_train_examples = len(
                    train_data) * estimated_images_per_example
                self.num_train_batches = (self.num_train_examples +
                                          self.batch_size -
                                          1) // self.batch_size
                # output_shape = (32, 32, 3)
                # WARNING: IF YOU ARE EDITING THIS CODE, MAKE SURE TO ALSO CHECK micro_controller.py and micro_child.py WHICH ALSO HAS A GENERATOR
                if self.translation_only is True:
                    # We've found evidence (but not concluded finally) in hyperopt
                    # that input of the rotation component actually
                    # lowers translation accuracy at least in the colored block case
                    # switch between the two commented lines to go back to the prvious behavior
                    # data_features = ['image_0_image_n_vec_xyz_aaxyz_nsc_15']
                    # self.data_features_len = 15
                    data_features = ['image_0_image_n_vec_xyz_nxygrid_12']
                    self.data_features_len = 12
                    label_features = ['grasp_goal_xyz_3']
                    self.num_classes = 3
                elif self.rotation_only is True:
                    data_features = ['image_0_image_n_vec_xyz_aaxyz_nsc_15']
                    self.data_features_len = 15
                    # disabled 2 lines below below because best run 2018_12_2054 was with settings above
                    # include a normalized xy grid, similar to uber's coordconv
                    # data_features = ['image_0_image_n_vec_xyz_aaxyz_nsc_nxygrid_17']
                    # self.data_features_len = 17
                    label_features = ['grasp_goal_aaxyz_nsc_5']
                    self.num_classes = 5
                elif self.stacking_reward is True:
                    data_features = [
                        'image_0_image_n_vec_0_vec_n_xyz_aaxyz_nsc_nxygrid_25'
                    ]
                    self.data_features_len = 25
                    label_features = ['stacking_reward']
                    self.num_classes = 1
                # elif self.use_root is True:
                #     data_features = ['current_xyz_aaxyz_nsc_8']
                #     self.data_features_len = 8
                #     label_features = ['grasp_goal_xyz_3']
                #     self.num_classes = 8
                else:
                    # original input block
                    # data_features = ['image_0_image_n_vec_xyz_aaxyz_nsc_15']
                    # include a normalized xy grid, similar to uber's coordconv
                    data_features = [
                        'image_0_image_n_vec_xyz_aaxyz_nsc_nxygrid_17'
                    ]
                    self.data_features_len = 17
                    label_features = ['grasp_goal_xyz_aaxyz_nsc_8']
                    self.num_classes = 8
                if self.one_hot_encoding:
                    self.data_features_len += 40
                training_generator = CostarBlockStackingSequence(
                    train_data,
                    batch_size=batch_size,
                    verbose=0,
                    label_features_to_extract=label_features,
                    data_features_to_extract=data_features,
                    output_shape=self.image_shape,
                    shuffle=True,
                    random_augmentation=self.random_augmentation,
                    one_hot_encoding=self.one_hot_encoding)

                train_enqueuer = OrderedEnqueuer(training_generator,
                                                 use_multiprocessing=False,
                                                 shuffle=True)
                train_enqueuer.start(workers=10, max_queue_size=100)

                def train_generator():
                    return iter(train_enqueuer.get())

                train_dataset = Dataset.from_generator(
                    train_generator, (tf.float32, tf.float32),
                    (tf.TensorShape([
                        None, self.image_shape[0], self.image_shape[1],
                        self.data_features_len
                    ]), tf.TensorShape([None, None])))
                # if self.use_root is True:
                #     train_dataset = Dataset.from_generator(train_generator, (tf.float32, tf.float32), (tf.TensorShape(
                #         [None, 2]), tf.TensorShape([None, None])))
                trainer = train_dataset.make_one_shot_iterator()
                x_train, y_train = trainer.get_next()
                # x_train_list = []
                # x_train_list[0] = np.reshape(x_train[0][0], [-1, self.image_shape[1], self.image_shape[2], 3])
                # x_train_list[1] = np.reshape(x_train[0][1], [-1, self.image_shape[1], self.image_shape[2], 3])
                # x_train_list[2] = np.reshape(x_train[0][2],[-1, ])
                # print("x shape--------------", x_train.shape)
                print("batch--------------------------",
                      self.num_train_examples, self.num_train_batches)
                print("y shape--------------", y_train.shape)
                self.x_train = x_train
                self.y_train = y_train

            else:
                self.num_train_examples = np.shape(images["train"])[0]
                self.num_classes = 10
                self.num_train_batches = (self.num_train_examples +
                                          self.batch_size -
                                          1) // self.batch_size

                x_train, y_train = tf.train.shuffle_batch(
                    [images["train"], labels["train"]],
                    batch_size=self.batch_size,
                    capacity=50000,
                    enqueue_many=True,
                    min_after_dequeue=0,
                    num_threads=16,
                    seed=self.seed,
                    allow_smaller_final_batch=True,
                )

                def _pre_process(x):
                    print("prep shape ", x.get_shape())
                    dims = list(x.get_shape())
                    dim = max(dims)
                    x = tf.pad(x, [[4, 4], [4, 4], [0, 0]])
                    #x = tf.random_crop(x, [32, 32, 3], seed=self.seed)
                    x = tf.random_crop(x, dims, seed=self.seed)
                    x = tf.image.random_flip_left_right(x, seed=self.seed)
                    if self.cutout_size is not None:
                        mask = tf.ones([self.cutout_size, self.cutout_size],
                                       dtype=tf.int32)
                        start = tf.random_uniform([2],
                                                  minval=0,
                                                  maxval=dim,
                                                  dtype=tf.int32)
                        mask = tf.pad(
                            mask,
                            [[self.cutout_size + start[0], dim - start[0]],
                             [self.cutout_size + start[1], dim - start[1]]])
                        mask = mask[self.cutout_size:self.cutout_size + dim,
                                    self.cutout_size:self.cutout_size + dim]
                        mask = tf.reshape(mask, [dim, dim, 1])
                        mask = tf.tile(mask, [1, 1, dims[2]])
                        x = tf.where(tf.equal(mask, 0),
                                     x=x,
                                     y=tf.zeros_like(x))
                    if self.data_format == "NCHW":
                        x = tf.transpose(x, [2, 0, 1])

                    return x

                self.x_train = tf.map_fn(_pre_process,
                                         x_train,
                                         back_prop=False)
                self.y_train = y_train
            self.lr_dec_every = lr_dec_every * self.num_train_batches

            # valid data
            self.x_valid, self.y_valid = None, None
            if self.dataset == "stacking":
                # TODO
                validation_generator = CostarBlockStackingSequence(
                    validation_data,
                    batch_size=batch_size,
                    verbose=0,
                    label_features_to_extract=label_features,
                    data_features_to_extract=data_features,
                    output_shape=self.image_shape,
                    one_hot_encoding=self.one_hot_encoding,
                    is_training=False)
                validation_enqueuer = OrderedEnqueuer(
                    validation_generator,
                    use_multiprocessing=False,
                    shuffle=True)
                validation_enqueuer.start(workers=10, max_queue_size=100)

                def validation_generator():
                    return iter(validation_enqueuer.get())

                validation_dataset = Dataset.from_generator(
                    validation_generator, (tf.float32, tf.float32),
                    (tf.TensorShape([
                        None, self.image_shape[0], self.image_shape[1],
                        self.data_features_len
                    ]), tf.TensorShape([None, None])))
                self.num_valid_examples = len(
                    validation_data
                ) * self.eval_batch_size * estimated_images_per_example
                self.num_valid_batches = (self.num_valid_examples +
                                          self.eval_batch_size -
                                          1) // self.eval_batch_size
                self.x_valid, self.y_valid = validation_dataset.make_one_shot_iterator(
                ).get_next()
                print("x-v........-------------", self.x_valid.shape)
                if "valid_original" not in images.keys():
                    images["valid_original"] = np.copy(self.x_valid)
                    labels["valid_original"] = np.copy(self.y_valid)
            else:
                if images["valid"] is not None:
                    images["valid_original"] = np.copy(images["valid"])
                    labels["valid_original"] = np.copy(labels["valid"])
                    if self.data_format == "NCHW":
                        images["valid"] = tf.transpose(images["valid"],
                                                       [0, 3, 1, 2])
                    self.num_valid_examples = np.shape(images["valid"])[0]
                    self.num_valid_batches = (
                        (self.num_valid_examples + self.eval_batch_size - 1) //
                        self.eval_batch_size)
                    self.x_valid, self.y_valid = tf.train.batch(
                        [images["valid"], labels["valid"]],
                        batch_size=self.eval_batch_size,
                        capacity=5000,
                        enqueue_many=True,
                        num_threads=1,
                        allow_smaller_final_batch=True,
                    )

            # test data
            if self.dataset == "stacking":
                # TODO
                test_generator = CostarBlockStackingSequence(
                    test_data,
                    batch_size=batch_size,
                    verbose=0,
                    label_features_to_extract=label_features,
                    data_features_to_extract=data_features,
                    output_shape=self.image_shape,
                    one_hot_encoding=self.one_hot_encoding,
                    is_training=False)
                test_enqueuer = OrderedEnqueuer(test_generator,
                                                use_multiprocessing=False,
                                                shuffle=True)
                test_enqueuer.start(workers=10, max_queue_size=100)

                def test_generator():
                    return iter(test_enqueuer.get())

                test_dataset = Dataset.from_generator(
                    test_generator, (tf.float32, tf.float32), (tf.TensorShape([
                        None, self.image_shape[0], self.image_shape[1],
                        self.data_features_len
                    ]), tf.TensorShape([None, None])))
                self.num_test_examples = len(
                    test_data
                ) * self.eval_batch_size * estimated_images_per_example
                self.num_test_batches = (self.num_valid_examples +
                                         self.eval_batch_size -
                                         1) // self.eval_batch_size
                self.x_test, self.y_test = test_dataset.make_one_shot_iterator(
                ).get_next()
            else:
                if self.data_format == "NCHW":
                    images["test"] = tf.transpose(images["test"], [0, 3, 1, 2])
                self.num_test_examples = np.shape(images["test"])[0]
                self.num_test_batches = (
                    (self.num_test_examples + self.eval_batch_size - 1) //
                    self.eval_batch_size)
                self.x_test, self.y_test = tf.train.batch(
                    [images["test"], labels["test"]],
                    batch_size=self.eval_batch_size,
                    capacity=10000,
                    enqueue_many=True,
                    num_threads=1,
                    allow_smaller_final_batch=True,
                )

        # cache images and labels
        self.images = images
        self.labels = labels
コード例 #35
0
ファイル: dataset_iterator.py プロジェクト: veegee82/tf_base
 def __init__(self, sequence, max_queue_size=10, workers=1):
     self.enqueuer = OrderedEnqueuer(sequence,
                                     use_multiprocessing=workers >= 2)
     self.enqueuer.start(workers, max_queue_size)
     self.generator = self.enqueuer.get()
     self._internal_index = 0