예제 #1
0
    def __next__(self):
        length = len(self)
        if length == 0:
            raise StopIteration()
        if self._curent_position == length:
            if self._loop:
                self.reset()
            else:
                raise StopIteration()

        entry = self._datasets[self._curent_position]
        env = getattr(self, 'rasterio_env', {})
        self._curent_position += 1
        entry_name, entry_components = entry
        new_components = {}
        cache_data = self._cache_data
        use_tensorflow_io = False
        for component_name, component_path in entry_components.items():
            if isinstance(component_path, DatasetReader):
                component_path = component_path.name
            local_component_path = component_path
            url_components = urlparse(component_path)
            if not url_components.scheme:
                cache_data = False
                if url_components.path.startswith('/vsigs/'):
                    cache_data = True  # We should check if we run inside GCP ML Engine
                    use_tensorflow_io = True
                    component_path = url_components.path[6:]
                    component_path = "gs:/" + component_path
            else:
                if url_components.scheme == 'file':
                    local_component_path = url_components.path
                    use_tensorflow_io = False
                    cache_data = False

            with rasterio.Env(**env):
                if use_tensorflow_io:
                    real_path = component_path
                    data = IOUtils.open_file(real_path, "rb").read()
                    if cache_data:
                        hash = sha224(component_path.encode("utf8")).hexdigest()
                        hash_part = "/".join(list(hash)[:3])
                        dataset_path = os.path.join(self._temp_dir, hash_part)
                        if not IOUtils.file_exists(dataset_path):
                            IOUtils.recursive_create_dir(dataset_path)
                        dataset_path = os.path.join(dataset_path, os.path.basename(component_path))
                        if not IOUtils.file_exists(dataset_path):
                            f = IOUtils.open_file(dataset_path, "wb")
                            f.write(data)
                            f.close()
                        component_src = rasterio.open(dataset_path)
                    else:
                        with NamedTemporaryFile() as tmpfile:
                            tmpfile.write(data)
                            tmpfile.flush()
                            component_src = rasterio.open(tmpfile.name)
                else:
                    component_src = rasterio.open(local_component_path)
                new_components[component_name] = component_src
        return (entry_name, new_components)
예제 #2
0
 def cleanup_dir(temp_dir):
     IOUtils.delete_recursively(temp_dir)
