Beispiel #1
0
def main():
    parser = argparse.ArgumentParser(description='Running Settings')

    parser.add_argument('--model', help='valid options Vgg, ResNet and '
                                        'Squeeze Excitation Models',
                        required=True)

    parser.add_argument('--batch', help='# of batches', type=int,
                        default=32)

    parser.add_argument('--data', help='path where the data is stored',
                        default='data')

    args = parser.parse_args()

    if MODELS.get(args.model) is None:
        raise ValueError("Model Does not Exist")

    builder = DatasetBuilder(args.data, shape=(256, 256))
    builder()

    data_train = TFRecordDataset(join(args.data, 'train.records'))
    data_train = data_train.map(builder.decode)
    data_train = data_train.map(builder.augmentation)
    data_train = data_train.shuffle(7000)
    data_train = data_train.batch(batch_size=args.batch)

    data_test = TFRecordDataset(join(args.data, 'test.records'))
    data_test = data_test.map(builder.decode)
    data_test = data_test.batch(batch_size=args.batch)

    model = MODELS.get(args.model)()
    model.build((1, 256, 256, 3))
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])

    log_dir = join('logs', args.model)
    tensor_board_callback = callbacks.TensorBoard(log_dir=log_dir)
    model_checkpoint = callbacks.ModelCheckpoint('models/{}.h5'.format(args.model),
                                                 save_best_only=True)
    reduce_lr = callbacks.ReduceLROnPlateau(factor=0.2, patience=5,
                                            min_lr=1e-6)
    early_stop = callbacks.EarlyStopping(patience=10)

    _callbacks = [model_checkpoint, reduce_lr, early_stop,
                  tensor_board_callback]

    model.fit(data_train, epochs=100, validation_data=data_test,
              callbacks=_callbacks)
Beispiel #2
0
 def test_loader_tf_data_set_should_be_ok(self):
     builder = loader.DatasetBuilder('data', shape=(256, 256))
     dataset = TFRecordDataset('data/test.records')
     dataset = dataset.map(builder.decode)
     dataset = dataset.map(builder.augmentation())
     dataset = dataset.shuffle(4000)
     dataset = dataset.batch(batch_size=60)
Beispiel #3
0
    def __init__(self, data_files, sampler_config):
        """ Create a new ImageDataGenerator.
        
        Receives a configure dictionary, which specify how to load the data
        """
        self.config = sampler_config
        self.__check_image_patch_shape()
        batch_size = self.config['batch_size']
        self.label_convert_source = self.config.get('label_convert_source',
                                                    None)
        self.label_convert_target = self.config.get('label_convert_target',
                                                    None)

        data = TFRecordDataset(data_files, "ZLIB")
        data = data.map(self._parse_function, num_parallel_calls=5)
        if (self.config.get('data_shuffle', False)):
            data = data.shuffle(buffer_size=20 * batch_size)
        data = data.batch(batch_size)
        self.data = data