示例#1
0
def main():
    args = docopt(__doc__)
    cfg = dk.load_config()
    tubs = args['--tubs']
    model = args['--model']
    model_type = args['--type']
    train(cfg, tubs, model, model_type)
示例#2
0
def auto_train(model_types, datasets):
    cfg = dk.load_config()

    for model_type in model_types:
        for dataset in datasets:
            model = f"models/model_{model_type}_{dataset}.h5"
            train(cfg, f"datasets/{dataset}", model, model_type)
示例#3
0
    def run(self, args):
        args = self.parse_args(args)
        args.tub = ','.join(args.tub)
        cfg = load_config(args.config)
        framework = args.framework if args.framework else cfg.DEFAULT_AI_FRAMEWORK

        if framework == 'tensorflow':
            from donkeycar.pipeline.training import train
            train(cfg, args.tub, args.model, args.type)
        elif framework == 'pytorch':
            from donkeycar.parts.pytorch.torch_train import train
            train(cfg,
                  args.tub,
                  args.model,
                  args.type,
                  checkpoint_path=args.checkpoint)
        else:
            print(
                "Unrecognized framework: {}. Please specify one of 'tensorflow' or 'pytorch'"
                .format(framework))
示例#4
0
def main():
    args = docopt(__doc__)
    cfg = dk.load_config()
    tubs = args['--tubs']
    model = args['--model']
    model_type = args['--type']
    framework = args.get('--framework', 'tf')

    if framework == 'tf':
        from donkeycar.pipeline.training import train
        train(cfg, tubs, model, model_type)
    elif framework == 'torch':
        from donkeycar.parts.pytorch.torch_train import train
        checkpoint_path = args.get('--checkpoint', None)

        train(cfg, tubs, model, model_type, checkpoint_path=checkpoint_path)
    else:
        print(
            "Unrecognized framework: {}. Please specify one of 'tf' or 'torch'"
            .format(framework))
示例#5
0
    def run(self, args):
        args = self.parse_args(args)
        args.tub = ','.join(args.tub)
        cfg = load_config(args.config)
        framework = args.framework if args.framework \
            else getattr(cfg, 'DEFAULT_AI_FRAMEWORK', 'tensorflow')

        if framework == 'tensorflow':
            from donkeycar.pipeline.training import train
            train(cfg, args.tub, args.model, args.type, args.transfer,
                  args.comment)
        elif framework == 'pytorch':
            from donkeycar.parts.pytorch.torch_train import train
            train(cfg,
                  args.tub,
                  args.model,
                  args.type,
                  checkpoint_path=args.checkpoint)
        else:
            print(
                f"Unrecognized framework: {framework}. Please specify one of "
                f"'tensorflow' or 'pytorch'")
示例#6
0
def test_train(config: Config, data: Data) -> None:
    """
    Testing convergence of the linear an categorical models
    :param config:          donkey config
    :param data:            test case data
    :return:                None
    """
    def pilot_path(name):
        pilot_name = f'pilot_{name}.h5'
        return os.path.join(config.MODELS_PATH, pilot_name)

    if data.pretrained:
        config.LATENT_TRAINED = pilot_path(data.pretrained)
    tub_dir = config.DATA_PATH
    history = train(config, tub_dir, pilot_path(data.name), data.type)
    loss = history.history['loss']
    # check loss is converging
    assert loss[-1] < loss[0] * data.convergence
示例#7
0
 def train_call(self, model_type, *args):
     # remove car directory from path
     tub_path = tub_screen().ids.tub_loader.tub.base_path
     transfer = self.ids.transfer_spinner.text
     if transfer != 'Choose transfer model':
         transfer = os.path.join(self.config.MODELS_PATH, transfer + '.h5')
     else:
         transfer = None
     try:
         history = train(self.config, tub_paths=tub_path,
                         model_type=model_type,
                         transfer=transfer,
                         comment=self.ids.comment.text)
         self.ids.status.text = f'Training completed.'
         self.ids.train_button.state = 'normal'
         self.ids.transfer_spinner.text = 'Choose transfer model'
         self.reload_database()
     except Exception as e:
         self.ids.status.text = f'Train error {e}'
示例#8
0
def test_train(config: Config, data: Data) -> None:
    """
    Testing convergence of the linear an categorical models
    :param config:          donkey config
    :param data:            test case data
    :return:                None
    """
    def pilot_path(name):
        pilot_name = f'pilot_{name}.h5'
        return os.path.join(config.MODELS_PATH, pilot_name)

    if data.pretrained:
        config.LATENT_TRAINED = pilot_path(data.pretrained)
    tub_dir = config.DATA_PATH_ALL if data.type in full_tub else \
        config.DATA_PATH
    if data.preprocess == 'aug':
        add_augmentation_to_config(config)
    elif data.preprocess == 'trans':
        add_transformation_to_config(config)

    history = train(config, tub_dir, pilot_path(data.name), data.type)
    loss = history.history['loss']
    # check loss is converging
    assert loss[-1] < loss[0] * data.convergence