예제 #3
0
def train_handler(config, args):
    if args.switch_to_prefix:
        current_dir = os.path.abspath(os.path.dirname(__file__))
        current_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..", ".."))
        log.info("Switching to %s", current_dir)
        os.chdir(current_dir)
        log.info("Current dir: %s", os.path.abspath(os.getcwd()))
    with IOUtils.open_file(args.config, "r") as cfg_file:
        config = yaml.load(cfg_file, Loader=Loader)

    model_name = config["model_name"]
    model_type = config["model_type"]
    random_seed = config.get("random_seed", None)
    model_config = config["model"]
    tilling_config = config.get("tilling", {})
    if 'window_size' in tilling_config:
        window_size = tilling_config["window_size"]
    else:
        log.warning("Using deprectated `window_size` location")
        window_size = config["window_size"]
    if 'stride_size' in tilling_config:
        stride_size = tilling_config["stride_size"]
    else:
        log.warning("Using deprectated `stride_size` location")
        stride_size = config["stride_size"]

    if random_seed is not None:
        log.info("Setting Python and NumPy seed to: %d", random_seed)
        random.seed(random_seed)
        np.random.seed(random_seed)
    else:
        log.warning("No random seed specified!")

    limit_validation_datasets = config.get("limit_validation_datasets", None)
    limit_train_datasets = config.get("limit_train_datasets", None)

    data_source = config.get("data_source")
    mapping = config["mapping"]
    augment = config.get("augment", False)
    input_channels = len(mapping["inputs"])
    log.info("Input has %d channels", input_channels)
    log.info("Model type is: %s", model_type)

    if args.split is None:
        dataset_cache = config.get("dataset_cache", None)
        log.debug("dataset_cache is set from config to %s", dataset_cache)
        dataset_cache = dataset_cache.format(model_name=model_name,
                                             time=str(time.time()),
                                             hostname=socket.gethostname(),
                                             user=getpass.getuser())
    else:
        if not IOUtils.file_exists(args.split):
            raise FileNotFoundError("Invalid split file")
        dataset_cache = args.split

    log.info("dataset_cache will be directed to: %s", dataset_cache)

    if data_source.input_source is None:
        data_source.set_input_source(args.input)
    log.info("Using datasource: %s", data_source)

    if not IOUtils.file_exists(dataset_cache):
        log.info("Loading datasets")
        train_datasets, validation_datasets = data_source.get_dataset_loader()
        dump = (train_datasets._datasets, validation_datasets._datasets)

        log.info("Saving dataset cache to %s", dataset_cache)

        with IOUtils.open_file(dataset_cache, "w") as f:
            f.write(yaml.dump(dump, Dumper=Dumper))


    else:
        log.info("Loading training datasets from %s", dataset_cache)
        train_datasets, validation_datasets = yaml.load(IOUtils.open_file(dataset_cache), Loader=Loader)
        if isinstance(train_datasets, DatasetLoader):
            log.warning("Converting from legacy format: `train_datasets`")
            train_datasets = train_datasets._datasets
        if isinstance(validation_datasets, DatasetLoader):
            log.warning("Converting from legacy format: `validation_datasets`")
            validation_datasets = validation_datasets._datasets
        train_datasets, validation_datasets = data_source.build_dataset_loaders(train_datasets, validation_datasets)

    train_datasets.loop = True
    validation_datasets.loop = True

    if limit_validation_datasets:
        validation_datasets = validation_datasets[:limit_validation_datasets]

    if limit_train_datasets:
        train_datasets = train_datasets[:limit_train_datasets]

    pre_callbacks = []
    if augment:
        log.info("Enabling global level augmentation. Verify if this is desired!")

        def augment_callback(X, y):
            from ..preprocessing.augmentation import Augmentation
            aug = Augmentation(config)
            return aug.augment(X, y)

        pre_callbacks.append(augment_callback)

    log.info("Using %d training datasets", len(train_datasets))
    log.info("Using %d validation datasets", len(validation_datasets))

    if model_type == "keras":
        train_keras(model_name, window_size, stride_size, model_config, mapping, train_datasets, validation_datasets,
                    pre_callbacks=pre_callbacks,
                    enable_multi_gpu=args.keras_multi_gpu,
                    gpus=args.keras_gpus,
                    cpu_merge=args.keras_disable_cpu_merge,
                    cpu_relocation=args.keras_enable_cpu_relocation,
                    batch_size=args.keras_batch_size,
                    random_seed=random_seed,
                    )
        log.info("Keras Training completed")
    elif model_type == "sklearn":
        train_sklearn(model_name, window_size, stride_size, model_config, mapping, train_datasets, validation_datasets)
        log.info("Scikit Training completed")
    else:
        log.critical("Unknown model type: %s", model_type)
