def main(): """Run training process.""" parser = argparse.ArgumentParser( description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)" ) parser.add_argument( "--train-dir", default=None, type=str, help="directory including training data. ", ) parser.add_argument( "--dev-dir", default=None, type=str, help="directory including development data. ", ) parser.add_argument( "--use-norm", default=1, type=int, help="usr norm-mels for train or raw." ) parser.add_argument( "--f0-stat", default="./dump/stats_f0.npy", type=str, required=True, help="f0-stat path.", ) parser.add_argument( "--energy-stat", default="./dump/stats_energy.npy", type=str, required=True, help="energy-stat path.", ) parser.add_argument( "--outdir", type=str, required=True, help="directory to save checkpoints." ) parser.add_argument( "--config", type=str, required=True, help="yaml format configuration file." ) parser.add_argument( "--resume", default="", type=str, nargs="?", help='checkpoint file path to resume training. (default="")', ) parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)", ) parser.add_argument( "--mixed_precision", default=0, type=int, help="using mixed precision for generator or not.", ) parser.add_argument( "--pretrained", default="", type=str, nargs="?", help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers", ) args = parser.parse_args() # return strategy STRATEGY = return_strategy() # set mixed precision config if args.mixed_precision == 1: tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) args.mixed_precision = bool(args.mixed_precision) args.use_norm = bool(args.use_norm) # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # check arguments if args.train_dir is None: raise ValueError("Please specify --train-dir") if args.dev_dir is None: raise ValueError("Please specify --valid-dir") # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = tensorflow_tts.__version__ with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["mel_length_threshold"] else: mel_length_threshold = None if config["format"] == "npy": charactor_query = "*-ids.npy" mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy" duration_query = "*-durations.npy" f0_query = "*-raw-f0.npy" energy_query = "*-raw-energy.npy" else: raise ValueError("Only npy are supported.") # define train/valid dataset train_dataset = CharactorDurationF0EnergyMelDataset( root_dir=args.train_dir, charactor_query=charactor_query, mel_query=mel_query, duration_query=duration_query, f0_query=f0_query, energy_query=energy_query, f0_stat=args.f0_stat, energy_stat=args.energy_stat, mel_length_threshold=mel_length_threshold, ).create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync * config["gradient_accumulation_steps"], ) valid_dataset = CharactorDurationF0EnergyMelDataset( root_dir=args.dev_dir, charactor_query=charactor_query, mel_query=mel_query, duration_query=duration_query, f0_query=f0_query, energy_query=energy_query, f0_stat=args.f0_stat, energy_stat=args.energy_stat, mel_length_threshold=mel_length_threshold, ).create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, ) # define trainer trainer = FastSpeech2Trainer( config=config, strategy=STRATEGY, steps=0, epochs=0, is_mixed_precision=args.mixed_precision, ) with STRATEGY.scope(): # define model fastspeech = TFFastSpeech2( config=FastSpeech2Config(**config["fastspeech2_params"]) ) fastspeech._build() fastspeech.summary() if len(args.pretrained) > 1: fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True) logging.info( f"Successfully loaded pretrained weight from {args.pretrained}." ) # AdamW for fastspeech learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=config["optimizer_params"]["initial_learning_rate"], decay_steps=config["optimizer_params"]["decay_steps"], end_learning_rate=config["optimizer_params"]["end_learning_rate"], ) learning_rate_fn = WarmUp( initial_learning_rate=config["optimizer_params"]["initial_learning_rate"], decay_schedule_fn=learning_rate_fn, warmup_steps=int( config["train_max_steps"] * config["optimizer_params"]["warmup_proportion"] ), ) optimizer = AdamWeightDecay( learning_rate=learning_rate_fn, weight_decay_rate=config["optimizer_params"]["weight_decay"], beta_1=0.9, beta_2=0.98, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) _ = optimizer.iterations # compile trainer trainer.compile(model=fastspeech, optimizer=optimizer) # start training try: trainer.fit( train_dataset, valid_dataset, saved_path=os.path.join(config["outdir"], "checkpoints/"), resume=args.resume, ) except KeyboardInterrupt: trainer.save_checkpoint() logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser( description= "Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)" ) parser.add_argument("--train-dir", default=None, type=str, help="directory including training data. ") parser.add_argument("--dev-dir", default=None, type=str, help="directory including development data. ") parser.add_argument("--use-norm", default=1, type=int, help="usr norm-mels for train or raw.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help="checkpoint file path to resume training. (default=\"\")") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument("--mixed_precision", default=0, type=int, help="using mixed precision for generator or not.") args = parser.parse_args() # set mixed precision config if args.mixed_precision == 1: tf.config.optimizer.set_experimental_options( {"auto_mixed_precision": True}) args.mixed_precision = bool(args.mixed_precision) args.use_norm = bool(args.use_norm) # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # check arguments if args.train_dir is None: raise ValueError("Please specify --train-dir") if args.dev_dir is None: raise ValueError("Please specify --valid-dir") # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = tensorflow_tts.__version__ with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") # get dataset if config["remove_short_samples"]: mel_length_threshold = config["mel_length_threshold"] else: mel_length_threshold = None if config["format"] == "npy": charactor_query = "*-ids.npy" mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy" duration_query = "*-durations.npy" charactor_load_fn = np.load mel_load_fn = np.load duration_load_fn = np.load else: raise ValueError("Only npy are supported.") # define train/valid dataset train_dataset = CharactorDurationMelDataset( root_dir=args.train_dir, charactor_query=charactor_query, mel_query=mel_query, duration_query=duration_query, charactor_load_fn=charactor_load_fn, mel_load_fn=mel_load_fn, duration_load_fn=duration_load_fn, mel_length_threshold=mel_length_threshold, return_utt_id=False).create(is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], batch_size=config["batch_size"]) valid_dataset = CharactorDurationMelDataset( root_dir=args.dev_dir, charactor_query=charactor_query, mel_query=mel_query, duration_query=duration_query, charactor_load_fn=charactor_load_fn, mel_load_fn=mel_load_fn, duration_load_fn=duration_load_fn, mel_length_threshold=None, return_utt_id=False).create(is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], batch_size=config["batch_size"]) fastspeech = TFFastSpeech(config=FASTSPEECH_CONFIG.FastSpeechConfig( **config["fastspeech_params"])) fastspeech._build() fastspeech.summary() # define trainer trainer = FastSpeechTrainer(config=config, steps=0, epochs=0, is_mixed_precision=False) # AdamW for fastspeech learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=config["optimizer_params"] ["initial_learning_rate"], decay_steps=config["optimizer_params"]["decay_steps"], end_learning_rate=config["optimizer_params"]["end_learning_rate"]) learning_rate_fn = WarmUp( initial_learning_rate=config["optimizer_params"] ["initial_learning_rate"], decay_schedule_fn=learning_rate_fn, warmup_steps=int(config["train_max_steps"] * config["optimizer_params"]["warmup_proportion"])) optimizer = AdamWeightDecay( learning_rate=learning_rate_fn, weight_decay_rate=config["optimizer_params"]["weight_decay"], beta_1=0.9, beta_2=0.98, epsilon=1e-6, exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']) # compile trainer trainer.compile(model=fastspeech, optimizer=optimizer) # start training try: trainer.fit(train_dataset, valid_dataset, saved_path=os.path.join(config["outdir"], 'checkpoints/'), resume=args.resume) except KeyboardInterrupt: trainer.save_checkpoint() logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser(description="Train Tacotron2") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--rootdir", type=str, required=True, help="dataset directory root") parser.add_argument( "--resume", default="", type=str, nargs="?", help='checkpoint file path to resume training. (default="")') parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") parser.add_argument("--batch-size", default=16, type=int, help="batch size.") parser.add_argument("--mixed_precision", default=0, type=int, help="using mixed precision for generator or not.") parser.add_argument( "--pretrained", default="", type=str, nargs="?", help= 'pretrained weights .h5 file to load weights from. Auto-skips non-matching layers', ) args = parser.parse_args() if args.resume is not None and os.path.isdir(args.resume): args.resume = tf.train.latest_checkpoint(args.resume) # set mixed precision config if args.mixed_precision == 1: tf.config.optimizer.set_experimental_options( {"auto_mixed_precision": True}) args.mixed_precision = bool(args.mixed_precision) # set logger log_format = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" if args.verbose > 1: logging.basicConfig(level=logging.DEBUG, stream=sys.stdout, format=log_format) elif args.verbose > 0: logging.basicConfig(level=logging.INFO, stream=sys.stdout, format=log_format) else: logging.basicConfig(level=logging.WARN, stream=sys.stdout, format=log_format) logging.warning("Skip DEBUG/INFO messages") # check directory existence(checkpoint) if not os.path.exists(args.outdir): os.makedirs(args.outdir) # select processor Processor = JSpeechProcessor # for test class Generator(Processor.Generator): def __init__(self): super().__init__() self._scaler_energy = StandardScaler(copy=False) self._scaler_f0 = StandardScaler(copy=False) self._energy_stat = np.stack((0, 0)) self._f0_stat = np.stack((0, 0)) def __call__(self, rootdir, tid, seq, speaker): tid, seq, feat_path, speaker = super().__call__( rootdir, tid, seq, speaker) f0_path = os.path.join(rootdir, "f0", f"{tid}.f0") energy_path = os.path.join(rootdir, "energies", f"{tid}.e") duration_path = os.path.join(rootdir, "durations", f"{tid}.dur") with open(f0_path) as f: f0 = np.fromfile(f, dtype='float32') self._scaler_f0.partial_fit(f0[f0 != 0].reshape(-1, 1)) with open(energy_path) as f: energy = np.fromfile(f, dtype='float32') self._scaler_energy.partial_fit(energy[energy != 0].reshape( -1, 1)) return tid, seq, feat_path, f0_path, energy_path, duration_path, speaker def complete(self): self._f0_stat = np.stack( (self._scaler_f0.mean_, self._scaler_f0.scale_)) self._energy_stat = np.stack( (self._scaler_energy.mean_, self._scaler_energy.scale_)) print("energy stat: mean {}, scale {}".format( self._energy_stat[0], self._energy_stat[1])) print("f0 stat: mean {}, scale {}".format(self._f0_stat[0], self._f0_stat[1])) def energy_stat(self): return self._energy_stat def f0_stat(self): return self._f0_stat generator = Generator() processor = Processor(rootdir=args.rootdir, generator=generator) config = Config(args.outdir, args.batch_size, processor.vocab_size()) # split train and test train_split, valid_split = train_test_split(processor.items, test_size=config.test_size, random_state=42, shuffle=True) train_dataset = generate_datasets(train_split, config, generator.f0_stat(), generator.energy_stat()) valid_dataset = generate_datasets(valid_split, config, generator.f0_stat(), generator.energy_stat()) # define trainer trainer = FastSpeech2Trainer(config=config, strategy=STRATEGY, steps=0, epochs=0, is_mixed_precision=args.mixed_precision) with STRATEGY.scope(): # define model fastspeech = TFFastSpeech2(config=config) # build fastspeech._build() fastspeech.summary() if len(args.pretrained) > 1: fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True) logging.info( f"Successfully loaded pretrained weight from {args.pretrained}." ) # AdamW for fastspeech learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=config.initial_learning_rate, decay_steps=config.decay_steps, end_learning_rate=config.end_learning_rate, ) learning_rate_fn = WarmUp( initial_learning_rate=config.initial_learning_rate, decay_schedule_fn=learning_rate_fn, warmup_steps=int(config.train_max_steps * config.warmup_proportion)) optimizer = AdamWeightDecay( learning_rate=learning_rate_fn, weight_decay_rate=config.weight_decay, beta_1=0.9, beta_2=0.98, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) _ = optimizer.iterations # compile trainer trainer.compile(model=fastspeech, optimizer=optimizer) # start training try: trainer.fit(train_dataset, valid_dataset, saved_path=os.path.join(config.outdir, "checkpoints/"), resume=args.resume) except KeyboardInterrupt: trainer.save_checkpoint() logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser( description="Train/finetune Tacotron2 with reinitiated Embeddings") parser.add_argument( "--train-dir", default=None, type=str, help="directory including training data. ", ) parser.add_argument( "--dev-dir", default=None, type=str, help="directory including development data. ", ) parser.add_argument("--use-norm", default=1, type=int, help="usr norm-mels for train or raw.") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--config", type=str, required=True, help="yaml format configuration file.") parser.add_argument( "--resume", default="", type=str, nargs="?", help='checkpoint file path to resume training. (default="")', ) parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)", ) parser.add_argument( "--mixed_precision", default=0, type=int, help="using mixed precision for generator or not.", ) parser.add_argument( "--pretrained", default="", type=str, nargs="?", help= "pretrained weights .h5 file to load weights from. Auto-skips non-matching layers", ) args = parser.parse_args() # return strategy STRATEGY = return_strategy() # set mixed precision config if args.mixed_precision == 1: tf.config.optimizer.set_experimental_options( {"auto_mixed_precision": True}) args.mixed_precision = bool(args.mixed_precision) args.use_norm = bool(args.use_norm) # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) elif args.verbose > 0: logging.basicConfig( level=logging.INFO, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: logging.basicConfig( level=logging.WARN, stream=sys.stdout, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # check arguments if args.train_dir is None: raise ValueError("Please specify --train-dir") if args.dev_dir is None: raise ValueError("Please specify --valid-dir") # load and save config with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) config["version"] = tensorflow_tts.__version__ # get dataset if config["remove_short_samples"]: mel_length_threshold = config["mel_length_threshold"] else: mel_length_threshold = 0 if config["format"] == "npy": charactor_query = "*-ids.npy" mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy" charactor_load_fn = np.load mel_load_fn = np.load else: raise ValueError("Only npy are supported.") train_dataset = CharactorMelDataset( dataset=config["tacotron2_params"]["dataset"], root_dir=args.train_dir, charactor_query=charactor_query, mel_query=mel_query, charactor_load_fn=charactor_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, reduction_factor=config["tacotron2_params"]["reduction_factor"], use_fixed_shapes=config["use_fixed_shapes"], ) # update max_mel_length and max_char_length to config config.update({"max_mel_length": int(train_dataset.max_mel_length)}) config.update({"max_char_length": int(train_dataset.max_char_length)}) with open(os.path.join(args.outdir, "config.yml"), "w") as f: yaml.dump(config, f, Dumper=yaml.Dumper) for key, value in config.items(): logging.info(f"{key} = {value}") train_dataset = train_dataset.create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync * config["gradient_accumulation_steps"], ) valid_dataset = CharactorMelDataset( dataset=config["tacotron2_params"]["dataset"], root_dir=args.dev_dir, charactor_query=charactor_query, mel_query=mel_query, charactor_load_fn=charactor_load_fn, mel_load_fn=mel_load_fn, mel_length_threshold=mel_length_threshold, reduction_factor=config["tacotron2_params"]["reduction_factor"], use_fixed_shapes=False, # don't need apply fixed shape for evaluation. ).create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, ) # define trainer trainer = Tacotron2Trainer( config=config, strategy=STRATEGY, steps=0, epochs=0, is_mixed_precision=args.mixed_precision, ) with STRATEGY.scope(): # define model. tacotron_config = Tacotron2Config(**config["tacotron2_params"]) tacotron2 = TFTacotron2(config=tacotron_config, name="tacotron2") tacotron2._build() tacotron2.summary() if len(args.pretrained) > 1: tacotron2.load_weights(args.pretrained, by_name=True, skip_mismatch=True) logging.info( f"Successfully loaded pretrained weight from {args.pretrained}." ) # re-define embedding NEW_VOCAB_SIZE = 44 # 149 -> LJSpeech-mapper, 44 -> TPI-mapper tacotron_config.vocab_size = NEW_VOCAB_SIZE new_embedding_layers = TFTacotronEmbeddings(tacotron_config, name='embeddings') tacotron2.embeddings = new_embedding_layers # re-build model tacotron2._build() tacotron2.summary() # AdamW for tacotron2 learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=config["optimizer_params"] ["initial_learning_rate"], decay_steps=config["optimizer_params"]["decay_steps"], end_learning_rate=config["optimizer_params"]["end_learning_rate"], ) learning_rate_fn = WarmUp( initial_learning_rate=config["optimizer_params"] ["initial_learning_rate"], decay_schedule_fn=learning_rate_fn, warmup_steps=int(config["train_max_steps"] * config["optimizer_params"]["warmup_proportion"]), ) optimizer = AdamWeightDecay( learning_rate=learning_rate_fn, weight_decay_rate=config["optimizer_params"]["weight_decay"], beta_1=0.9, beta_2=0.98, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) _ = optimizer.iterations # compile trainer trainer.compile(model=tacotron2, optimizer=optimizer) # start training try: trainer.fit( train_dataset, valid_dataset, saved_path=os.path.join(config["outdir"], "checkpoints/"), resume=args.resume, ) except KeyboardInterrupt: trainer.save_checkpoint() logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
def main(): """Run training process.""" parser = argparse.ArgumentParser(description="Train Tacotron2") parser.add_argument("--outdir", type=str, required=True, help="directory to save checkpoints.") parser.add_argument("--rootdir", type=str, required=True, help="dataset directory root") parser.add_argument("--resume",default="",type=str,nargs="?",help='checkpoint file path to resume training. (default="")') parser.add_argument("--verbose",type=int,default=1,help="logging level. higher is more logging. (default=1)") parser.add_argument("--batch-size", default=12, type=int, help="batch size.") parser.add_argument("--mixed_precision",default=0,type=int,help="using mixed precision for generator or not.") args = parser.parse_args() if args.resume is not None and os.path.isdir(args.resume): args.resume = tf.train.latest_checkpoint(args.resume) # set mixed precision config if args.mixed_precision == 1: tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) args.mixed_precision = bool(args.mixed_precision) # set logger log_format = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" if args.verbose > 1: logging.basicConfig(level=logging.DEBUG,stream=sys.stdout,format=log_format) elif args.verbose > 0: logging.basicConfig(level=logging.INFO,stream=sys.stdout,format=log_format) else: logging.basicConfig(level=logging.WARN,stream=sys.stdout,format=log_format) logging.warning("Skip DEBUG/INFO messages") # check directory existence(checkpoint) if not os.path.exists(args.outdir): os.makedirs(args.outdir) # select processor Processor = JSpeechProcessor # for test processor = Processor(rootdir=args.rootdir) config = Config(args.outdir, args.batch_size, processor.vocab_size()) max_mel_length = processor.max_feat_length() // config.n_mels max_seq_length = processor.max_seq_length() # split train and test train_split, valid_split = train_test_split(processor.items, test_size=config.test_size,random_state=42,shuffle=True) train_dataset = generate_datasets(train_split, config, max_mel_length, max_seq_length) valid_dataset = generate_datasets(valid_split, config, max_mel_length, max_seq_length) # define trainer trainer = Tacotron2Trainer( config=config, strategy=STRATEGY, steps=0, epochs=0, is_mixed_precision=args.mixed_precision ) with STRATEGY.scope(): # define model. tacotron2 = TFTacotron2(config=config, training=True, name="tacotron2") #build input_ids = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9]]) input_lengths = np.array([9]) speaker_ids = np.array([0]) mel_outputs = np.random.normal(size=(1, 50, config.n_mels)).astype(np.float32) mel_lengths = np.array([50]) tacotron2(input_ids,input_lengths,speaker_ids,mel_outputs,mel_lengths,10,training=True) tacotron2.summary() # AdamW for tacotron2 learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=config.initial_learning_rate, decay_steps=config.decay_steps, end_learning_rate=config.end_learning_rate, ) learning_rate_fn = WarmUp( initial_learning_rate=config.initial_learning_rate, decay_schedule_fn=learning_rate_fn, warmup_steps=int(config.train_max_steps* config.warmup_proportion), ) optimizer = AdamWeightDecay( learning_rate=learning_rate_fn, weight_decay_rate=config.weight_decay, beta_1=0.9, beta_2=0.98, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) _ = optimizer.iterations # compile trainer trainer.compile(model=tacotron2, optimizer=optimizer) # start training try: trainer.fit(train_dataset,valid_dataset,saved_path=os.path.join(args.outdir, "checkpoints/"),resume=args.resume) except KeyboardInterrupt: trainer.save_checkpoint() logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")