def get_runner_args(): parser = argparse.ArgumentParser(description='Argument Parser for the S3PLR project.') # setting parser.add_argument('--config', default='../config/deprecated_runner/tera_libri_fmllrBase_pretrain,yaml', type=str, help='Path to experiment config.', required=False) parser.add_argument('--seed', default=1337, type=int, help='Random seed for reproducable results.', required=False) # Logging parser.add_argument('--logdir', default='../log/log_transformer/', type=str, help='Logging path.', required=False) parser.add_argument('--name', default=None, type=str, help='Name for logging.', required=False) # model ckpt parser.add_argument('--load', action='store_true', help='Load pre-trained model to restore training, no need to specify this during testing.') parser.add_argument('--ckpdir', default='../result/result_transformer/', type=str, help='path to store experiment result.', required=False) parser.add_argument('--ckpt', default='fmllrBase960-F-N-K-libri/states-1000000.ckpt', type=str, help='path to transformer model checkpoint.', required=False) parser.add_argument('--dckpt', default='baseline_sentiment_libri_sd1337/baseline_sentiment-500000.ckpt', type=str, help='path to downstream checkpoint.', required=False) parser.add_argument('--apc_path', default='../result/result_apc/apc_libri_sd1337_standard/apc-500000.ckpt', type=str, help='path to the apc model checkpoint.', required=False) # mockingjay parser.add_argument('--train', action='store_true', help='Train the model.') parser.add_argument('--run_transformer', action='store_true', help='train and test the downstream tasks using speech representations.') parser.add_argument('--run_apc', action='store_true', help='train and test the downstream tasks using apc representations.') parser.add_argument('--fine_tune', action='store_true', help='fine tune the transformer model with downstream task.') parser.add_argument('--plot', action='store_true', help='Plot model generated results during testing.') # phone task parser.add_argument('--train_phone', action='store_true', help='Train the phone classifier on mel or speech representations.') parser.add_argument('--test_phone', action='store_true', help='Test mel or speech representations using the trained phone classifier.') # cpc phone task parser.add_argument('--train_cpc_phone', action='store_true', help='Train the phone classifier on mel or speech representations with the alignments in CPC paper.') parser.add_argument('--test_cpc_phone', action='store_true', help='Test mel or speech representations using the trained phone classifier with the alignments in CPC paper.') # sentiment task parser.add_argument('--train_sentiment', action='store_true', help='Train the sentiment classifier on mel or speech representations.') parser.add_argument('--test_sentiment', action='store_true', help='Test mel or speech representations using the trained sentiment classifier.') # speaker verification task parser.add_argument('--train_speaker', action='store_true', help='Train the speaker classifier on mel or speech representations.') parser.add_argument('--test_speaker', action='store_true', help='Test mel or speech representations using the trained speaker classifier.') # Options parser.add_argument('--with_head', action='store_true', help='inference with the spectrogram head, the model outputs spectrogram.') parser.add_argument('--plot_attention', action='store_true', help='plot attention') parser.add_argument('--load_ws', default='result/result_transformer_sentiment/10111754-10170300-weight_sum/best_val.ckpt', help='load weighted-sum weights from trained downstream model') parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') parser.add_argument('--multi_gpu', action='store_true', help='Enable Multi-GPU training.') parser.add_argument('--no_msg', action='store_true', help='Hide all messages.') parser.add_argument('--test_reconstruct', action='store_true', help='Test reconstruction capability') # parse args = parser.parse_args() setattr(args,'gpu', not args.cpu) setattr(args,'verbose', not args.no_msg) config = yaml.load(open(args.config,'r'), Loader=yaml.FullLoader) parse_prune_heads(config) return config, args
def get_upstream_args(): parser = argparse.ArgumentParser( description='Argument Parser for Upstream Models of the S3PLR project.' ) # required parser.add_argument('--run', choices=['transformer', 'apc'], help='Select pre-training task. \ For the transformer models, which type of pre-training (mockingjay, tera, aalbert, etc) \ is determined by config file.', required=True) parser.add_argument('--config', type=str, help='Path to experiment config.', required=True) # ckpt and logging parser.add_argument('--name', default=None, type=str, help='Name for logging.', required=False) parser.add_argument( '--ckpdir', default='', type=str, help='Path to store checkpoint result, if empty then default is used.', required=False) parser.add_argument('--seed', default=1337, type=int, help='Random seed for reproducable results.', required=False) # Options parser.add_argument('--test', default='', type=str, help='Input path to the saved model ckpt for testing.') parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') parser.add_argument('--multi_gpu', action='store_true', help='Enable Multi-GPU training.') parser.add_argument('--test_reconstruct', action='store_true', help='Test reconstruction capability') # parse args = parser.parse_args() setattr(args, 'gpu', not args.cpu) config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader) parse_prune_heads(config) return args, config
def get_upstream_args(): parser = argparse.ArgumentParser(description='Argument Parser for Upstream Models of the S3PLR project.') # required, set either (--run and --config) or (--resume) parser.add_argument('--run', default=None, choices=['transformer', 'apc'], help='Select pre-training task. \ For the transformer models, which type of pre-training (mockingjay, tera, aalbert, etc) \ is determined by config file.') parser.add_argument('--config', default=None, type=str, help='Path to experiment config.') parser.add_argument('--resume', default=None, help='Specify the upstream checkpoint path to resume training') # ckpt and logging parser.add_argument('--name', default=None, type=str, help='Name for logging.') parser.add_argument('--ckpdir', default='', type=str, help='Path to store checkpoint result, if empty then default is used.') parser.add_argument('--seed', default=1337, type=int, help='Random seed for reproducable results.') # Options parser.add_argument('--test', default='', type=str, help='Input path to the saved model ckpt for testing.') parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') parser.add_argument('--multi_gpu', action='store_true', help='Enable Multi-GPU training.') parser.add_argument('--test_reconstruct', action='store_true', help='Test reconstruction capability') parser.add_argument('--online_config', default=None, help='Explicitly specify the config of on-the-fly feature extraction') parser.add_argument('--kaldi_data', action='store_true', help='Whether to use the Kaldi dataset') # parse args = parser.parse_args() if args.resume is None: assert args.run is not None and args.config is not None, '`--run` and `--config` must be given if `--resume` is not provided' setattr(args, 'gpu', not args.cpu) config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader) parse_prune_heads(config) if args.online_config is not None: online_config = yaml.load(open(args.online_config, 'r'), Loader=yaml.FullLoader) config['online'] = online_config else: if os.path.isdir(args.resume): ckpts = glob.glob(f'{args.resume}/*.ckpt') assert len(ckpts) > 0 ckpts = sorted(ckpts, key=lambda pth: int(pth.split('-')[-1].split('.')[0])) resume_ckpt = ckpts[-1] else: resume_ckpt = args.resume def update_args(old, new): old_dict = vars(old) new_dict = vars(new) old_dict.update(new_dict) return Namespace(**old_dict) ckpt = torch.load(resume_ckpt, map_location='cpu') args = update_args(args, ckpt['Settings']['Paras']) config = ckpt['Settings']['Config'] setattr(args, 'resume', resume_ckpt) return args, config