예제 #4
0
    def __next__(self):
        length = len(self)
        if length == 0:
            raise StopIteration()
        if self._curent_position == length:
            if self._loop:
                if self.randomise_on_loop:
                    random.shuffle(self._datasets)
                self.reset()
            else:
                raise StopIteration()

        entry = self._datasets[self._curent_position]
        env = getattr(self, 'rasterio_env', {})
        self._curent_position += 1
        entry_name, entry_components = entry
        new_components = {}
        cache_data = self._cache_data
        use_tensorflow_io = False
        for component_name, component_path_entry in entry_components.items():
            if isinstance(component_path_entry, (RasterGenerator, GeoDataFrame, MemoryFile)):
                new_components[component_name] = component_path_entry
                continue
            elif isinstance(component_path_entry, GeoDataFrame):
                new_components[component_name] = component_path_entry
                continue
            elif isinstance(component_path_entry, DatasetReader):
                component_path = component_path_entry.name
            elif isinstance(component_path_entry, str):
                component_path = component_path_entry
            else:
                raise NotImplementedError("Unsupported type for component value")
            local_component_path = component_path
            url_components = urlparse(component_path)
            if not url_components.scheme:
                cache_data = False
                if url_components.path.startswith('/vsigs/'):
                    cache_data = True  # We should check if we run inside GCP ML Engine
                    use_tensorflow_io = True
                    component_path = url_components.path[6:]
                    component_path = "gs:/" + component_path
            else:
                if url_components.scheme == 'file':
                    local_component_path = url_components.path
                    use_tensorflow_io = False
                    cache_data = False

            with rasterio.Env(**env):
                if use_tensorflow_io:
                    real_path = component_path
                    data = IOUtils.open_file(real_path, "rb").read()
                    if cache_data:
                        hash = sha224(component_path.encode("utf8")).hexdigest()
                        hash_part = "/".join(list(hash)[:3])
                        dataset_path = os.path.join(self._temp_dir, hash_part)
                        if not IOUtils.file_exists(dataset_path):
                            IOUtils.recursive_create_dir(dataset_path)
                        dataset_path = os.path.join(dataset_path, os.path.basename(component_path))
                        if not IOUtils.file_exists(dataset_path):
                            f = IOUtils.open_file(dataset_path, "wb")
                            f.write(data)
                            f.close()
                        component_src = self.get_component_file_descriptor(dataset_path)
                    else:
                        with NamedTemporaryFile() as tmpfile:
                            tmpfile.write(data)
                            tmpfile.flush()
                            component_src = self.get_component_file_descriptor(tmpfile.name)
                else:
                    component_src = self.get_component_file_descriptor(local_component_path)
                new_components[component_name] = component_src

        # Trigger the generation of the dynamic components
        for component_name, component_path in new_components.items():
            if isinstance(component_path, RasterGenerator):
                new_components[component_name] = component_path(new_components)

        return entry_name, new_components
