def reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir, valid_ratio): with open(os.path.join(data_dir, label_file), 'r') as f: # Skipp first line lines = f.readlines()[1:] tokens = [l.rstrip().split(',') for l in lines] idx_label = dict(((idx, label) for idx, label in tokens)) reorg_train_valid(data_dir, train_dir, input_dir, valid_ratio, idx_label) # Manage test set d2l.mkdir_if_not_exist([data_dir, input_dir, 'test', 'unknown']) for test_file in os.listdir(os.path.join(data_dir, test_dir)): shutil.copy(os.path.join(data_dir, test_dir, test_file), os.path.join(data_dir, input_dir, 'test', 'unknown'))
def reorg_train_valid(data_dir, train_dir, input_dir, valid_ratio, idx_label): # 训练集中数量最少一类的狗的样本数 min_n_train_per_label = (collections.Counter( idx_label.values()).most_common()[:-2:-1][0][1]) # 验证集中每类狗的样本数 n_valid_per_label = math.floor(min_n_train_per_label * valid_ratio) label_count = {} for train_file in os.listdir(os.path.join(data_dir, train_dir)): idx = train_file.split('.')[0] label = idx_label[idx] d2l.mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label]) shutil.copy(os.path.join(data_dir, train_dir, train_file), os.path.join(data_dir, input_dir, 'train_valid', label)) if label not in label_count or label_count[label] < n_valid_per_label: d2l.mkdir_if_not_exist([data_dir, input_dir, 'valid', label]) shutil.copy(os.path.join(data_dir, train_dir, train_file), os.path.join(data_dir, input_dir, 'valid', label)) label_count[label] = label_count.get(label, 0) + 1 else: d2l.mkdir_if_not_exist([data_dir, input_dir, 'train', label]) shutil.copy(os.path.join(data_dir, train_dir, train_file), os.path.join(data_dir, input_dir, 'train', label))