예제 #1
0
import Resources.training as r
from Models.erfh5_fullyConnected import S1140DryspotModelFCWide
from Pipeline.data_gather import get_filelist_within_folder_blacklisted
from Pipeline.data_loader_dryspot import DataloaderDryspots
from Trainer.ModelTrainer import ModelTrainer
from Trainer.evaluation import BinaryClassificationEvaluator
from Utils.eval_utils import run_eval_w_binary_classificator

if __name__ == "__main__":
    dl = DataloaderDryspots(aux_info=True)

    checkpoint_p = r.chkp_S1140_densenet_baseline_full_trainingset
    adv_output_dir = checkpoint_p.parent / "advanced_eval"
    m = ModelTrainer(
        lambda: S1140DryspotModelFCWide(),
        data_source_paths=r.get_data_paths_base_0(),
        save_path=r.save_path,
        dataset_split_path=r.dataset_split,
        cache_path=r.cache_path,
        batch_size=32768,
        train_print_frequency=100,
        epochs=1000,
        num_workers=75,
        num_validation_samples=131072,
        num_test_samples=1048576,
        data_processing_function=dl.get_sensor_bool_dryspot,
        data_gather_function=get_filelist_within_folder_blacklisted,
        loss_criterion=torch.nn.BCELoss(),
        optimizer_function=lambda params: torch.optim.AdamW(params, lr=1e-4),
        classification_evaluator_function=lambda: BinaryClassificationEvaluator(),
예제 #2
0
import Resources.training as r
from Models.erfh5_fullyConnected import S1140DryspotModelFCWide
from Pipeline.data_gather import get_filelist_within_folder_blacklisted
from Pipeline.data_loader_flowfront_sensor import DataloaderFlowfrontSensor
from Trainer.ModelTrainer import ModelTrainer
from Trainer.evaluation import BinaryClassificationEvaluator
from Utils.training_utils import read_cmd_params

if __name__ == "__main__":
    args = read_cmd_params()

    dlds = DataloaderFlowfrontSensor(sensor_indizes=((0, 1), (0, 1)),
                                     frame_count=1,
                                     use_binary_sensor_only=True)
    m = ModelTrainer(lambda: S1140DryspotModelFCWide(),
                     data_source_paths=r.get_data_paths_debug(),
                     save_path=r.save_path,
                     dataset_split_path=r.dataset_split,
                     cache_path=r.cache_path,
                     batch_size=2048,
                     train_print_frequency=100,
                     epochs=100,
                     num_workers=75,
                     num_validation_samples=512,  # 131072,
                     num_test_samples=1024,  # 1048576,
                     data_processing_function=dlds.get_flowfront_sensor_bool_dryspot,
                     data_gather_function=get_filelist_within_folder_blacklisted,
                     loss_criterion=torch.nn.BCELoss(),
                     optimizer_function=lambda params: torch.optim.AdamW(params, lr=1e-4),
                     classification_evaluator_function=lambda: BinaryClassificationEvaluator(),