def use_orcanet(): temp_folder = "output/" os.mkdir(temp_folder) make_dummy_data(temp_folder) list_file = "example_list.toml" organizer = Organizer(temp_folder + "sum_model", list_file) organizer.cfg.train_logger_display = 10 model = make_dummy_model() organizer.train_and_validate(model, epochs=3) organizer.predict()
def train(directory, list_file=None, config_file=None, model_file=None, to_epoch=None): from orcanet.core import Organizer from orcanet.model_builder import ModelBuilder from orcanet.misc import find_file orga = Organizer(directory, list_file, config_file, tf_log_level=1) if orga.io.get_latest_epoch() is None: # Start of training print("Building new model") if model_file is None: model_file = find_file(directory, "model.toml") model = ModelBuilder(model_file).build(orga, verbose=False) else: model = None return orga.train_and_validate(model=model, to_epoch=to_epoch)
def orca_train(output_folder, list_file, config_file, model_file, recompile_model=False): """ Run orga.train with predefined ModelBuilder networks using a parser. Parameters ---------- output_folder : str Path to the folder where everything gets saved to, e.g. the summary log file, the plots, the trained models, etc. list_file : str Path to a list file which contains pathes to all the h5 files that should be used for training and validation. config_file : str Path to a .toml file which overwrites some of the default settings for training and validating a model. model_file : str Path to a file with parameters to build a model of a predefined architecture with OrcaNet. recompile_model : bool If the model should be recompiled or not. Necessary, if e.g. the loss_weights are changed during the training. """ # Set up the Organizer with the input data orga = Organizer(output_folder, list_file, config_file, tf_log_level=1) # Load in the orga sample-, label-, and dataset-modifiers, as well as # the custom objects update_objects(orga, model_file) # If this is the start of the training, a compiled model needs to be # handed to the orga.train function if orga.io.get_latest_epoch() is None: # The ModelBuilder class allows to construct models from a toml file, # adapted to the datasets in the orga instance. Its modifiers will # be taken into account for this builder = ModelBuilder(model_file) model = builder.build(orga, log_comp_opts=True) elif recompile_model is True: builder = ModelBuilder(model_file) path_of_model = orga.io.get_model_path(-1, -1) model = ks.models.load_model( path_of_model, custom_objects=orga.cfg.get_custom_objects()) print("Recompiling the saved model") model = builder.compile_model( model, custom_objects=orga.cfg.get_custom_objects()) builder.log_model_properties(orga) else: model = None try: # Use a custom LR schedule user_lr = orga.cfg.learning_rate lr = orca_learning_rates(user_lr, orga.io.get_no_of_files("train")) orga.cfg.learning_rate = lr except NameError: pass # start the training orga.train_and_validate(model=model)