def process_condition_batch(max_time_steps, hparams, batch): """process condition batch""" cin_pad = hparams.cin_pad new_batch = [] for batch_ in batch: x, c, g = batch_ if hparams.upsample_conditional_features: assert_ready_for_upsampling(x, c, cin_pad=0) if max_time_steps is not None: max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True) if len(x) > max_steps: max_time_frames = max_steps // audio.get_hop_size() s = np.random.randint(cin_pad, len(c) - max_time_frames - cin_pad) ts = s * audio.get_hop_size() x = x[ts:ts + audio.get_hop_size() * max_time_frames] c = c[s - cin_pad:s + max_time_frames + cin_pad, :] assert_ready_for_upsampling(x, c, cin_pad=cin_pad) else: x, c = audio.adjust_time_resolution(x, c) if max_time_steps is not None and len(x) > max_time_steps: s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad) x = x[s:s + max_time_steps] c = c[s - cin_pad:s + max_time_steps + cin_pad, :] assert len(x) == len(c) new_batch.append((x, c, g)) return new_batch
def collate_fn(batch): """Create batch Args: batch(tuple): List of tuples - x[0] (ndarray,int) : list of (T,) - x[1] (ndarray,int) : list of (T, D) - x[2] (ndarray,int) : list of (1,), speaker id Returns: tuple: Tuple of batch - x (FloatTensor) : Network inputs (B, C, T) - y (LongTensor) : Network targets (B, T, 1) """ local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0 global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0 # To save GPU memory... I don't want to do this though if hparams.max_time_sec is not None: max_time_steps = int(hparams.max_time_sec * hparams.sample_rate) elif hparams.max_time_steps is not None: max_time_steps = hparams.max_time_steps else: max_time_steps = None # Time resolution adjustment if local_conditioning: new_batch = [] for idx in range(len(batch)): x, c, g = batch[idx] if hparams.upsample_conditional_features: assert_ready_for_upsampling(x, c) if max_time_steps is not None: max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True) if len(x) > max_steps: max_time_frames = max_steps // audio.get_hop_size() s = np.random.randint(0, len(c) - max_time_frames) #print("Size of file=%6d, t_offset=%6d" % (len(c), s,)) ts = s * audio.get_hop_size() x = x[ts:ts + audio.get_hop_size() * max_time_frames] c = c[s:s + max_time_frames, :] assert_ready_for_upsampling(x, c) else: x, c = audio.adjust_time_resolution(x, c) if max_time_steps is not None and len(x) > max_time_steps: s = np.random.randint(0, len(x) - max_time_steps) x, c = x[s:s + max_time_steps], c[s:s + max_time_steps, :] assert len(x) == len(c) new_batch.append((x, c, g)) batch = new_batch else: new_batch = [] for idx in range(len(batch)): x, c, g = batch[idx] x = audio.trim(x) if max_time_steps is not None and len(x) > max_time_steps: s = np.random.randint(0, len(x) - max_time_steps) if local_conditioning: x, c = x[s:s + max_time_steps], c[s:s + max_time_steps, :] else: x = x[s:s + max_time_steps] new_batch.append((x, c, g)) batch = new_batch # Lengths input_lengths = [len(x[0]) for x in batch] max_input_len = max(input_lengths) # (B, T, C) # pad for time-axis if is_mulaw_quantize(hparams.input_type): x_batch = np.array([ _pad_2d( np_utils.to_categorical(x[0], num_classes=hparams.quantize_channels), max_input_len) for x in batch ], dtype=np.float32) else: x_batch = np.array( [_pad_2d(x[0].reshape(-1, 1), max_input_len) for x in batch], dtype=np.float32) assert len(x_batch.shape) == 3 # (B, T) if is_mulaw_quantize(hparams.input_type): y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.int) else: y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32) assert len(y_batch.shape) == 2 # (B, T, D) if local_conditioning: max_len = max([len(x[1]) for x in batch]) c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32) assert len(c_batch.shape) == 3 # (B x C x T) c_batch = torch.FloatTensor(c_batch).transpose(1, 2).contiguous() else: c_batch = None if global_conditioning: g_batch = torch.LongTensor([x[2] for x in batch]) else: g_batch = None # Covnert to channel first i.e., (B, C, T) x_batch = torch.FloatTensor(x_batch).transpose(1, 2).contiguous() # Add extra axis if is_mulaw_quantize(hparams.input_type): y_batch = torch.LongTensor(y_batch).unsqueeze(-1).contiguous() else: y_batch = torch.FloatTensor(y_batch).unsqueeze(-1).contiguous() input_lengths = torch.LongTensor(input_lengths) return x_batch, y_batch, c_batch, g_batch, input_lengths
def thread_main(self, sess): stop = False while not stop: iterator = load_npy_data(self.metadata_filename, self.npy_dataroot, self.speaker_id) for wav, local_condition, global_condition in iterator: if self.coord.should_stop(): stop = True break # force to align the audio and local_condition # if audio.shape[0] > local_condition.shape[0]: # audio = audio[:local_condition.shape[0], :] # else: # local_condition = local_condition[:audio.shape[0], :] # audio = np.pad(audio, [[self.receptive_field, 0], [0, 0]], mode='constant') # local_condition = np.pad(local_condition, [[self.receptive_field, 0], [0, 0]], mode='constant') # if self.sample_size: # while len(audio) > self.receptive_field: # audio_piece = audio[:(self.receptive_field + self.sample_size), :] # audio = audio[self.sample_size:, :] # # local_condition_piece = local_condition[:(self.receptive_field + self.sample_size), :] # local_condition = local_condition[self.sample_size:, :] # # if self.gc_enable: # sess.run(self.enqueue, feed_dict= # dict(zip(self._placeholders, (audio_piece, local_condition_piece, global_condition)))) # else: # sess.run(self.enqueue, feed_dict= # dict(zip(self._placeholders, (audio_piece, local_condition_piece)))) # else: # if self.gc_enable: # sess.run(self.enqueue, feed_dict=dict(zip( # self._placeholders, (audio, local_condition, global_condition)))) # else: # sess.run(self.enqueue, feed_dict=dict(zip(self._placeholders, (audio, local_condition)))) if hparams.upsample_conditional_features or ( hparams.lc_conv_layers > 0 and hparams.lc_average): wav = wav.reshape(-1, 8) assert_ready_for_upsampling(wav, local_condition) if self.sample_size is not None: sample_size, s = ensure_divisible( self.sample_size, hparams.average_window_len, hparams.average_window_shift, True) if wav.shape[0] > sample_size: wav_piece = wav[:sample_size, :] local_condition_piece = local_condition[:s, :] wav = wav[sample_size:, :] local_condition = local_condition[s:, :] if self.gc_enable: sess.run( self.enqueue, feed_dict=dict( zip(self._placeholders, (wav_piece, local_condition_piece, global_condition)))) else: sess.run(self.enqueue, feed_dict=dict( zip(self._placeholders, (wav_piece, local_condition_piece)))) else: wav, local_condition = audio.adjust_time_resolution( wav, local_condition) wav = wav.reshape(-1, 8) if self.sample_size is not None: while wav.shape[0] > self.sample_size: wav_piece = wav[:(self.receptive_field + self.sample_size), :] local_condition_piece = local_condition[:( self.receptive_field + self.sample_size), :] wav = wav[self.sample_size:, :] local_condition = local_condition[ self.sample_size:, :] assert wav_piece.shape[ 0] == local_condition_piece.shape[0] if self.gc_enable: sess.run( self.enqueue, feed_dict=dict( zip(self._placeholders, (wav_piece, local_condition_piece, global_condition)))) else: sess.run(self.enqueue, feed_dict=dict( zip(self._placeholders, (wav_piece, local_condition_piece))))
def thread_main(self, sess): stop = False while not stop: iterator = load_npy_data(self.metadata_filename, self.npy_dataroot, self.speaker_id) for wav, local_condition, global_condition in iterator: if self.coord.should_stop(): stop = True break # force to align the audio and local_condition # if audio.shape[0] > local_condition.shape[0]: # audio = audio[:local_condition.shape[0], :] # else: # local_condition = local_condition[:audio.shape[0], :] # audio = np.pad(audio, [[self.receptive_field, 0], [0, 0]], mode='constant') # local_condition = np.pad(local_condition, [[self.receptive_field, 0], [0, 0]], mode='constant') # if self.sample_size: # while len(audio) > self.receptive_field: # audio_piece = audio[:(self.receptive_field + self.sample_size), :] # audio = audio[self.sample_size:, :] # # local_condition_piece = local_condition[:(self.receptive_field + self.sample_size), :] # local_condition = local_condition[self.sample_size:, :] # # if self.gc_enable: # sess.run(self.enqueue, feed_dict= # dict(zip(self._placeholders, (audio_piece, local_condition_piece, global_condition)))) # else: # sess.run(self.enqueue, feed_dict= # dict(zip(self._placeholders, (audio_piece, local_condition_piece)))) # else: # if self.gc_enable: # sess.run(self.enqueue, feed_dict=dict(zip( # self._placeholders, (audio, local_condition, global_condition)))) # else: # sess.run(self.enqueue, feed_dict=dict(zip(self._placeholders, (audio, local_condition)))) if hparams.upsample_conditional_features: wav = wav.reshape(-1, 1) print('wav_shape:', wav.shape) print('local_condition shape:', local_condition.shape) assert_ready_for_upsampling(wav, local_condition) if self.sample_size is not None: sample_size = ensure_divisible(self.sample_size, audio.get_hop_size(), True) if wav.shape[0] > sample_size: max_frames = sample_size // audio.get_hop_size() s = np.random.randint(0, len(local_condition) - max_frames) ts = s * audio.get_hop_size() wav = wav[ts:ts + audio.get_hop_size() * max_frames, :] local_condition = local_condition[s:s + max_frames, :] if self.gc_enable: sess.run(self.enqueue, feed_dict=dict(zip( self._placeholders, (wav, local_condition, global_condition) ))) else: sess.run(self.enqueue, feed_dict=dict(zip( self._placeholders, (wav, local_condition) ))) else: wav, local_condition = audio.adjust_time_resolution(wav, local_condition) wav = wav.reshape(-1, 1) if self.sample_size is not None: while wav.shape[0] > self.sample_size: wav_piece = wav[:(self.receptive_field + self.sample_size), :] local_condition_piece = local_condition[:(self.receptive_field + self.sample_size), :] wav = wav[:self.sample_size, :] local_condition = local_condition[:self.sample_size, :] assert len(wav_piece) == len(local_condition_piece) if self.gc_enable: sess.run(self.enqueue, feed_dict=dict(zip( self._placeholders, (wav_piece, local_condition_piece, global_condition)))) else: sess.run(self.enqueue, feed_dict=dict(zip( self._placeholders, (wav_piece, local_condition_piece))))
def collate_fn(batch, ignore_base_signal_computations=False): """Create batch Args: batch(tuple): List of tuples - x[0] (ndarray,int) : list of (T,) - x[1] (ndarray,int) : list of (T, D) - x[2] (ndarray,int) : list of (1,), speaker id Returns: tuple: Tuple of batch - x (FloatTensor) : Network inputs (B, C, T) - y (LongTensor) : Network targets (B, T, 1) """ local_conditioning = len( batch[0]) >= 2 and wavenet_hparams.cin_channels > 0 global_conditioning = len( batch[0]) >= 3 and wavenet_hparams.gin_channels > 0 if wavenet_hparams.max_time_sec is not None: max_time_steps = int(wavenet_hparams.max_time_sec * wavenet_hparams.sample_rate) elif wavenet_hparams.max_time_steps is not None: max_time_steps = wavenet_hparams.max_time_steps else: max_time_steps = None # Time resolution adjustment cin_pad = wavenet_hparams.cin_pad if local_conditioning: new_batch = [] for idx in range(len(batch)): x, c, g = batch[idx] if wavenet_hparams.upsample_conditional_features: if not ignore_base_signal_computations: assert_ready_for_upsampling(x, c, cin_pad=0) if max_time_steps is not None: max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True) if len(x) > max_steps: max_time_frames = max_steps // audio.get_hop_size() s = np.random.randint( cin_pad, len(c) - max_time_frames - cin_pad) ts = s * audio.get_hop_size() if not ignore_base_signal_computations: x = x[ts:ts + audio.get_hop_size() * max_time_frames] c = c[s - cin_pad:s + max_time_frames + cin_pad, :] assert_ready_for_upsampling(x, c, cin_pad=cin_pad) else: x, c = audio.adjust_time_resolution(x, c) if max_time_steps is not None and len(x) > max_time_steps: s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad) x = x[s:s + max_time_steps] c = c[s - cin_pad:s + max_time_steps + cin_pad, :] assert len(x) == len(c) new_batch.append((x, c, g)) batch = new_batch else: new_batch = [] for idx in range(len(batch)): x, c, g = batch[idx] x = audio.trim(x) if max_time_steps is not None and len(x) > max_time_steps: s = np.random.randint(0, len(x) - max_time_steps) if local_conditioning: x, c = x[s:s + max_time_steps], c[s:s + max_time_steps, :] else: x = x[s:s + max_time_steps] new_batch.append((x, c, g)) batch = new_batch # Lengths input_lengths = [len(x[0]) for x in batch] max_input_len = max(input_lengths) # (B, T, C) # pad for time-axis if is_mulaw_quantize(wavenet_hparams.input_type): padding_value = P.mulaw_quantize(0, mu=wavenet_hparams.quantize_channels - 1) x_batch = np.array( [ _pad_2d( to_categorical( x[0], num_classes=wavenet_hparams.quantize_channels), max_input_len, 0, padding_value, ) for x in batch ], dtype=np.float32, ) else: x_batch = np.array( [_pad_2d(x[0].reshape(-1, 1), max_input_len) for x in batch], dtype=np.float32, ) assert len(x_batch.shape) == 3 # (B, T) if is_mulaw_quantize(wavenet_hparams.input_type): padding_value = P.mulaw_quantize(0, mu=wavenet_hparams.quantize_channels - 1) y_batch = np.array( [ _pad(x[0], max_input_len, constant_values=padding_value) for x in batch ], dtype=np.int, ) else: y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32) assert len(y_batch.shape) == 2 # (B, T, D) if local_conditioning: max_len = max([len(x[1]) for x in batch]) c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32) assert len(c_batch.shape) == 3 # (B x C x T) c_batch = torch.FloatTensor(c_batch).transpose(1, 2).contiguous() else: c_batch = None if global_conditioning: g_batch = torch.LongTensor([x[2] for x in batch]) else: g_batch = None # Covnert to channel first i.e., (B, C, T) x_batch = torch.FloatTensor(x_batch).transpose(1, 2).contiguous() # Add extra axis if is_mulaw_quantize(wavenet_hparams.input_type): y_batch = torch.LongTensor(y_batch).unsqueeze(-1).contiguous() else: y_batch = torch.FloatTensor(y_batch).unsqueeze(-1).contiguous() input_lengths = torch.LongTensor(input_lengths) return x_batch, y_batch, c_batch, g_batch, input_lengths