예제 #5
0
def train_keras(model_name,
                window_size,
                stride_size,
                model_config,
                mapping,
                train_datasets,
                validation_datasets,
                pre_callbacks=(),
                enable_multi_gpu=False,
                gpus=None,
                cpu_merge=True,
                cpu_relocation=False,
                batch_size=None,
                random_seed=None,
                ):
    log.info("Starting keras training")

    import tensorflow as tf

    # Seed initialization should happed as early as possible
    if random_seed is not None:
        log.info("Setting Tensorflow random seed to: %d", random_seed)
        tf.set_random_seed(random_seed)

    from keras.callbacks import EarlyStopping, TensorBoard, ReduceLROnPlateau
    from ..tools.callbacks import ModelCheckpoint, CSVLogger
    from keras.optimizers import Adam
    from ..tools.utils import import_model_builder
    from keras.models import load_model
    from keras.utils import multi_gpu_model

    if batch_size is None:
        batch_size = model_config.get("batch_size", None)
    model_path = model_config["model_path"]
    model_loss = model_config.get("loss", "categorical_crossentropy")
    log.info("Using loss: %s", model_loss)
    model_metrics = model_config.get("metrics", "accuracy")
    # Make code compatible with previous version
    format_converter = model_config.get("format_converter", CategoricalConverter(2))
    swap_axes = model_config["swap_axes"]
    train_epochs = model_config["train_epochs"]
    prefetch_queue_size = model_config.get("prefetch_queue_size", 10)
    input_channels = len(mapping["inputs"])
    include_last_classfication =model_config.get("include_classfication_layer",True)

    z_scaler = model_config.get('z_scaler',None)


    train_data = DataGenerator(train_datasets,
                               batch_size,
                               mapping["inputs"],
                               mapping["target"],
                               format_converter=format_converter,
                               swap_axes=swap_axes,
                               postprocessing_callbacks=pre_callbacks,
                               default_window_size=window_size,
                               default_stride_size=stride_size,z_scaler=z_scaler)

    train_data = ThreadedDataGenerator(train_data, queue_size=prefetch_queue_size)

    validation_data = DataGenerator(validation_datasets,
                                    batch_size,
                                    mapping["inputs"],
                                    mapping["target"],
                                    format_converter=format_converter,
                                    swap_axes=swap_axes,
                                    default_window_size=window_size,
                                    default_stride_size=stride_size,z_scaler=z_scaler)

    validation_data = ThreadedDataGenerator(validation_data, queue_size=prefetch_queue_size)

    model_builder, model_builder_custom_options = import_model_builder(model_config["model_builder"])
    model_builder_option = model_config.get("options", {})

    steps_per_epoch = getattr(model_config, "steps_per_epoch", len(train_data) // batch_size)
    validation_steps_per_epoch = getattr(model_config, "validation_steps_per_epoch", len(validation_data) // batch_size)

    log.info("Traing data has %d tiles", len(train_data))
    log.info("Validation data has %d tiles", len(validation_data))
    log.info("validation_steps_per_epoch: %d", validation_steps_per_epoch)
    log.info("steps_per_epoch: %d", steps_per_epoch)

    load_only_weights = model_config.get("load_only_weights", False)
    checkpoint = model_config.get("checkpoint", None)
    callbacks = []
    early_stopping = model_config.get("early_stopping", None)
    adaptive_lr = model_config.get("adaptive_lr", None)
    tensor_board = model_config.get("tensor_board", False)
    tb_log_dir = model_config.get("tb_log_dir", os.path.join("/tmp/", model_name))  # TensorBoard log directory
    tb_log_dir = tb_log_dir.format(model_name=model_name,
                                   time=str(time.time()),
                                   hostname=socket.gethostname(),
                                   user=getpass.getuser())
    keras_logging = model_config.get("log", None)
    if not keras_logging:
        log.info("Keras logging is disabled")
    else:
        csv_log_file = keras_logging.format(model_name=model_name,
                                            time=str(time.time()),
                                            hostname=socket.gethostname(),
                                            user=getpass.getuser())
        dir_head, dir_tail = os.path.split(csv_log_file)
        if dir_tail and not IOUtils.file_exists(dir_head):
            log.info("Creating directory: %s", dir_head)
            IOUtils.recursive_create_dir(dir_head)
        log.info("Logging training data to csv file: %s", csv_log_file)
        csv_logger = CSVLogger(csv_log_file, separator=',', append=False)
        callbacks.append(csv_logger)

    if tensor_board:
        log.info("Registering TensorBoard callback")
        log.info("Event log dir set to: {}".format(tb_log_dir))
        tb_callback = TensorBoard(log_dir=tb_log_dir, histogram_freq=0, write_graph=True, write_images=True)
        callbacks.append(tb_callback)
        log.info("To access TensorBoard run: tensorboard --logdir {} --port <port_number> --host <host_ip> ".format(
            tb_log_dir))

    if checkpoint:
        checkpoint_file = checkpoint["path"]
        log.info("Registering checkpoint callback")
        destination_file = checkpoint_file % {
            'model_name': model_name,
            'time': str(time.time()),
            'hostname': socket.gethostname(),
            'user': getpass.getuser()}
        dir_head, dir_tail = os.path.split(destination_file)
        if dir_tail and not IOUtils.file_exists(dir_head):
            log.info("Creating directory: %s", dir_head)
            IOUtils.recursive_create_dir(dir_head)
        log.info("Checkpoint data directed to: %s", destination_file)
        checkpoint_options = checkpoint.get("options", {})
        checkpoint_callback = ModelCheckpoint(destination_file, **checkpoint_options)
        callbacks.append(checkpoint_callback)

    log.info("Starting training")

    options = {
        'epochs': train_epochs,
        'callbacks': callbacks
    }

    if len(validation_data) > 0 and validation_steps_per_epoch:
        log.info("We have validation data")
        options['validation_data'] = validation_data
        options["validation_steps"] = validation_steps_per_epoch
        if early_stopping:
            log.info("Enabling early stopping %s", str(early_stopping))
            callback_early_stopping = EarlyStopping(**early_stopping)
            options["callbacks"].append(callback_early_stopping)
        if adaptive_lr:
            log.info("Enabling reduce lr on plateu: %s", str(adaptive_lr))
            callback_lr_loss = ReduceLROnPlateau(**adaptive_lr)
            options["callbacks"].append(callback_lr_loss)
    else:
        log.warn("No validation data available. Ignoring")

    final_model_location = model_path.format(model_name=model_name,
                                             time=str(time.time()),
                                             hostname=socket.gethostname(),
                                             user=getpass.getuser())
    log.info("Model path is %s", final_model_location)

    existing_model_location = None
    if IOUtils.file_exists(final_model_location):
        existing_model_location = final_model_location

    if existing_model_location is not None and not load_only_weights:
        log.info("Loading existing model from: %s", existing_model_location)
        custom_objects = {}
        if model_builder_custom_options is not None:
            custom_objects.update(model_builder_custom_options)
        if enable_multi_gpu:
            with tf.device('/cpu:0'):
                model = load_model(existing_model_location, custom_objects=custom_objects)
        else:
            model = load_model(existing_model_location, custom_objects=custom_objects)
            nr_classes = model_builder_option.get('nr_classes', None)

            if (not include_last_classfication) and nr_classes:
                model.layers.pop()
                l =  Conv2D(25, (1, 1), activation='softmax',name="conv_final")(model.layers[-1].output)
                layers = [ll for ll in model.layers]
                layers.append(l)
                m = Model (input=layers[0].input, output=layers[-1])
                model = m
        log.info("Model loaded!")
    else:
        log.info("Building model")
        model_options = model_builder_option
        model_options['n_channels'] = input_channels
        input_height, input_width = window_size
        model_options['input_width'] = model_builder_option.get('input_width', input_width)
        model_options['input_height'] = model_builder_option.get('input_height', input_height)
        activation = model_config.get('activation', None)
        if activation:
            model_options["activation"] = activation
        if enable_multi_gpu:
            with tf.device('/cpu:0'):
                model = model_builder(**model_options)
        else:
            model = model_builder(**model_options)
        log.info("Model built")
        if load_only_weights and existing_model_location is not None:
            log.info("Loading weights from %s", existing_model_location)
            model.load_weights(existing_model_location)
            log.info("Finished loading weights")
    optimiser = model_config.get("optimiser", None)
    if optimiser is None:
        log.info("No optimiser specified. Using default Adam")
        optimiser = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-8)

    if enable_multi_gpu:
        log.info("Using Keras Multi-GPU Training")
        fit_model = multi_gpu_model(model, gpus=gpus, cpu_merge=cpu_merge, cpu_relocation=cpu_relocation)
    else:
        log.info("Using Keras default GPU Training")
        fit_model = model

    log.info("Compiling model")
    fit_model.compile(loss=model_loss, optimizer=optimiser, metrics=model_metrics)
    log.info("Model compiled")
    model.summary()

    fit_model.fit_generator(train_data, steps_per_epoch, **options)

    log.info("Saving model to %s", os.path.abspath(final_model_location))
    dir_head, dir_tail = os.path.split(final_model_location)
    if dir_tail and not IOUtils.file_exists(dir_head):
        log.info("Creating directory: %s", dir_head)
        IOUtils.recursive_create_dir(dir_head)

    model.save(final_model_location)

    log.info("Done saving")
    log.info("Training completed")