コード例 #1
0
    def __init__(self):
        logger.info('Will init all required to train.')
        self.helper = sly.TaskHelperTrain()

        self._determine_settings()
        self._determine_model_classes()
        self._determine_out_config()
        self._construct_data_dicts()
コード例 #2
0
ファイル: train.py プロジェクト: onisimchukv/supervisely
    def __init__(self):
        logger.info('Will init all required to train.')
        self.helper = sly.TaskHelperTrain()

        self._determine_settings()
        self._determine_model_classes()
        self._determine_out_config()

        self._construct_data_dicts()
        self._make_yolo_train_config()
        self._define_initializing_params()
        self._create_checkpoints_dir()
コード例 #3
0
ファイル: train.py プロジェクト: onisimchukv/supervisely
    def __init__(self):
        logger.info('Will init all required to train.')
        self.helper = sly.TaskHelperTrain()

        self._determine_settings()
        self._determine_model_classes()
        self._determine_out_config()
        self._construct_and_fill_model()
        self._construct_criterion_and_metrics()

        self._construct_data_loaders()
        self.epoch_flt = 0  # real progress

        logger.info('Model is ready to train.')
コード例 #4
0
    def __init__(self):
        logger.info('Will init all required to train.')
        self.helper = sly.TaskHelperTrain()

        self._determine_settings()
        self.device_ids = sly.remap_gpu_devices(self.config['gpu_devices'])
        tf.set_random_seed(RANDOM_SEED)  # @TODO: from config

        self._determine_model_classes()
        self._determine_out_config()
        self._explore_input_project()
        self._tf_common_init()
        self._load_weights()
        self.epoch_flt = 0  # real progress

        logger.info('Model is ready to train.')
コード例 #5
0
ファイル: train.py プロジェクト: yunweidashuju/supervisely
def main():
    # Please note that auxiliary methods from sly (supervisely_lib) use supervisely_lib.logger to format output.
    # So don't replace formatters or handlers of the logger.
    # One may use other loggers or simple prints for other output, but it's recommended to use supervisely_lib.logger.
    logger.info('Hello ML world')
    print('Glad to see u')

    # TaskHelperTrain contains almost all needed to run training as Supervisely task,
    # including task settings and paths to data and models.
    task_helper = sly.TaskHelperTrain()

    # All settings and parameters are passed to task in json file.
    # Content of the file is entirely dependent on model implementation.
    training_settings = task_helper.task_settings
    logger.info('Task settings are read',
                extra={'task_settings': training_settings})
    cnt_epochs = training_settings[
        "epochs"]  # in the fake model we want cnt of epochs
    cnt_iters_per_epoch = training_settings["iters_per_epoch"]

    # Let's imitate model weights loading.
    # Task acquires directory with input model weights (e.g. to continue training or to initialize some parameters).
    # Content of the directory is entirely dependent on model implementation.
    model_dir = task_helper.paths.model_dir
    if task_helper.model_dir_is_empty():
        model = create_fake_model()
        logger.info('Model created from scratch')
    else:
        model = load_fake_model(model_dir)
        logger.info('Init model weights are loaded',
                    extra={'model_dir': model_dir})

    # We will save weights of trained model (checkpoints) into directories provided by the checkpoints_saver.
    checkpoints_saver = task_helper.checkpoints_saver
    logger.info('Ready to save checkpoints',
                extra={'results_dir': task_helper.paths.results_dir})

    # Let's imitate reading input project with training data.
    # Of course in real implementations it is usually wrapped in some data loaders which are executed in parallel.
    project_meta = task_helper.in_project_meta  # Project meta contains list of project classes.
    project_dir = task_helper.paths.project_dir
    project_fs = sly.ProjectFS.from_disk_dir_project(project_dir)
    # ProjectFS enlists all samples (image/annotation pairs) in input project.
    for item_descr in project_fs:
        logger.info('Processing input sample',
                    extra={
                        'dataset': item_descr.ds_name,
                        'image_name': item_descr.image_name
                    })

        # Open some image...
        img = cv2.imread(item_descr.img_path)
        logger.info('Read image from input project',
                    extra={
                        'width': img.shape[1],
                        'height': img.shape[0]
                    })

        # And read corresponding annotation...
        ann_packed = sly.json_load(item_descr.ann_path)
        ann = sly.Annotation.from_packed(ann_packed, project_meta)
        logger.info('Read annotation from input project',
                    extra={
                        'object_cnt': len(ann['objects']),
                        'tags': ann['tags']
                    })

    # We are to report progress of task over sly.ProgressCounter if we want to observe the progress in web panel.
    # In fact one task may report progress for some sequential (not nested) subtasks,
    # but here we will report training progress only.
    progress = sly.progress_counter_train(cnt_epochs, cnt_iters_per_epoch)

    epoch_flt = 0
    for epoch in range(cnt_epochs):
        logger.info("Epoch started", extra={'epoch': epoch})

        for train_iter in range(cnt_iters_per_epoch):
            logger.info('Some forward-backward pass...')
            time.sleep(1)

            progress.iter_done_report(
            )  # call it after every iteration to report progress
            epoch_flt = sly.epoch_float(epoch, train_iter + 1,
                                        cnt_iters_per_epoch)

            # And we are to report some metrics if we want to observe those amazing charts in web panel.
            # Regrettably, only the fixed metric types may be displayed now: 'loss', 'accuracy' and 'dice'.
            metric_values_train = {
                'loss': random.random(),
                'my_metric': random.uniform(0, 100)
            }
            sly.report_metrics_training(epoch_flt, metric_values_train)

        logger.info("Epoch finished", extra={'epoch': epoch})

        # Validation is not necessary but may be performed. So let's imitate validation...
        logger.info("Validation...")
        time.sleep(2)
        # Metrics for validation may also be reported.
        metric_values_val = {
            'loss': random.random(),
            'my_metric': random.uniform(0, 20)
        }
        sly.report_metrics_validation(epoch_flt, metric_values_val)

        # Save trained model weights when you want.
        # Model weights (checkpoint) should be written into directory provided by sly checkpoints_saver.
        # Content of the directory is entirely dependent on model implementation.
        cur_checkpoint_dir = checkpoints_saver.get_dir_to_write()
        dump_fake_model(cur_checkpoint_dir, model)
        checkpoints_saver.saved(is_best=True,
                                optional_data={
                                    'epoch': epoch_flt,
                                    'val_metrics': metric_values_val
                                })
        # It is necessary to call checkpoints_saver.saved after saving.
        # By default new model will be created from the best checkpoint over the whole training
        # (which is determined by "is_best" flag).
        # Some optional info may be provided. It will be linked with the checkpoint
        # and may help to distinguish checkpoints from same training.

    # Thank you for your patience.
    logger.info('Training finished')