Ejemplo n.º 1
0
    def train(self, tmp_dir):
        from rastervision.backend.keras_classification.commands.train \
            import _train

        dataset_files = DatasetFiles(self.config.training_data_uri, tmp_dir)
        dataset_files.download()

        model_files = ModelFiles(
            self.config.training_output_uri,
            tmp_dir,
            replace_model=self.config.train_options.replace_model)
        model_paths = model_files.download_backend_config(
            self.config.pretrained_model_uri, self.config.kc_config,
            dataset_files, self.class_map)
        backend_config_path, pretrained_model_path = model_paths

        # Get output from potential previous run so we can resume training.
        if not self.config.train_options.replace_model:
            sync_from_dir(self.config.training_output_uri,
                          model_files.base_dir)

        sync = start_sync(
            model_files.base_dir,
            self.config.training_output_uri,
            sync_interval=self.config.train_options.sync_interval)
        with sync:
            do_monitoring = self.config.train_options.do_monitoring
            _train(backend_config_path, pretrained_model_path, do_monitoring)

        # Perform final sync
        sync_to_dir(model_files.base_dir,
                    self.config.training_output_uri,
                    delete=True)
Ejemplo n.º 2
0
    def train(self, tmp_dir):
        training_package = TrainingPackage(self.config.training_data_uri,
                                           self.config, tmp_dir,
                                           self.partition_id)
        # Download training data and update config file.
        training_package.download_data()
        config_path = training_package.download_config(self.class_map)

        # Setup output dirs.
        output_dir = get_local_path(self.config.training_output_uri, tmp_dir)
        make_dir(output_dir)

        # Get output from potential previous run so we can resume training.
        if not self.config.train_options.replace_model:
            sync_from_dir(self.config.training_output_uri, output_dir)
        else:
            for f in os.listdir(output_dir):
                if not f.startswith('command-config'):
                    path = os.path.join(output_dir, f)
                    if os.path.isfile(path):
                        os.remove(path)
                    else:
                        shutil.rmtree(path)

        local_config_path = os.path.join(output_dir, 'pipeline.config')
        shutil.copy(config_path, local_config_path)

        model_main_py = self.config.script_locations.model_main_uri
        export_py = self.config.script_locations.export_uri

        # Train model and sync output periodically.
        sync = start_sync(
            output_dir,
            self.config.training_output_uri,
            sync_interval=self.config.train_options.sync_interval)
        with sync:
            train(local_config_path,
                  output_dir,
                  self.config.get_num_steps(),
                  model_main_py=model_main_py,
                  do_monitoring=self.config.train_options.do_monitoring)

        export_inference_graph(
            output_dir,
            local_config_path,
            output_dir,
            fine_tune_checkpoint_name=self.config.fine_tune_checkpoint_name,
            export_py=export_py)

        # Perform final sync
        sync_to_dir(output_dir, self.config.training_output_uri)
