def main(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True with open(args.config, 'r') as f: params = yaml.load(f, Loader=yaml.FullLoader) expdir = os.path.join('egs', params['data']['name'], 'exp', params['train']['save_name']) if not os.path.exists(expdir): os.makedirs(expdir) model = Transformer(params['model']) if args.ngpu >= 1: model.cuda() print(model) # build optimizer optimizer = TransformerOptimizer(model, params['train'], model_size=params['model']['d_model'], parallel_mode=args.parallel_mode) trainer = Trainer(params, model=model, optimizer=optimizer, is_visual=True, expdir=expdir, ngpu=args.ngpu, parallel_mode=args.parallel_mode, local_rank=args.local_rank) train_dataset = AudioDataset(params['data'], 'train') trainer.train(train_dataset=train_dataset)
def main(args): checkpoint = torch.load(args.load_model) if 'params' in checkpoint: params = checkpoint['params'] else: assert os.path.isfile(args.config), 'please specify a configure file.' with open(args.config, 'r') as f: params = yaml.load(f) params['data']['shuffle'] = False params['data']['spec_argument'] = False params['data']['short_first'] = False params['data']['batch_size'] = args.batch_size expdir = os.path.join('egs', params['data']['name'], 'exp', params['train']['save_name']) if args.suffix is None: decode_dir = os.path.join(expdir, 'decode_%s' % args.decode_set) else: decode_dir = os.path.join( expdir, 'decode_%s_%s' % (args.decode_set, args.suffix)) if not os.path.exists(decode_dir): os.makedirs(decode_dir) model = Transformer(params['model']) model.load_state_dict(checkpoint['model']) print('Load pre-trained model from %s' % args.load_model) model.eval() if args.ngpu > 0: model.cuda() char2unit = load_vocab(params['data']['vocab']) unit2char = {i: c for c, i in char2unit.items()} dataset = AudioDataset(params['data'], args.decode_set) data_loader = FeatureLoader(dataset) recognizer = TransformerRecognizer(model, unit2char=unit2char, beam_width=args.beam_width, max_len=args.max_len, penalty=args.penalty, lamda=args.lamda, ngpu=args.ngpu) totals = len(dataset) batch_size = params['data']['batch_size'] writer = open(os.path.join(decode_dir, 'predict.txt'), 'w') for step, (utt_id, batch) in enumerate(data_loader.loader): if args.ngpu > 0: inputs = batch['inputs'].cuda() inputs_length = batch['inputs_length'].cuda() preds = recognizer.recognize(inputs, inputs_length) targets = batch['targets'] targets_length = batch['targets_length'] for b in range(len(preds)): n = step * batch_size + b truth = ' '.join( [unit2char[i.item()] for i in targets[b][1:targets_length[b]]]) print('[%d / %d ] %s - pred : %s' % (n, totals, utt_id[b], preds[b])) print('[%d / %d ] %s - truth: %s' % (n, totals, utt_id[b], truth)) writer.write(utt_id[b] + ' ' + preds[b] + '\n') writer.close()