parser = argparse.ArgumentParser() parser.add_argument("--tune_period", type=int, default=100, help="How many steps per learning rate.") parser.add_argument("--checkpoint", type=str, default="-1", help="The checkpoint in the pre-trained model. The default is to load the BEST checkpoint (according to valid_loss)") parser.add_argument("--config", type=str, help="The configuration file.") parser.add_argument("train_dir", type=str, help="The data directory of the training set.") parser.add_argument("train_spklist", type=str, help="The spklist file maps the TRAINING speakers to the indices.") parser.add_argument("valid_dir", type=str, help="The data directory of the validation set.") parser.add_argument("valid_spklist", type=str, help="The spklist maps the VALID speakers to the indices.") parser.add_argument("pretrain_model", type=str, help="The pre-trained model directory.") parser.add_argument("finetune_model", type=str, help="The fine-tuned model directory") if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) args = parser.parse_args() params = save_codes_and_config(False, args.finetune_model, args.config) # Load the pre-trained model to the target model directory. # The pre-trained model will be copied as the fine-tuned model and can be loaded from the new directory. # The pre-trained model is now just like an initialized model. get_pretrain_model(os.path.join(args.pretrain_model, "nnet"), os.path.join(args.finetune_model, "nnet"), args.checkpoint) # The model directory always has a folder named nnet model_dir = os.path.join(args.finetune_model, "nnet") # Set the random seed. The random operations may appear in data input, batch forming, etc. tf.set_random_seed(params.seed) random.seed(params.seed) np.random.seed(params.seed)
from dataset.data_loader import KaldiDataRandomQueue from dataset.kaldi_io import FeatureReader parser = argparse.ArgumentParser() parser.add_argument("data_dir", type=str, help="The data directory of the dataset.") parser.add_argument("data_spklist", type=str, help="The spklist maps the speakers to the indices.") parser.add_argument("model", type=str, help="The output model directory.") if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) args = parser.parse_args() params = save_codes_and_config(True, args.model, None) # The model directory always has a folder named nnet model_dir = os.path.join(args.model, "nnet") # Set the random seed. The random operations may appear in data input, batch forming, etc. tf.set_random_seed(params.seed) random.seed(params.seed) np.random.seed(params.seed) dim = FeatureReader(args.data_dir).get_dim() if "selected_dim" in params.dict: dim = params.selected_dim with open(args.data_spklist, 'r') as f: num_total_train_speakers = len(f.readlines())
parser.add_argument("train_data_dir", type=str, help="The data directory of the training set.") parser.add_argument("train_ali_dir", type=str, help="The ali directory of the training set.") parser.add_argument( "train_spklist", type=str, help="The spklist file maps the TRAINING speakers to the indices.") parser.add_argument("model", type=str, help="The output model directory.") if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) args = parser.parse_args() params = save_codes_and_config(False, args.model, args.config) # The model directory always has a folder named nnet model_dir = os.path.join(args.model, "nnet") # Set the random seed. The random operations may appear in data input, batch forming, etc. tf.set_random_seed(params.seed) random.seed(params.seed) np.random.seed(params.seed) start_epoch = 0 feat_reader = FeatureReaderV2(args.train_data_dir, args.train_ali_dir) dim = feat_reader.get_dim() feat_reader = KaldiDataRandomQueueV2(args.train_data_dir, args.train_ali_dir,
os.environ['CUDA_VISIBLE_DEVICES'] = '2' parser = argparse.ArgumentParser() parser.add_argument("-c", "--cont", action="store_true", help="Continue training from an existing model.") parser.add_argument("--config", type=str, help="The configuration file.") parser.add_argument("train_dir", type=str, help="The data directory of the training set.") parser.add_argument("train_spklist", type=str, help="The spklist file maps the TRAINING speakers to the indices.") parser.add_argument("valid_dir", type=str, help="The data directory of the validation set.") parser.add_argument("valid_spklist", type=str, help="The spklist maps the VALID speakers to the indices.") parser.add_argument("model", type=str, help="The output model directory.") if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) args = parser.parse_args() params = save_codes_and_config(args.cont, args.model, args.config) # The model directory always has a folder named nnet model_dir = os.path.join(args.model, "nnet") # Set the random seed. The random operations may appear in data input, batch forming, etc. tf.set_random_seed(params.seed) random.seed(params.seed) np.random.seed(params.seed) if args.cont: # If we continue training, we can figure out how much steps the model has been trained, # using the index of the checkpoint import re ckpt = tf.train.get_checkpoint_state(model_dir) if ckpt and ckpt.model_checkpoint_path:
action="store_true", help="About whether to continue training.") parser.add_argument("--config", type=str, help="The configuration file.") parser.add_argument("train_dir", type=str, help="The data directory of the training set.") parser.add_argument( "train_spklist", type=str, help="The spklist file maps the TRAINING speakers to the indices.") parser.add_argument("model", type=str, help="The output model directory.") if __name__ == '__main__': args = parser.parse_args() params = save_codes_and_config(args.continue_training, args.model, args.config) model_dir = os.path.join(args.model, "nnet") os.environ['CUDA_VISIBLE_DEVICES'] = params.gpu_id torch.manual_seed(params.random_seed) np.random.seed(params.random_seed) random.seed(params.random_seed) dim = FeatureReader(args.train_dir).get_dim() with open(os.path.join(model_dir, "feature_dim"), 'w') as f: f.write("%d\n" % dim) num_total_train_speakers = KaldiDataRandomQueue( args.train_dir, args.train_spklist).num_total_speakers # 训练说话人数目