Esempio n. 1
0
def get_model_train_details(cfg: Config, database: PilotDatabase,
                            model: str = None, model_type: str = None) \
        -> Tuple[str, int, str, bool]:
    """
    Returns automatic model name if none is given
    :param cfg:         donkey config
    :param database:    model database with existing training data
    :param model:       model path
    :param model_type:  type of model, like 'linear', 'tflite_linear', etc
    :return:            tuple of model path, number, training type, and if
                        tflite is requested
    """
    if not model_type:
        model_type = cfg.DEFAULT_MODEL_TYPE
    train_type = model_type
    is_tflite = False
    if 'tflite_' in train_type:
        train_type = train_type.replace('tflite_', '')
        is_tflite = True
    model_num = 0
    if not model:
        model_path, model_num = database.generate_model_name()
    else:
        _, model_ext = os.path.splitext(model)
        model_path = model
        is_tflite = model_ext == '.tflite'

    return model_path, model_num, train_type, is_tflite
Esempio n. 2
0
def get_model_train_details(database: PilotDatabase, model: str = None) \
        -> Tuple[str, int]:
    if not model:
        model_name, model_num = database.generate_model_name()
    else:
        model_name, model_num = os.path.abspath(model), 0
    return model_name, model_num
Esempio n. 3
0
def get_model_train_details(cfg: Config, database: PilotDatabase,
                            model: str = None, model_type: str = None) \
        -> Tuple[str, int, str, bool]:
    if not model_type:
        model_type = cfg.DEFAULT_MODEL_TYPE
    train_type = model_type
    is_tflite = False
    if 'tflite_' in train_type:
        train_type = train_type.replace('tflite_', '')
        is_tflite = True
    model_num = 0
    if not model:
        model_name, model_num = database.generate_model_name()
    else:
        model_name, model_ext = os.path.splitext(model)
        is_tflite = model_ext == '.tflite'

    return model_name, model_num, train_type, is_tflite
Esempio n. 4
0
def train(cfg: Config, tub_paths: str, model: str = None,
          model_type: str = None, transfer: str = None, comment: str = None) \
        -> tf.keras.callbacks.History:
    """
    Train the model
    """
    database = PilotDatabase(cfg)
    model_name, model_num, train_type, is_tflite = \
        get_model_train_details(cfg, database, model, model_type)

    output_path = os.path.join(cfg.MODELS_PATH, model_name + '.h5')
    kl = get_model_by_type(train_type, cfg)
    if transfer:
        kl.load(transfer)
    if cfg.PRINT_MODEL_SUMMARY:
        print(kl.model.summary())

    tubs = tub_paths.split(',')
    all_tub_paths = [os.path.expanduser(tub) for tub in tubs]
    dataset = TubDataset(cfg, all_tub_paths)
    training_records, validation_records = dataset.train_test_split()
    print(f'Records # Training {len(training_records)}')
    print(f'Records # Validation {len(validation_records)}')

    training_pipe = BatchSequence(kl, cfg, training_records, is_train=True)
    validation_pipe = BatchSequence(kl,
                                    cfg,
                                    validation_records,
                                    is_train=False)

    dataset_train = training_pipe.create_tf_data().prefetch(
        tf.data.experimental.AUTOTUNE)
    dataset_validate = validation_pipe.create_tf_data().prefetch(
        tf.data.experimental.AUTOTUNE)
    train_size = len(training_pipe)
    val_size = len(validation_pipe)

    assert val_size > 0, "Not enough validation data, decrease the batch " \
                         "size or add more data."

    history = kl.train(model_path=output_path,
                       train_data=dataset_train,
                       train_steps=train_size,
                       batch_size=cfg.BATCH_SIZE,
                       validation_data=dataset_validate,
                       validation_steps=val_size,
                       epochs=cfg.MAX_EPOCHS,
                       verbose=cfg.VERBOSE_TRAIN,
                       min_delta=cfg.MIN_DELTA,
                       patience=cfg.EARLY_STOP_PATIENCE,
                       show_plot=cfg.SHOW_PLOT)

    if is_tflite:
        tf_lite_model_path = f'{os.path.splitext(output_path)[0]}.tflite'
        keras_model_to_tflite(output_path, tf_lite_model_path)

    database_entry = {
        'Number': model_num,
        'Name': model_name,
        'Type': str(kl),
        'Tubs': tub_paths,
        'Time': time(),
        'History': history.history,
        'Transfer': os.path.basename(transfer) if transfer else None,
        'Comment': comment,
        'Config': str(cfg)
    }
    database.add_entry(database_entry)
    database.write()

    return history