Ejemplo n.º 3
0
    def train(self, tmp_dir: str) -> None:
        """Train a DeepLab model the task and backend config.

        Args:
            tmp_dir: (str) temporary directory to use

        Returns:
             None
        """
        train_py = self.backend_config.script_locations.train_py
        eval_py = self.backend_config.script_locations.eval_py
        export_py = self.backend_config.script_locations.export_py

        # Setup local input and output directories
        log.info('Setting up local input and output directories')
        train_logdir = self.backend_config.training_output_uri
        train_logdir_local = get_local_path(train_logdir, tmp_dir)
        dataset_dir = get_record_dir(self.backend_config.training_data_uri,
                                     TRAIN)
        dataset_dir_local = get_local_path(dataset_dir, tmp_dir)
        make_dir(tmp_dir)
        make_dir(train_logdir_local)
        make_dir(dataset_dir_local)

        # Download training data
        log.info('Downloading training data')
        for i, record_file in enumerate(list_paths(dataset_dir)):
            download_if_needed(record_file, tmp_dir)

        # Download and untar initial checkpoint.
        log.info('Downloading and untarring initial checkpoint')
        tf_initial_checkpoints_uri = self.backend_config.pretrained_model_uri
        download_if_needed(tf_initial_checkpoints_uri, tmp_dir)
        tfic_tarball = get_local_path(tf_initial_checkpoints_uri, tmp_dir)
        tfic_dir = os.path.dirname(tfic_tarball)
        with tarfile.open(tfic_tarball, 'r:gz') as tar:
            tar.extractall(tfic_dir)
        tfic_ckpt = glob.glob('{}/*/*.index'.format(tfic_dir))[0]
        tfic_ckpt = tfic_ckpt[0:-len('.index')]

        # Restart support
        train_restart_dir = self.backend_config.train_options.train_restart_dir
        if type(train_restart_dir) is not str or len(train_restart_dir) == 0:
            train_restart_dir = train_logdir

        # Get output from potential previous run so we can resume training.
        if type(train_restart_dir) is str and len(
                train_restart_dir
        ) > 0 and not self.backend_config.train_options.replace_model:
            sync_from_dir(train_restart_dir, train_logdir_local)
        else:
            if self.backend_config.train_options.replace_model:
                if os.path.exists(train_logdir_local):
                    shutil.rmtree(train_logdir_local)
                make_dir(train_logdir_local)

        # Periodically synchronize with remote
        sync = start_sync(
            train_logdir_local,
            train_logdir,
            sync_interval=self.backend_config.train_options.sync_interval)

        with sync:
            # Setup TFDL config
            tfdl_config = json_format.ParseDict(
                self.backend_config.tfdl_config, TrainingParametersMsg())
            log.info('tfdl_config={}'.format(tfdl_config))
            log.info('Training steps={}'.format(
                tfdl_config.training_number_of_steps))

            # Additional training options
            max_class = max(
                list(map(lambda c: c.id, self.class_map.get_items())))
            num_classes = len(self.class_map.get_items())
            num_classes = max(max_class, num_classes) + 1
            (train_args, train_env) = get_training_args(
                train_py, train_logdir_local, tfic_ckpt, dataset_dir_local,
                num_classes, tfdl_config)

            # Start training
            log.info('Starting training process')
            log.info(' '.join(train_args))
            train_process = Popen(train_args, env=train_env)
            terminate_at_exit(train_process)

            if self.backend_config.train_options.do_monitoring:
                # Start tensorboard
                log.info('Starting tensorboard process')
                tensorboard_process = Popen(
                    ['tensorboard', '--logdir={}'.format(train_logdir_local)])
                terminate_at_exit(tensorboard_process)

            if self.backend_config.train_options.do_eval:
                # Start eval script
                log.info('Starting eval script')
                eval_logdir = train_logdir_local
                eval_args = get_evaluation_args(eval_py, train_logdir_local,
                                                dataset_dir_local, eval_logdir,
                                                tfdl_config)
                eval_process = Popen(eval_args, env=train_env)
                terminate_at_exit(eval_process)

            # Wait for training and tensorboard
            log.info('Waiting for training and tensorboard processes')
            train_process.wait()
            if self.backend_config.train_options.do_monitoring:
                tensorboard_process.terminate()

            # Export frozen graph
            log.info(
                'Exporting frozen graph ({}/model)'.format(train_logdir_local))
            export_args = get_export_args(export_py, train_logdir_local,
                                          num_classes, tfdl_config)
            export_process = Popen(export_args)
            terminate_at_exit(export_process)
            export_process.wait()

            # Package up the model files for usage as fine tuning checkpoints
            fine_tune_checkpoint_name = self.backend_config.fine_tune_checkpoint_name
            latest_checkpoints = get_latest_checkpoint(train_logdir_local)
            model_checkpoint_files = glob.glob(
                '{}*'.format(latest_checkpoints))
            inference_graph_path = os.path.join(train_logdir_local, 'model')

            with RVConfig.get_tmp_dir() as tmp_dir:
                model_dir = os.path.join(tmp_dir, fine_tune_checkpoint_name)
                make_dir(model_dir)
                model_tar = os.path.join(
                    train_logdir_local,
                    '{}.tar.gz'.format(fine_tune_checkpoint_name))
                shutil.copy(inference_graph_path,
                            '{}/frozen_inference_graph.pb'.format(model_dir))
                for path in model_checkpoint_files:
                    shutil.copy(path, model_dir)
                with tarfile.open(model_tar, 'w:gz') as tar:
                    tar.add(model_dir, arcname=os.path.basename(model_dir))

        # Perform final sync
        sync_to_dir(train_logdir_local, train_logdir, delete=False)
