def test_train_TrainTestSplit_continuous(tub_path): # Check whether the Train-Test splitting is working correctly when a dataset is extended. initial_records = 100 # Setup the test data t = create_sample_tub(tub_path, records=initial_records) assert t is not None # Import the configuration import donkeycar.templates.cfg_complete as cfg # Initial Setup gen_records = {} opts = {'categorical': False} opts['cfg'] = cfg # Perform the initial split print() print("Initial split of {} records to {} Test-Train split...".format( initial_records, opts['cfg'].TRAIN_TEST_SPLIT)) print() records = gather_records(cfg, tub_path, opts, verbose=True) assert len(records) == initial_records collate_records(records, gen_records, opts) ratio = calculate_TrainTestSplit(gen_records) assert ratio == cfg.TRAIN_TEST_SPLIT # Add some more records and recheck the ratio (only the NEW records should be added) additional_records = 200 print() print( "Added an extra {} records, aiming for overall {} Test-Train split...". format(additional_records, opts['cfg'].TRAIN_TEST_SPLIT)) print() create_sample_tub(tub_path, records=additional_records) records = gather_records(cfg, tub_path, opts, verbose=True) assert len(records) == (initial_records + additional_records) collate_records(records, gen_records, opts) ratio = calculate_TrainTestSplit(gen_records) assert ratio == cfg.TRAIN_TEST_SPLIT
def test_train_TrainTestSplit_simple(tub_path): # Check whether the Train-Test splitting is working correctly on a dataset. initial_records = 100 # Setup the test data t = create_sample_tub(tub_path, records=initial_records) assert t is not None # Import the configuration import donkeycar.templates.cfg_complete as cfg # Initial Setup opts = {'categorical': False} opts['cfg'] = cfg orig_TRAIN_TEST_SPLIT = cfg.TRAIN_TEST_SPLIT records = gather_records(cfg, tub_path, opts, verbose=True) assert len(records) == initial_records # Attempt a 50:50 split gen_records = {} cfg.TRAIN_TEST_SPLIT = 0.5 print() print("Testing a {} Test-Train split...".format( opts['cfg'].TRAIN_TEST_SPLIT)) print() collate_records(records, gen_records, opts) ratio = calculate_TrainTestSplit(gen_records) assert ratio == cfg.TRAIN_TEST_SPLIT # Attempt a split based on config file (reset of the record set) gen_records = {} cfg.TRAIN_TEST_SPLIT = orig_TRAIN_TEST_SPLIT print() print("Testing a {} Test-Train split...".format( opts['cfg'].TRAIN_TEST_SPLIT)) print() collate_records(records, gen_records, opts) ratio = calculate_TrainTestSplit(gen_records) assert ratio == cfg.TRAIN_TEST_SPLIT