Esempio n. 5
0
 def reload_database(self):
     if self.config:
         self.database = PilotDatabase(self.config)
Esempio n. 6
0
class TrainScreen(Screen):
    """ Class showing the training screen. """
    config = ObjectProperty(force_dispatch=True, allownone=True)
    database = ObjectProperty()
    pilot_df = ObjectProperty(force_dispatch=True)
    tub_df = ObjectProperty(force_dispatch=True)

    def train_call(self, model_type, *args):
        # remove car directory from path
        tub_path = tub_screen().ids.tub_loader.tub.base_path
        transfer = self.ids.transfer_spinner.text
        if transfer != 'Choose transfer model':
            transfer = os.path.join(self.config.MODELS_PATH, transfer + '.h5')
        else:
            transfer = None
        try:
            history = train(self.config, tub_paths=tub_path,
                            model_type=model_type,
                            transfer=transfer,
                            comment=self.ids.comment.text)
            self.ids.status.text = f'Training completed.'
            self.ids.train_button.state = 'normal'
            self.ids.transfer_spinner.text = 'Choose transfer model'
            self.reload_database()
        except Exception as e:
            self.ids.status.text = f'Train error {e}'

    def train(self, model_type):
        self.config.SHOW_PLOT = False
        Thread(target=self.train_call, args=(model_type,)).start()
        self.ids.status.text = f'Training started.'
        self.ids.comment.text = 'Comment'

    def set_config_attribute(self, input):
        try:
            val = json.loads(input)
        except ValueError:
            val = input

        att = self.ids.cfg_spinner.text.split(':')[0]
        setattr(self.config, att, val)
        self.ids.cfg_spinner.values = self.value_list()
        self.ids.status.text = f'Setting {att} to {val} of type ' \
                               f'{type(val).__name__}'

    def value_list(self):
        if self.config:
            return [f'{k}: {v}' for k, v in self.config.__dict__.items()]
        else:
            return ['select']

    def on_config(self, obj, config):
        if self.config and self.ids:
            self.ids.cfg_spinner.values = self.value_list()
            self.reload_database()

    def reload_database(self):
        if self.config:
            self.database = PilotDatabase(self.config)

    def on_database(self, obj, database):
        if self.ids.check.state == 'down':
            self.pilot_df, self.tub_df = self.database.to_df_tubgrouped()
            self.ids.scroll_tubs.text = self.tub_df.to_string()
        else:
            self.pilot_df = self.database.to_df()
            self.tub_df = pd.DataFrame()
            self.ids.scroll_tubs.text = ''

        self.pilot_df.drop(columns=['History', 'Config'], errors='ignore',
                           inplace=True)
        text = self.pilot_df.to_string(formatters=self.formatter())
        self.ids.scroll_pilots.text = text
        values = ['Choose transfer model']
        if not self.pilot_df.empty:
            values += self.pilot_df['Name'].tolist()
        self.ids.transfer_spinner.values = values

    @staticmethod
    def formatter():
        def time_fmt(t):
            fmt = '%Y-%m-%d %H:%M:%S'
            return datetime.fromtimestamp(t).strftime(format=fmt)

        def transfer_fmt(model_name):
            return model_name.replace('.h5', '')

        return {'Time': time_fmt, 'Transfer': transfer_fmt}