Ejemplo n.º 4
0
    def train(self, tmp_dir):
        """Train a model.
        """
        from tile2vec.datasets import triplet_dataloader
        from tile2vec.tilenet import make_tilenet
        from tile2vec.training import train_triplet_epoch
        import torch
        from torch import optim

        # TODO: Config
        batch_size = self.backend_config.batch_size
        epochs = self.backend_config.epochs
        epoch_size = self.backend_config.epoch_size

        img_type = 'naip'
        bands = 4
        augment = True

        shuffle = False
        num_workers = 1

        z_dim = 512

        lr = 1e-3
        betas = (0.5, 0.999)

        margin = 10
        l2 = 0.01
        print_every = 10000
        save_models = False

        sync_interval = 60

        scenes = list(
            map(lambda s: s.create_scene(self.task_config, tmp_dir),
                self.backend_config.scenes))

        # Load dataset
        dataloader = triplet_dataloader(img_type,
                                        scenes,
                                        self.task_config.chip_size,
                                        augment=augment,
                                        batch_size=batch_size,
                                        epoch_size=epoch_size,
                                        shuffle=shuffle,
                                        num_workers=num_workers)
        print('Dataloader set up complete.')

        # Setup TileNet
        in_channels = len(
            self.backend_config.scenes[0].raster_source.channel_order)
        if in_channels < 1:
            raise Exception("Must set channel order on RasterSource")

        TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
        if self.cuda:
            TileNet.cuda()

        if self.backend_config.pretrained_model_uri:
            model_path = download_if_needed(
                self.backend_config.pretrained_model_uri, tmp_dir)
            TileNet.load_state_dict(torch.load(model_path))

        TileNet.train()
        print('TileNet set up complete.')

        # Setup Optimizer
        optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=betas)
        print('Optimizer set up complete.')

        model_dir = os.path.join(tmp_dir, 'training/model_files')

        make_dir(model_dir)

        sync = start_sync(model_dir,
                          self.backend_config.training_output_uri,
                          sync_interval=sync_interval)
        model_path = None
        with sync:
            print('Begin training.................')
            t0 = time()
            for epoch in range(0, epochs):
                print('Epoch {}'.format(epoch))
                (avg_loss, avg_l_n, avg_l_d,
                 avg_l_nd) = train_triplet_epoch(TileNet,
                                                 self.cuda,
                                                 dataloader,
                                                 optimizer,
                                                 epoch + 1,
                                                 margin=margin,
                                                 l2=l2,
                                                 print_every=print_every,
                                                 t0=t0)

                if epoch % self.backend_config.epoch_save_rate == 0 or epoch + 1 == epochs:
                    print('Saving model for epoch {}'.format(epoch))
                    model_path = os.path.join(
                        model_dir, 'TileNet_epoch{}.ckpt'.format(epoch))
                    torch.save(TileNet.state_dict(), model_path)
                else:
                    print('Skipping model save for epoch {}'.format(epoch))

        if model_path:
            shutil.copy(model_path, os.path.join(model_dir, 'model.ckpt'))
            # Perform final sync
            sync_to_dir(model_dir,
                        self.backend_config.training_output_uri,
                        delete=True)