def get_model_train_details(cfg: Config, database: PilotDatabase, model: str = None, model_type: str = None) \ -> Tuple[str, int, str, bool]: """ Returns automatic model name if none is given :param cfg: donkey config :param database: model database with existing training data :param model: model path :param model_type: type of model, like 'linear', 'tflite_linear', etc :return: tuple of model path, number, training type, and if tflite is requested """ if not model_type: model_type = cfg.DEFAULT_MODEL_TYPE train_type = model_type is_tflite = False if 'tflite_' in train_type: train_type = train_type.replace('tflite_', '') is_tflite = True model_num = 0 if not model: model_path, model_num = database.generate_model_name() else: _, model_ext = os.path.splitext(model) model_path = model is_tflite = model_ext == '.tflite' return model_path, model_num, train_type, is_tflite
def get_model_train_details(database: PilotDatabase, model: str = None) \ -> Tuple[str, int]: if not model: model_name, model_num = database.generate_model_name() else: model_name, model_num = os.path.abspath(model), 0 return model_name, model_num
def get_model_train_details(cfg: Config, database: PilotDatabase, model: str = None, model_type: str = None) \ -> Tuple[str, int, str, bool]: if not model_type: model_type = cfg.DEFAULT_MODEL_TYPE train_type = model_type is_tflite = False if 'tflite_' in train_type: train_type = train_type.replace('tflite_', '') is_tflite = True model_num = 0 if not model: model_name, model_num = database.generate_model_name() else: model_name, model_ext = os.path.splitext(model) is_tflite = model_ext == '.tflite' return model_name, model_num, train_type, is_tflite
def train(cfg: Config, tub_paths: str, model: str = None, model_type: str = None, transfer: str = None, comment: str = None) \ -> tf.keras.callbacks.History: """ Train the model """ database = PilotDatabase(cfg) model_name, model_num, train_type, is_tflite = \ get_model_train_details(cfg, database, model, model_type) output_path = os.path.join(cfg.MODELS_PATH, model_name + '.h5') kl = get_model_by_type(train_type, cfg) if transfer: kl.load(transfer) if cfg.PRINT_MODEL_SUMMARY: print(kl.model.summary()) tubs = tub_paths.split(',') all_tub_paths = [os.path.expanduser(tub) for tub in tubs] dataset = TubDataset(cfg, all_tub_paths) training_records, validation_records = dataset.train_test_split() print(f'Records # Training {len(training_records)}') print(f'Records # Validation {len(validation_records)}') training_pipe = BatchSequence(kl, cfg, training_records, is_train=True) validation_pipe = BatchSequence(kl, cfg, validation_records, is_train=False) dataset_train = training_pipe.create_tf_data().prefetch( tf.data.experimental.AUTOTUNE) dataset_validate = validation_pipe.create_tf_data().prefetch( tf.data.experimental.AUTOTUNE) train_size = len(training_pipe) val_size = len(validation_pipe) assert val_size > 0, "Not enough validation data, decrease the batch " \ "size or add more data." history = kl.train(model_path=output_path, train_data=dataset_train, train_steps=train_size, batch_size=cfg.BATCH_SIZE, validation_data=dataset_validate, validation_steps=val_size, epochs=cfg.MAX_EPOCHS, verbose=cfg.VERBOSE_TRAIN, min_delta=cfg.MIN_DELTA, patience=cfg.EARLY_STOP_PATIENCE, show_plot=cfg.SHOW_PLOT) if is_tflite: tf_lite_model_path = f'{os.path.splitext(output_path)[0]}.tflite' keras_model_to_tflite(output_path, tf_lite_model_path) database_entry = { 'Number': model_num, 'Name': model_name, 'Type': str(kl), 'Tubs': tub_paths, 'Time': time(), 'History': history.history, 'Transfer': os.path.basename(transfer) if transfer else None, 'Comment': comment, 'Config': str(cfg) } database.add_entry(database_entry) database.write() return history
def reload_database(self): if self.config: self.database = PilotDatabase(self.config)
class TrainScreen(Screen): """ Class showing the training screen. """ config = ObjectProperty(force_dispatch=True, allownone=True) database = ObjectProperty() pilot_df = ObjectProperty(force_dispatch=True) tub_df = ObjectProperty(force_dispatch=True) def train_call(self, model_type, *args): # remove car directory from path tub_path = tub_screen().ids.tub_loader.tub.base_path transfer = self.ids.transfer_spinner.text if transfer != 'Choose transfer model': transfer = os.path.join(self.config.MODELS_PATH, transfer + '.h5') else: transfer = None try: history = train(self.config, tub_paths=tub_path, model_type=model_type, transfer=transfer, comment=self.ids.comment.text) self.ids.status.text = f'Training completed.' self.ids.train_button.state = 'normal' self.ids.transfer_spinner.text = 'Choose transfer model' self.reload_database() except Exception as e: self.ids.status.text = f'Train error {e}' def train(self, model_type): self.config.SHOW_PLOT = False Thread(target=self.train_call, args=(model_type,)).start() self.ids.status.text = f'Training started.' self.ids.comment.text = 'Comment' def set_config_attribute(self, input): try: val = json.loads(input) except ValueError: val = input att = self.ids.cfg_spinner.text.split(':')[0] setattr(self.config, att, val) self.ids.cfg_spinner.values = self.value_list() self.ids.status.text = f'Setting {att} to {val} of type ' \ f'{type(val).__name__}' def value_list(self): if self.config: return [f'{k}: {v}' for k, v in self.config.__dict__.items()] else: return ['select'] def on_config(self, obj, config): if self.config and self.ids: self.ids.cfg_spinner.values = self.value_list() self.reload_database() def reload_database(self): if self.config: self.database = PilotDatabase(self.config) def on_database(self, obj, database): if self.ids.check.state == 'down': self.pilot_df, self.tub_df = self.database.to_df_tubgrouped() self.ids.scroll_tubs.text = self.tub_df.to_string() else: self.pilot_df = self.database.to_df() self.tub_df = pd.DataFrame() self.ids.scroll_tubs.text = '' self.pilot_df.drop(columns=['History', 'Config'], errors='ignore', inplace=True) text = self.pilot_df.to_string(formatters=self.formatter()) self.ids.scroll_pilots.text = text values = ['Choose transfer model'] if not self.pilot_df.empty: values += self.pilot_df['Name'].tolist() self.ids.transfer_spinner.values = values @staticmethod def formatter(): def time_fmt(t): fmt = '%Y-%m-%d %H:%M:%S' return datetime.fromtimestamp(t).strftime(format=fmt) def transfer_fmt(model_name): return model_name.replace('.h5', '') return {'Time': time_fmt, 'Transfer': transfer_fmt}
def train(cfg: Config, tub_paths: str, model: str = None, model_type: str = None, transfer: str = None, comment: str = None) \ -> tf.keras.callbacks.History: """ Train the model """ database = PilotDatabase(cfg) if model_type is None: model_type = cfg.DEFAULT_MODEL_TYPE model_path, model_num = \ get_model_train_details(database, model) base_path = os.path.splitext(model_path)[0] kl = get_model_by_type(model_type, cfg) if transfer: kl.load(transfer) if cfg.PRINT_MODEL_SUMMARY: print(kl.interpreter.model.summary()) tubs = tub_paths.split(',') all_tub_paths = [os.path.expanduser(tub) for tub in tubs] dataset = TubDataset(config=cfg, tub_paths=all_tub_paths, seq_size=kl.seq_size()) training_records, validation_records \ = train_test_split(dataset.get_records(), shuffle=True, test_size=(1. - cfg.TRAIN_TEST_SPLIT)) print(f'Records # Training {len(training_records)}') print(f'Records # Validation {len(validation_records)}') # We need augmentation in validation when using crop / trapeze training_pipe = BatchSequence(kl, cfg, training_records, is_train=True) validation_pipe = BatchSequence(kl, cfg, validation_records, is_train=False) tune = tf.data.experimental.AUTOTUNE dataset_train = training_pipe.create_tf_data().prefetch(tune) dataset_validate = validation_pipe.create_tf_data().prefetch(tune) train_size = len(training_pipe) val_size = len(validation_pipe) ### training/validation length limit. Large validation datasets cause memory leaks. train_limit = cfg.TRAIN_LIMIT train_len = len(training_records) if train_limit is not None and train_len > train_limit: train_decrease = train_limit / train_len _train_size = math.ceil(train_size * train_decrease) print(f'train steps decrease from {train_size} to {_train_size}') train_size = _train_size val_limit = cfg.VALIDATION_LIMIT val_len = len(validation_records) if val_limit is not None and val_len > val_limit: val_decrease = val_limit / val_len _val_size = math.ceil(val_size * val_decrease) print(f'val steps decrease from {val_size} to {_val_size}') val_size = _val_size assert val_size > 0, "Not enough validation data, decrease the batch " \ "size or add more data." history = kl.train(model_path=model_path, train_data=dataset_train, train_steps=train_size, batch_size=cfg.BATCH_SIZE, validation_data=dataset_validate, validation_steps=val_size, epochs=cfg.MAX_EPOCHS, verbose=cfg.VERBOSE_TRAIN, min_delta=cfg.MIN_DELTA, use_early_stop=cfg.USE_EARLY_STOP, patience=cfg.EARLY_STOP_PATIENCE, show_plot=cfg.SHOW_PLOT) if getattr(cfg, 'CREATE_TF_LITE', True): tf_lite_model_path = f'{base_path}.tflite' keras_model_to_tflite(model_path, tf_lite_model_path) if getattr(cfg, 'CREATE_TENSOR_RT', False): # load h5 (ie. keras) model model_rt = load_model(model_path) # save in tensorflow savedmodel format (i.e. directory) model_rt.save(f'{base_path}.savedmodel') # pass savedmodel to the rt converter saved_model_to_tensor_rt(f'{base_path}.savedmodel', f'{base_path}.trt') database_entry = { 'Number': model_num, 'Name': os.path.basename(base_path), 'Type': str(kl), 'Tubs': tub_paths, 'Time': time(), 'History': history.history, 'Transfer': os.path.basename(transfer) if transfer else None, 'Comment': comment, 'Config': str(cfg) } database.add_entry(database_entry) database.write() return history