def __init__(self, ratio=[0.90, 0.05, 0.05], **kwargs):
		"""
		Inputs:
			ratio: ratio for train / valid / test set
			kwargs: Dataset parameters
		"""

		np.random.seed(config.seed)

		self.nb_speakers = kwargs['nb_speakers']
		self.sex = kwargs['sex']
		self.batch_size = kwargs['batch_size']
		self.chunk_size = kwargs['chunk_size']
		self.no_random_picking = kwargs['no_random_picking']

		# Flags for Training/Validation/Testing sets
		self.TRAIN = 0
		self.VALID = 1
		self.TEST = 2

		# TODO 
		metadata = data_tools.read_metadata()

		if self.sex != ['M', 'F'] and self.sex != ['F', 'M'] and self.sex != ['M'] and self.sex != ['F']:
			raise Exception('Sex must be ["M","F"] |  ["F","M"] | ["M"] | [F"]')

		# Create a key to speaker index dictionnary
		# And count the numbers of speakers
		self.key_to_index = {}
		self.sex_to_keys = {}
		j = 0

		if 'M' in self.sex:
			M = data_tools.males_keys(metadata)
			self.sex_to_keys['M'] = M
			for k in M:
				self.key_to_index[k] = j
				j += 1 
		if 'F' in self.sex:
			F = data_tools.females_keys(metadata)
			self.sex_to_keys['F'] = F
			for k in F:
				self.key_to_index[k] = j
				j += 1

		self.tot_speakers = j

		self.file = h5py.File(kwargs['dataset'], 'r')


		# Define all the items related to each key/speaker
		self.total_items = []

		for key in self.key_to_index.keys():
			for val in self.file[key]:
				# Get one file related to a speaker and check how many chunks can be obtained
				# with the current chunk size
				chunks = self.file['/'.join([key,val])].shape[0]//self.chunk_size
				# Add each possible chunks in the items with the following form:
				# 'key/file/#chunk'
				self.total_items += ['/'.join([key,val,str(i)]) for i in range(chunks)]

		np.random.shuffle(self.total_items)
		self.total_items = self.total_items

		L = len(self.total_items)
		# Shuffle all the items

		# Training / Valid / Test Separation
		train = self.create_tree(self.total_items[:int(L*ratio[0])])
		valid = self.create_tree(self.total_items[int(L*ratio[0]):int(L*(ratio[0]+ratio[1]))])
		test = self.create_tree(self.total_items[int(L*(ratio[0]+ratio[1])):])
		
		self.train = TreeIterator(train, self)
		self.valid = TreeIterator(valid, self)
		self.test = TreeIterator(test, self)
예제 #2
0
파일: dataset.py 프로젝트: Qoboty/das
	def get_labels(self):
		return self.dico

from data_tools import read_metadata, males_keys, females_keys
if __name__ == "__main__":

	###
	### TEST
	###

	H5_dic = read_metadata()
	print H5_dic
	chunk_size = 512*100

	males = H5PY_RW('test_raw.h5py', subset = males_keys(H5_dic))
	fem = H5PY_RW('test_raw.h5py', subset = females_keys(H5_dic))

	print 'Data with', len(H5_dic), 'male and female speakers'
	print males.length(), 'elements'
	print fem.length(), 'elements'

	mixed_data = Mixer([males, fem], chunk_size= chunk_size, with_mask=False, with_inputs=True, shuffling=True)

	batch_size = 128

	mixed_data.adjust_split_size_to_batchsize(batch_size)
	nb_batches = mixed_data.nb_batches(batch_size)

	nb_to_speaker = mixed_data.dico
	id_f = []