def main(argv): args = parse_args(argv) if args.path_checkpoint is None: if args.file_extension == '.pt': feature_function = load_pt elif args.file_extension == '.npy': feature_function = load_npy else: state_dict = torch.load(args.path_checkpoint) feature_maker = load_cpc_features(state_dict) feature_maker.cuda() def feature_function(x): return build_feature_from_file(x, feature_maker) # Modes if args.mode == 'all': modes = ["within", "across"] else: modes = [args.mode] step_feature = 1 / args.feature_size # Get the list of sequences seq_list = find_all_files(args.path_data, args.file_extension) scores = ABX(feature_function, args.path_item_file, seq_list, args.distance_mode, step_feature, modes, cuda=args.cuda, max_x_across=args.max_x_across, max_size_group=args.max_size_group) out_dir = Path(args.path_checkpoint).parent if args.out is None \ else Path(args.out) out_dir.mkdir(exist_ok=True) path_score = out_dir / 'ABX_scores.json' with open(path_score, 'w') as file: json.dump(scores, file, indent=2) path_args = out_dir / 'ABX_args.json' with open(path_args, 'w') as file: json.dump(vars(args), file, indent=2)
def per(args): # Load the model state_dict = torch.load(args.path_checkpoint) feature_maker = load_cpc_features(state_dict["model"]) feature_maker.cuda() feature_maker.eval() hidden_gar = feature_maker.get_output_dim() # Get the model training configuration path_config = Path(args.path_checkpoint).parent / "args_training.json" with open(path_config, 'rb') as file: config_training = json.load(file) n_phones = get_n_phones(config_training["path_phone_converter"]) phone_criterion = per_src.CTCPhoneCriterion( hidden_gar, n_phones, config_training["LSTM"], seqNorm=config_training["seqNorm"], dropout=config_training["dropout"], reduction=config_training["loss_reduction"]) phone_criterion.load_state_dict(state_dict["classifier"]) phone_criterion.cuda() downsamplingFactor = 160 # dataset inSeqs = find_all_files(args.pathDB, args.file_extension) phoneLabels = parse_phone_labels(args.pathPhone) datasetVal = per_src.SingleSequenceDataset(args.pathDB, inSeqs, phoneLabels, inDim=1) valLoader = DataLoader(datasetVal, batch_size=args.batchSize, shuffle=False) per_step(valLoader, feature_maker, phone_criterion, downsamplingFactor)
def main(args): description = "Training and evaluation of a letter classifier on top of a pre-trained CPC model. " "Please specify at least one `path_wer` (to calculate WER) or `path_train` and `path_val` (for training)." parser = argparse.ArgumentParser(description=description) parser.add_argument('--path_checkpoint', type=str) parser.add_argument('--path_train', default=None, type=str) parser.add_argument('--path_val', default=None, type=str) parser.add_argument('--n_epochs', type=int, default=30) parser.add_argument('--seed', type=int, default=7) parser.add_argument('--downsampling_factor', type=int, default=160) parser.add_argument('--lr', type=float, default=2e-04) parser.add_argument('--output', type=str, default='out', help="Output directory") parser.add_argument('--p_dropout', type=float, default=0.0) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--lm_weight', type=float, default=2.0) parser.add_argument('--path_wer', help="For computing the WER on specific sequences", action='append') parser.add_argument('--letters_path', type=str, default='WER_data/letters.lst') args = parser.parse_args(args=args) if not args.path_wer and not (args.path_train and args.path_val): print( 'Please specify at least one `path_wer` (to calculate WER) or `path_train` and `path_val` (for training).' ) if not os.path.isdir(args.output): os.mkdir(args.output) # creating models before reading the datasets with open(args.letters_path) as f: n_chars = len(f.readlines()) state_dict = torch.load(args.path_checkpoint) feature_maker = load_cpc_features(state_dict) feature_maker.cuda() hidden = feature_maker.get_output_dim() letter_classifier = LetterClassifier( feature_maker, hidden, n_chars, p_dropout=args.p_dropout if hasattr(args, 'p_dropout') else 0.0) criterion = CTCLetterCriterion(letter_classifier, n_chars) criterion.cuda() criterion = torch.nn.DataParallel(criterion) # Checkpoint file where the model should be saved path_checkpoint = os.path.join(args.output, 'checkpoint.pt') if args.path_train and args.path_val: set_seed(args.seed) char_labels_val, n_chars, _ = parse_ctc_labels_from_root( args.path_val, letters_path="./WER_data/letters.lst") print(f"Loading the validation dataset at {args.path_val}") dataset_val = SingleSequenceDataset(args.path_val, char_labels_val) val_loader = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False) # train dataset char_labels_train, n_chars, _ = parse_ctc_labels_from_root( args.path_train, letters_path="./WER_data/letters.lst") print(f"Loading the training dataset at {args.path_train}") dataset_train = SingleSequenceDataset(args.path_train, char_labels_train) train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True) # Optimizer g_params = list(criterion.parameters()) optimizer = torch.optim.Adam(g_params, lr=args.lr) args_path = os.path.join(args.output, "args_training.json") with open(args_path, 'w') as file: json.dump(vars(args), file, indent=2) run(train_loader, val_loader, criterion, optimizer, args.downsampling_factor, args.n_epochs, path_checkpoint) if args.path_wer: args = get_eval_args(args) state_dict = torch.load(path_checkpoint) criterion.load_state_dict(state_dict) criterion = criterion.module criterion.eval() args_path = os.path.join(args.output, "args_validation.json") with open(args_path, 'w') as file: json.dump(vars(args), file, indent=2) for path_wer in args.path_wer: print(f"Loading the validation dataset at {path_wer}") char_labels_wer, _, (letter2index, index2letter) = parse_ctc_labels_from_root( path_wer, letters_path="./WER_data/letters.lst") dataset_eval = SingleSequenceDataset(path_wer, char_labels_wer) eval_loader = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=False) wer = eval_wer(eval_loader, criterion, args.lm_weight, index2letter) print(f'WER: {wer}')
def train(args): # Output Directory if not os.path.isdir(args.output): os.mkdir(args.output) name = f"_{args.name}" if args.command == "per" else "" pathLogs = os.path.join(args.output, f'logs_{args.command}{name}.txt') tee = subprocess.Popen(["tee", pathLogs], stdin=subprocess.PIPE) os.dup2(tee.stdin.fileno(), sys.stdout.fileno()) nPhones = get_n_phones(args.path_phone_converter) phoneLabels = parse_phone_labels(args.pathPhone) inSeqs = find_all_files(args.pathDB, args.file_extension) # Model downsamplingFactor = 160 state_dict = torch.load(args.pathCheckpoint) featureMaker = load_cpc_features(state_dict) hiddenGar = featureMaker.get_output_dim() featureMaker.cuda() featureMaker = torch.nn.DataParallel(featureMaker) # Criterion phoneCriterion = per_src.CTCPhoneCriterion(hiddenGar, nPhones, args.LSTM, seqNorm=args.seqNorm, dropout=args.dropout, reduction=args.loss_reduction) phoneCriterion.cuda() phoneCriterion = torch.nn.DataParallel(phoneCriterion) # Datasets if args.command == 'train' and args.pathTrain is not None: seqTrain = filter_seq(args.pathTrain, inSeqs) else: seqTrain = inSeqs if args.pathVal is None: random.shuffle(seqTrain) sizeTrain = int(0.9 * len(seqTrain)) seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:] elif args.pathVal is not None: seqVal = filter_seq(args.pathVal, inSeqs) print(len(seqVal), len(inSeqs), args.pathVal) if args.debug: seqVal = seqVal[:100] print(f"Loading the validation dataset at {args.pathDB}") datasetVal = per_src.SingleSequenceDataset(args.pathDB, seqVal, phoneLabels, inDim=args.in_dim) valLoader = DataLoader(datasetVal, batch_size=args.batchSize, shuffle=True) # Checkpoint file where the model should be saved pathCheckpoint = os.path.join(args.output, 'checkpoint.pt') featureMaker.optimize = True if args.freeze: featureMaker.eval() featureMaker.optimize = False for g in featureMaker.parameters(): g.requires_grad = False if args.debug: print("debug") random.shuffle(seqTrain) seqTrain = seqTrain[:1000] seqVal = seqVal[:100] print(f"Loading the training dataset at {args.pathDB}") datasetTrain = per_src.SingleSequenceDataset(args.pathDB, seqTrain, phoneLabels, inDim=args.in_dim) trainLoader = DataLoader(datasetTrain, batch_size=args.batchSize, shuffle=True) # Optimizer g_params = list(phoneCriterion.parameters()) if not args.freeze: print("Optimizing model") g_params += list(featureMaker.parameters()) optimizer = torch.optim.Adam(g_params, lr=args.lr, betas=(args.beta1, args.beta2), eps=args.epsilon) pathArgs = os.path.join(args.output, "args_training.json") with open(pathArgs, 'w') as file: json.dump(vars(args), file, indent=2) run_training(trainLoader, valLoader, featureMaker, phoneCriterion, optimizer, downsamplingFactor, args.nEpochs, pathCheckpoint)