Esempio n. 7
0
def train(cfg: Config, tub_paths: str, model: str = None,
          model_type: str = None, transfer: str = None, comment: str = None) \
        -> tf.keras.callbacks.History:
    """
    Train the model
    """
    database = PilotDatabase(cfg)
    if model_type is None:
        model_type = cfg.DEFAULT_MODEL_TYPE
    model_path, model_num = \
        get_model_train_details(database, model)

    base_path = os.path.splitext(model_path)[0]
    kl = get_model_by_type(model_type, cfg)
    if transfer:
        kl.load(transfer)
    if cfg.PRINT_MODEL_SUMMARY:
        print(kl.interpreter.model.summary())

    tubs = tub_paths.split(',')
    all_tub_paths = [os.path.expanduser(tub) for tub in tubs]
    dataset = TubDataset(config=cfg,
                         tub_paths=all_tub_paths,
                         seq_size=kl.seq_size())
    training_records, validation_records \
        = train_test_split(dataset.get_records(), shuffle=True,
                           test_size=(1. - cfg.TRAIN_TEST_SPLIT))
    print(f'Records # Training {len(training_records)}')
    print(f'Records # Validation {len(validation_records)}')

    # We need augmentation in validation when using crop / trapeze
    training_pipe = BatchSequence(kl, cfg, training_records, is_train=True)
    validation_pipe = BatchSequence(kl,
                                    cfg,
                                    validation_records,
                                    is_train=False)
    tune = tf.data.experimental.AUTOTUNE
    dataset_train = training_pipe.create_tf_data().prefetch(tune)
    dataset_validate = validation_pipe.create_tf_data().prefetch(tune)
    train_size = len(training_pipe)
    val_size = len(validation_pipe)

    ### training/validation length limit. Large validation datasets cause memory leaks.
    train_limit = cfg.TRAIN_LIMIT
    train_len = len(training_records)
    if train_limit is not None and train_len > train_limit:
        train_decrease = train_limit / train_len
        _train_size = math.ceil(train_size * train_decrease)
        print(f'train steps decrease from {train_size} to {_train_size}')
        train_size = _train_size

    val_limit = cfg.VALIDATION_LIMIT
    val_len = len(validation_records)
    if val_limit is not None and val_len > val_limit:
        val_decrease = val_limit / val_len
        _val_size = math.ceil(val_size * val_decrease)
        print(f'val steps decrease from {val_size} to {_val_size}')
        val_size = _val_size

    assert val_size > 0, "Not enough validation data, decrease the batch " \
                         "size or add more data."

    history = kl.train(model_path=model_path,
                       train_data=dataset_train,
                       train_steps=train_size,
                       batch_size=cfg.BATCH_SIZE,
                       validation_data=dataset_validate,
                       validation_steps=val_size,
                       epochs=cfg.MAX_EPOCHS,
                       verbose=cfg.VERBOSE_TRAIN,
                       min_delta=cfg.MIN_DELTA,
                       use_early_stop=cfg.USE_EARLY_STOP,
                       patience=cfg.EARLY_STOP_PATIENCE,
                       show_plot=cfg.SHOW_PLOT)

    if getattr(cfg, 'CREATE_TF_LITE', True):
        tf_lite_model_path = f'{base_path}.tflite'
        keras_model_to_tflite(model_path, tf_lite_model_path)

    if getattr(cfg, 'CREATE_TENSOR_RT', False):
        # load h5 (ie. keras) model
        model_rt = load_model(model_path)
        # save in tensorflow savedmodel format (i.e. directory)
        model_rt.save(f'{base_path}.savedmodel')
        # pass savedmodel to the rt converter
        saved_model_to_tensor_rt(f'{base_path}.savedmodel', f'{base_path}.trt')

    database_entry = {
        'Number': model_num,
        'Name': os.path.basename(base_path),
        'Type': str(kl),
        'Tubs': tub_paths,
        'Time': time(),
        'History': history.history,
        'Transfer': os.path.basename(transfer) if transfer else None,
        'Comment': comment,
        'Config': str(cfg)
    }
    database.add_entry(database_entry)
    database.write()

    return history