def __init__(self, dir_phoneme_labels, dir_durations, id_list, file_symbol_dict, hparams=None): """Default constructor. :param dir_phoneme_labels: Path to the directory containing the label files with monophones. :param dir_durations: Path to the directory containing the durations. :param id_list: List containing all ids. Subset is taken as test set. :param file_symbol_dict: List of all used monophones. """ if hparams is None: hparams = self.create_hparams() hparams.out_dir = os.path.curdir # Write missing default parameters. if hparams.variable_sequence_length_train is None: hparams.variable_sequence_length_train = hparams.batch_size_train > 1 if hparams.variable_sequence_length_test is None: hparams.variable_sequence_length_test = hparams.batch_size_test > 1 if not hasattr(hparams, "synth_dir") or hparams.synth_dir is None: hparams.synth_dir = os.path.join(hparams.out_dir, "synth") super().__init__(id_list, hparams) self.InputGen = PhonemeLabelGen(dir_phoneme_labels, file_symbol_dict, hparams.phoneme_label_type, one_hot=True) self.OutputGen = PhonemeDurationLabelGen(dir_durations) self.OutputGen.get_normalisation_params( dir_durations, hparams.output_norm_params_file_prefix) self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams, match_lengths=False) self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams, match_lengths=False) if self.loss_function is None: self.loss_function = torch.nn.MSELoss(reduction='none') if hparams.scheduler_type == "default": hparams.scheduler_type = "Plateau" hparams.add_hparams(plateau_verbose=True)
def __init__(self, wcad_root, dir_atom_labels, dir_question_labels, id_list, thetas, k, num_questions, hparams=None): """Default constructor. :param wcad_root: Path to main directory of wcad. :param dir_atom_labels: Path to directory that contains the .atom files. :param dir_question_labels: Path to directory that contains the .questions files. :param id_list: List containing all ids. Subset is taken as test set. :param thetas: List of theta values. :param k: K value of atoms. :param num_questions: Expected number of questions in question labels. :param hparams: Hyper-parameter container. """ if hparams is None: hparams = self.create_hparams() hparams.out_dir = os.path.curdir # Write missing default parameters. if hparams.variable_sequence_length_train is None: hparams.variable_sequence_length_train = hparams.batch_size_train > 1 if hparams.variable_sequence_length_test is None: hparams.variable_sequence_length_test = hparams.batch_size_test > 1 if hparams.synth_dir is None: hparams.synth_dir = os.path.join(hparams.out_dir, "synth") # If the weight for unvoiced frames is not given, compute it to get equal weights. non_zero_occurrence = min(0.99, 0.02 / len(thetas)) zero_occurrence = 1 - non_zero_occurrence if not hasattr(hparams, "weight_zero"): hparams.add_hparam("weight_non_zero", 1 / non_zero_occurrence) hparams.add_hparam("weight_zero", 1 / zero_occurrence) elif hparams.weight_zero is None: hparams.weight_non_zero = 1 / non_zero_occurrence hparams.weight_zero = 1 / zero_occurrence super().__init__(id_list, hparams) self.InputGen = QuestionLabelGen(dir_question_labels, num_questions) self.InputGen.get_normalisation_params( dir_question_labels, hparams.input_norm_params_file_prefix) self.OutputGen = AtomLabelGen(wcad_root, dir_atom_labels, thetas, k, hparams.frame_size_ms) self.OutputGen.get_normalisation_params( dir_atom_labels, hparams.output_norm_params_file_prefix) self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams, match_lengths=True) self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams, match_lengths=True) if self.loss_function is None: self.loss_function = WeightedNonzeroMSELoss( hparams.use_gpu, hparams.weight_zero, hparams.weight_non_zero, size_average=False, reduce=False) if hparams.scheduler_type == "default": hparams.scheduler_type = "Plateau" hparams.add_hparams(plateau_patience=10, plateau_factor=0.5, plateau_verbose=True)
def __init__(self, wcad_root, dir_audio, dir_atom_labels, dir_lf0_labels, dir_question_labels, id_list, thetas, k, num_questions, dist_window_size=51, hparams_phrase=None): """Default constructor. :param wcad_root: Path to main directory of wcad. :param dir_audio: Path to directory that contains the .wav files. :param dir_lf0_labels: Path to directory that contains the .lf0 files. :param dir_atom_labels: Path to directory that contains the .atoms files. :param dir_question_labels: Path to directory that contains the .lab files. :param id_list: List containing all ids. Subset is taken as test set. :param thetas: List of used theta values. :param k: k-order of each each atom. :param num_questions: Expected number of questions in question labels. :param dist_window_size: Width of the distribution surrounding each atom spike The window is only used for amps. Thetas are surrounded by a window of 5. :param hparams_phrase: Hyper-parameter container. """ if hparams_phrase is None: hparams_phrase = self.create_hparams() hparams_phrase.out_dir = os.path.curdir hparams_flat = hparams_phrase.hparams_flat if hparams_flat is None: hparams_flat = copy.deepcopy(hparams_phrase) # Set default paths to pre-trained models. if hparams_phrase.atom_model_path is None: hparams_phrase.atom_model_path = os.path.join( hparams_phrase.out_dir, hparams_phrase.networks_dir, hparams_phrase.model_name + "_flat_atoms") if hparams_phrase.flat_model_path is None: hparams_phrase.flat_model_path = os.path.join( hparams_phrase.out_dir, hparams_phrase.networks_dir, hparams_phrase.model_name + "_flat") # Write missing default parameters. if hparams_phrase.synth_dir is None: hparams_phrase.synth_dir = os.path.join(hparams_phrase.out_dir, "synth") super().__init__(id_list, hparams_phrase) self.InputGen = QuestionLabelGen(dir_question_labels, num_questions) self.InputGen.get_normalisation_params( dir_question_labels, hparams_phrase.input_norm_params_file_prefix) self.OutputGen = FlatLF0LabelGen(dir_lf0_labels, dir_atom_labels, remove_phrase=False) self.OutputGen.get_normalisation_params( dir_atom_labels, hparams_phrase.output_norm_params_file_prefix) self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams_phrase, match_lengths=True) self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams_phrase, match_lengths=True) self.flat_trainer = AtomNeuralFilterModelTrainer( wcad_root, dir_audio, dir_atom_labels, dir_lf0_labels, dir_question_labels, id_list, thetas, k, num_questions, dist_window_size, hparams_flat) if self.loss_function is None: self.loss_function = L1WeightedVUVMSELoss( weight_unvoiced=hparams_phrase.weight_unvoiced, vuv_loss_weight=hparams_phrase.vuv_loss_weight, L1_loss_weight=hparams_phrase.L1_loss_weight, reduce=False) if hparams_phrase.scheduler_type == "default": hparams_phrase.scheduler_type = "None" # Override the collate and decollate methods of batches. self.batch_collate_fn = self.prepare_batch self.batch_decollate_fn = self.decollate_network_output
def __init__(self, dir_world_features, id_list, hparams=None): if hparams is None: hparams = self.create_hparams() hparams.out_dir = os.path.curdir # Write missing default parameters. if hparams.variable_sequence_length_train is None: hparams.variable_sequence_length_train = hparams.batch_size_train > 1 if hparams.variable_sequence_length_test is None: hparams.variable_sequence_length_test = hparams.batch_size_test > 1 if hparams.synth_dir is None: hparams.synth_dir = os.path.join(hparams.out_dir, "synth") super().__init__(id_list, hparams) in_to_out_multiplier = int(hparams.frame_rate_output_Hz / (1000.0 / hparams.frame_size_ms)) max_frames_input_trainset = int( 1000.0 / hparams.frame_size_ms * hparams.max_input_train_sec ) * in_to_out_multiplier # Multiply by number of seconds. max_frames_input_testset = int( 1000.0 / hparams.frame_size_ms * hparams.max_input_test_sec ) * in_to_out_multiplier # Ensure that test takes all frames. NOTE: Had to limit it because of memory constraints. self.InputGen = WorldFeatLabelGen( dir_world_features, add_deltas=False, sampling_fn=partial(sample_linearly, in_to_out_multiplier=in_to_out_multiplier, dtype=np.float32), num_coded_sps=hparams.num_coded_sps, sp_type=hparams.sp_type, load_sp=hparams.load_sp, load_lf0=hparams.load_lf0, load_vuv=hparams.load_vuv, load_bap=hparams.load_bap) self.InputGen.get_normalisation_params( dir_world_features, hparams.input_norm_params_file_prefix) self.OutputGen = RawWaveformLabelGen( frame_rate_output_Hz=hparams.frame_rate_output_Hz, frame_size_ms=hparams.frame_size_ms, mu=hparams.mu if hparams.input_type == "mulaw-quantize" else None, silence_threshold_quantized=hparams.silence_threshold_quantized) # No normalisation parameters required. self.dataset_train = LabelGensDataset( self.id_list_train, self.InputGen, self.OutputGen, hparams, random_select=True, max_frames_input=max_frames_input_trainset) self.dataset_val = LabelGensDataset( self.id_list_val, self.InputGen, self.OutputGen, hparams, random_select=True, max_frames_input=max_frames_input_testset) if self.loss_function is None: if hparams.input_type == "mulaw-quantize": self.loss_function = OneHotCrossEntropyLoss(reduction='none', shift=1) else: self.loss_function = DiscretizedMixturelogisticLoss( hparams.quantize_channels, hparams.log_scale_min, reduction='none', hinge_loss=hparams.hinge_regularizer) if hparams.scheduler_type == "default": hparams.scheduler_type = "Noam" # hparams.scheduler_args['exponential_gamma'] = 0.99 hparams.scheduler_args['wormup_steps'] = 4000 # Override the collate and decollate methods of batches. self.batch_collate_fn = partial(self.prepare_batch, use_cond=hparams.use_cond, one_hot_target=True) self.batch_decollate_fn = self.decollate_network_output
class WaveNetVocoderTrainer(ModelTrainer): logger = logging.getLogger(__name__) ######################### # Default constructor # def __init__(self, dir_world_features, id_list, hparams=None): if hparams is None: hparams = self.create_hparams() hparams.out_dir = os.path.curdir # Write missing default parameters. if hparams.variable_sequence_length_train is None: hparams.variable_sequence_length_train = hparams.batch_size_train > 1 if hparams.variable_sequence_length_test is None: hparams.variable_sequence_length_test = hparams.batch_size_test > 1 if hparams.synth_dir is None: hparams.synth_dir = os.path.join(hparams.out_dir, "synth") super().__init__(id_list, hparams) in_to_out_multiplier = int(hparams.frame_rate_output_Hz / (1000.0 / hparams.frame_size_ms)) max_frames_input_trainset = int( 1000.0 / hparams.frame_size_ms * hparams.max_input_train_sec ) * in_to_out_multiplier # Multiply by number of seconds. max_frames_input_testset = int( 1000.0 / hparams.frame_size_ms * hparams.max_input_test_sec ) * in_to_out_multiplier # Ensure that test takes all frames. NOTE: Had to limit it because of memory constraints. self.InputGen = WorldFeatLabelGen( dir_world_features, add_deltas=False, sampling_fn=partial(sample_linearly, in_to_out_multiplier=in_to_out_multiplier, dtype=np.float32), num_coded_sps=hparams.num_coded_sps, sp_type=hparams.sp_type, load_sp=hparams.load_sp, load_lf0=hparams.load_lf0, load_vuv=hparams.load_vuv, load_bap=hparams.load_bap) self.InputGen.get_normalisation_params( dir_world_features, hparams.input_norm_params_file_prefix) self.OutputGen = RawWaveformLabelGen( frame_rate_output_Hz=hparams.frame_rate_output_Hz, frame_size_ms=hparams.frame_size_ms, mu=hparams.mu if hparams.input_type == "mulaw-quantize" else None, silence_threshold_quantized=hparams.silence_threshold_quantized) # No normalisation parameters required. self.dataset_train = LabelGensDataset( self.id_list_train, self.InputGen, self.OutputGen, hparams, random_select=True, max_frames_input=max_frames_input_trainset) self.dataset_val = LabelGensDataset( self.id_list_val, self.InputGen, self.OutputGen, hparams, random_select=True, max_frames_input=max_frames_input_testset) if self.loss_function is None: if hparams.input_type == "mulaw-quantize": self.loss_function = OneHotCrossEntropyLoss(reduction='none', shift=1) else: self.loss_function = DiscretizedMixturelogisticLoss( hparams.quantize_channels, hparams.log_scale_min, reduction='none', hinge_loss=hparams.hinge_regularizer) if hparams.scheduler_type == "default": hparams.scheduler_type = "Noam" # hparams.scheduler_args['exponential_gamma'] = 0.99 hparams.scheduler_args['wormup_steps'] = 4000 # Override the collate and decollate methods of batches. self.batch_collate_fn = partial(self.prepare_batch, use_cond=hparams.use_cond, one_hot_target=True) self.batch_decollate_fn = self.decollate_network_output @staticmethod def create_hparams(hparams_string=None, verbose=False): """Create model hyper-parameters. Parse non-default from given string.""" hparams = ModelTrainer.create_hparams(hparams_string, verbose=False) hparams.synth_vocoder = "raw" hparams.add_hparams( batch_first=True, frame_rate_output_Hz=16000, mu=255, bit_depth=16, silence_threshold_quantized= None, # Beginning and end of audio below the threshold are trimmed. teacher_forcing_in_test=True, ema_decay=0.9999, # Model parameters. input_type="mulaw-quantize", hinge_regularizer= True, # Only used in MoL prediction (input_type="raw"). log_scale_min=float(np.log( 1e-14)), # Only used for mixture of logistic distributions. quantize_channels=256 ) # 256 for input type mulaw-quantize, otherwise 65536 if hparams.input_type == "mulaw-quantize": hparams.add_hparam("out_channels", hparams.quantize_channels) else: hparams.add_hparam("out_channels", 10 * 3) # num_mixtures * 3 (pi, mean, log_scale) hparams.add_hparams( layers=24, # 20 stacks=4, # 2 residual_channels=512, gate_channels=512, skip_out_channels=256, dropout=1 - 0.95, kernel_size=3, weight_normalization=True, use_cond=True, # Determines if conditioning is used. cin_channels=63, upsample_conditional_features=False, upsample_scales=[5, 4, 2]) if hparams.upsample_conditional_features: hparams.len_in_out_multiplier = reduce(mul, hparams.upsample_scales, 1) else: hparams.len_in_out_multiplier = 1 hparams.add_hparams(freq_axis_kernel_size=3, gin_channels=-1, n_speakers=1, use_speaker_embedding=False, sp_type="mcep", load_sp=True, load_lf0=True, load_vuv=True, load_bap=True) if verbose: logging.info(hparams.get_debug_string()) return hparams # Load train and test data. @staticmethod def prepare_batch(batch, common_divisor=1, batch_first=False, use_cond=True, one_hot_target=True): inputs, targets, seq_lengths_input, seq_lengths_output, mask, permutation = ModelHandler.prepare_batch( batch, common_divisor=common_divisor, batch_first=batch_first) if batch_first: # inputs: (B x T x C) --permute--> (B x C x T) inputs = inputs.transpose(1, 2).contiguous() # TODO: Handle case where batch_first=False: inputs = inputs.transpose(2, 0, 1).contiguous()? if targets is not None: if batch_first: # targets: (B x T x C) --permute--> (B x C x T) targets = targets.transpose(1, 2).contiguous() if not one_hot_target: targets = targets.max(dim=1, keepdim=True)[1].float() if mask is not None: mask = mask[:, 1:].contiguous() return inputs if use_cond else None, targets, seq_lengths_input, seq_lengths_output, mask, permutation @staticmethod def decollate_network_output(output, hidden, seq_lengths=None, permutation=None, batch_first=True): # Output of r9y9 Wavenet has batch first, thus output: B x C x T --transpose--> B x T x C output = np.transpose(output, (0, 2, 1)) if not batch_first: # output: B x T x C --transpose--> T x B x C output = np.transpose(output, (1, 0, 2)) return ModelTrainer.split_batch(output, hidden, seq_length_output=seq_lengths, permutation=permutation, batch_first=batch_first) def gen_figure_from_output(self, id_name, labels, hidden, hparams): labels_post = self.dataset_train.postprocess_sample( labels) # Labels come in as T x C. org_raw = RawWaveformLabelGen.load_sample( id_name, self.OutputGen.frame_rate_output_Hz) # Get a data plotter. plotter = DataPlotter() net_name = os.path.basename(hparams.model_name) id_name = os.path.basename(id_name).rsplit('.', 1)[0] filename = os.path.join(hparams.out_dir, id_name + "." + net_name) plotter.set_title(id_name + " - " + net_name) grid_idx = 0 graphs = list() graphs.append((org_raw, 'Org')) graphs.append((labels_post, 'Wavenet')) plotter.set_data_list(grid_idx=grid_idx, data_list=graphs) plotter.set_linewidth(grid_idx=grid_idx, linewidth=[0.1]) plotter.set_colors(grid_idx=grid_idx, alpha=0.8) plotter.set_lim(grid_idx, ymin=-1, ymax=1) plotter.set_label(grid_idx=grid_idx, xlabel='frames [' + str(hparams.frame_rate_output_Hz) + ' Hz]', ylabel='raw') plotter.gen_plot() plotter.save_to_file(filename + '.Raw' + hparams.gen_figure_ext) # def synthesize(self, file_id_list, synth_output, hparams): # self.run_raw_synth(synth_output, hparams) # def synth_ref(self, hparams, file_id_list): # self.logger.info("Synthesise references for [{0}].".format(", ".join([id_name for id_name in file_id_list]))) # Can be different from original by sampling frequency. # # synth_output = dict() # for id_name in file_id_list: # # Use extracted data. Useful to create a reference. # raw = RawWaveformLabelGen.load_sample(id_name, self.OutputGen.frame_rate_output_Hz) # synth_output[id_name] = raw # # # Add identifier to suffix. # old_synth_file_suffix = hparams.synth_file_suffix # hparams.synth_file_suffix += '_ref' # # # Run the WORLD synthesiser. # self.run_raw_synth(synth_output, hparams) # # # Restore identifier. # hparams.synth_file_suffix = old_synth_file_suffix def save_for_vocoding(self, filename): # Save the full model so that hyper-parameters are already set. self.model_handler.save_full_model(filename, self.model_handler.model, verbose=True) # Save an easily loadable version of the normalisation parameters on the input side used during training. np.save( os.path.splitext(filename)[0] + "_norm_params", np.concatenate(self.InputGen.norm_params, axis=0))
def __init__(self, wcad_root, dir_atom_labels, dir_lf0_labels, dir_question_labels, id_list, thetas, k, num_questions, dist_window_size=51, hparams=None): """Default constructor. :param wcad_root: Path to main directory of wcad. :param dir_atom_labels: Path to directory that contains the .wav files. :param dir_lf0_labels: Path to directory that contains the .lf0 files. :param dir_question_labels: Path to directory that contains the .lab files. :param id_list: List containing all ids. Subset is taken as test set. :param thetas: List of theta values of atoms. :param k: K-value of atoms. :param num_questions: Expected number of questions in question labels. :param dist_window_size: Width of the distribution surrounding each atom spike The window is only used for amps. Thetas are surrounded by a window of 5. :param hparams: Hyper-parameter container. """ if hparams is None: hparams = self.create_hparams() hparams.out_dir = os.path.curdir # Write missing default parameters. if hparams.variable_sequence_length_train is None: hparams.variable_sequence_length_train = hparams.batch_size_train > 1 if hparams.variable_sequence_length_test is None: hparams.variable_sequence_length_test = hparams.batch_size_test > 1 if hparams.synth_dir is None: hparams.synth_dir = os.path.join(hparams.out_dir, "synth") # If the weight for unvoiced frames is not given, compute it to get equal weights. if not hasattr(hparams, "weight_zero") or hparams.weight_zero is None: non_zero_occurrence = min(0.99, 0.015 / len(thetas)) zero_occurrence = 1 - non_zero_occurrence hparams.add_hparam("weight_non_zero", 1 / non_zero_occurrence) hparams.add_hparam("weight_zero", 1 / zero_occurrence) if not hasattr(hparams, "weight_vuv") or hparams.weight_vuv is None: hparams.add_hparam("weight_vuv", 0.5) if not hasattr(hparams, "atom_loss_theta") or hparams.atom_loss_theta is None: hparams.add_hparam("atom_loss_theta", 0.01) # Explicitly call only the constructor of the baseclass of AtomModelTrainer. super(AtomModelTrainer, self).__init__(id_list, hparams) if hparams.dist_window_size % 2 == 0: hparams.dist_window_size += 1 self.logger.warning("hparams.dist_window_size should be odd, changed it to " + str(hparams.dist_window_size)) self.InputGen = QuestionLabelGen(dir_question_labels, num_questions) self.InputGen.get_normalisation_params(dir_question_labels, hparams.input_norm_params_file_prefix) # Overwrite OutputGen by the one with beta distribution. self.OutputGen = AtomVUVDistPosLabelGen(wcad_root, dir_atom_labels, dir_lf0_labels, thetas, k, hparams.frame_size_ms, window_size=dist_window_size) self.OutputGen.get_normalisation_params(dir_atom_labels, hparams.output_norm_params_file_prefix) self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams, match_lengths=True) self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams, match_lengths=True) if self.loss_function is None: self.loss_function = WeightedNonzeroWMSEAtomLoss(use_gpu=hparams.use_gpu, theta=hparams.atom_loss_theta, weights_vuv=hparams.weight_vuv, weights_zero=hparams.weight_zero, weights_non_zero=hparams.weight_non_zero, reduce=False) if hparams.scheduler_type == "default": hparams.scheduler_type = "None"
def __init__(self, wcad_root, dir_audio, dir_atom_labels, dir_lf0_labels, dir_question_labels, id_list, thetas, k, num_questions, dist_window_size=51, hparams=None): """Default constructor. :param wcad_root: Path to main directory of wcad. :param dir_audio: Path to directory that contains the .wav files. :param dir_atom_labels: Path to directory that contains the .atoms files. :param dir_lf0_labels: Path to directory that contains the .lf0 files. :param dir_question_labels: Path to directory that contains the .lab files. :param id_list: List containing all ids. Subset is taken as test set. :param thetas: List of theta values of the used atoms. :param k: K-value of atoms. :param num_questions: Expected number of questions in question labels. :param dist_window_size: Size of distribution around atom amplitudes when training the atom model. :param hparams: Hyper-parameter container. """ if hparams is None: hparams = self.create_hparams() hparams.out_dir = os.path.curdir hparams_atom = hparams.hparams_atom if hparams_atom is None: hparams_atom = copy.deepcopy(hparams) hparams_atom.synth_gen_figure = False hparams_atom.synth_acoustic_model_path = None if hparams.atom_model_path is None: hparams.atom_model_path = os.path.join( hparams.out_dir, hparams.networks_dir, hparams.model_name + "_atoms") # Write missing default parameters. if hparams.synth_dir is None: hparams.synth_dir = os.path.join(hparams.out_dir, "synth") super().__init__(id_list, hparams) self.InputGen = QuestionLabelGen(dir_question_labels, num_questions) self.InputGen.get_normalisation_params( dir_question_labels, hparams.input_norm_params_file_prefix) self.OutputGen = FlatLF0LabelGen(dir_lf0_labels, dir_atom_labels) self.OutputGen.get_normalisation_params( dir_atom_labels, hparams.output_norm_params_file_prefix) self.dataset_train = PyTorchLabelGensDataset(self.id_list_train, self.InputGen, self.OutputGen, hparams, match_lengths=True) self.dataset_val = PyTorchLabelGensDataset(self.id_list_val, self.InputGen, self.OutputGen, hparams, match_lengths=True) self.atom_trainer = AtomVUVDistPosModelTrainer( wcad_root, dir_atom_labels, dir_lf0_labels, dir_question_labels, id_list, thetas, k, num_questions, dist_window_size, hparams_atom) if self.loss_function is None: self.loss_function = L1WeightedVUVMSELoss( weight_unvoiced=hparams.weight_unvoiced, vuv_loss_weight=hparams.vuv_loss_weight, L1_loss_weight=hparams.L1_loss_weight, reduce=False) if hparams.scheduler_type == "default": hparams.scheduler_type = "None" # Override the collate and decollate methods of batches. self.batch_collate_fn = self.prepare_batch self.batch_decollate_fn = self.decollate_network_output