def test_tensor2str(): alphabet = Alphabet('-abc ', blank_index=0) decoder = Decoder(alphabet) expected = ['ab c', 'aa-', 'c ', 'a', ''] tensor = torch.tensor([ [1, 2, 4, 3, 0], [1, 1, 0, 0, 0], [3, 4, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], ], dtype=torch.int) sizes = [4, 3, 2, 1, 0] with pytest.raises(ValueError) as excinfo: decoder.tensor2str(torch.tensor([[[1, 2, 3]]])) assert '`tensor.dim()` != 1 or 2' in str(excinfo.value) with pytest.raises(ValueError) as excinfo: decoder.tensor2str(tensor.float()) assert 'must be int' in str(excinfo.value) output = decoder.tensor2str(tensor, sizes) for o, e in zip(output, expected): assert o == e output = decoder.tensor2str(tensor) for o, t in zip(output, tensor): assert o == alphabet.idx2str(t)
def test_real_ctc_beam_decoder(): labels = ' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_' alphabet = Alphabet(labels, blank_index=labels.index('_')) log_input = torch.load( os.path.join(base_dir, "data/rnn_output_log_softmax.pth")) sizes = torch.tensor([log_input.shape[1]]) # greedy using beam decoder = BeamCTCDecoder(alphabet, beam_width=1) decode_result, scores, timesteps = decoder.decode(log_input, sizes) assert "the fak friend of the fomly hae tC" == decode_result[0] # default beam decoding decoder = BeamCTCDecoder(alphabet, beam_width=25) decode_result, scores, timesteps = decoder.decode(log_input, sizes) assert "the fak friend of the fomcly hae tC" == decode_result[0] # lm-based decoding decoder = BeamCTCDecoder(alphabet, lm_path=os.path.join(base_dir, 'data', 'bigram.arpa'), beam_width=25, alpha=2, beta=0) decode_result, scores, timesteps = decoder.decode(log_input, sizes) assert "the fake friend of the family, like the" == decode_result[0]
def test_wer(): alphabet = Alphabet('-abc ', blank_index=0) with pytest.raises(ValueError) as excinfo: metric = WER() assert '`alphabet` is required' in str(excinfo.value) targets = ['abcc abc', 'abc'] targets = [torch.tensor(alphabet.str2idx(t)) for t in targets] targets = torch.cat(targets) targets_size = torch.tensor([8, 3]) outputs_str = ['aa-b-c-bc abb-c--a', 'a-bbbb-cc---------aa'] outputs_ints = [torch.tensor(alphabet.str2idx(o)) for o in outputs_str] outputs = torch.zeros(2, 20, 5) for i, o in enumerate(outputs_ints): outputs[i, ...].scatter_(1, torch.as_tensor([[i] for i in o]), 1) outputs_size = torch.tensor([17, 9]) wers = torch.tensor([1 / 2, 0]) metric = WER(alphabet=alphabet) assert repr(metric) == 'WER' metric.update(['loss', outputs, targets, outputs_size, targets_size]) assert metric.val == wers.mean() assert metric.count == 2 assert str(metric) == f'WER {wers.mean(0):.02%} ({wers.mean(0):.02%})' outputs_str = ['aa-b-c-cc abb-c---', 'a-bbbb-cc-----------'] outputs_ints = [torch.tensor(alphabet.str2idx(o)) for o in outputs_str] outputs = torch.zeros(2, 20, 5) for i, o in enumerate(outputs_ints): outputs[i, ...].scatter_(1, torch.as_tensor([[i] for i in o]), 1) metric.update(['loss', outputs, targets, outputs_size, targets_size]) assert metric.val == 0 assert metric.avg == (1 / 2) / 4 assert metric.count == 4
def test_constructor(tmpdir, model, params): serialization_dir = tmpdir / 'serialization_dir' loss = losses.CTCLoss(backend='pytorch') alphabet = Alphabet('-abc ', blank_index=0) missing_params = copy.deepcopy(params) del missing_params['optimizer'] with pytest.raises(ConfigurationError) as excinfo: Trainer(serialization_dir, copy.deepcopy(missing_params), model, loss, alphabet, device='cpu') assert "key 'optimizer' is required" in str(excinfo.value) allowed_missing_params = copy.deepcopy(params) del allowed_missing_params['lr_scheduler'] trainer = Trainer(serialization_dir, copy.deepcopy(allowed_missing_params), model, loss, alphabet, device='cpu') assert trainer.lr_scheduler == None trainer = Trainer(serialization_dir, copy.deepcopy(params), model, loss, alphabet, device='cpu') for phase in ['train', 'val']: assert isinstance(trainer.metrics[phase], Metrics) assert trainer.monitor == 'loss' assert trainer.clip_grad_norm == 400 assert trainer.clip_grad_value == None assert trainer.start_epoch == 0 assert trainer.start_iteration == 0 assert trainer.iterations_per_epoch == None assert trainer.start_time == 0 params['monitor'] = 'cer' trainer = Trainer(serialization_dir, copy.deepcopy(params), model, loss, alphabet, device='cpu') assert trainer.monitor == 'cer'
def trainer(tmpdir_factory, params, model): serialization_dir = tmpdir_factory.mktemp('serialization_dir') loss = losses.CTCLoss(backend='pytorch') alphabet = Alphabet('-abc ', blank_index=0) trainer = Trainer(str(serialization_dir), params, model, loss, alphabet, device='cpu') return trainer
def tokenizer(args): if args.unit == 'word': raise ValueError('Not implemented yet') tokens = Alphabet.from_file(args.tokens) if '<space>' in tokens: raise ValueError( f'Reserved token `<space>` found in {str(args.tokens)}') lines = args.infile.readlines() for line in tqdm(lines, unit='line'): l = ' '.join(list(line.strip().replace(' ', '@'))) args.outfile.write(l + '\n')
def test_from_params(): params = ['cer', 'wer'] alphabet = Alphabet('-abc ', blank_index=0) m = metrics.from_params(params, alphabet=alphabet) assert isinstance(m, Metrics) assert isinstance(m[0], CER) assert isinstance(m[1], WER) params = [{'type': 'cer'}, 'wer'] m = metrics.from_params(params, alphabet=alphabet) assert isinstance(m, Metrics) assert isinstance(m[0], CER) assert isinstance(m[1], WER)
def test_base_decoder(): with pytest.raises(TypeError) as excinfo: decoder = Decoder() assert "missing 1 required positional argument: 'alphabet'" in str( excinfo.value) alphabet = Alphabet('-abc ', blank_index=0) decoder = Decoder(alphabet) assert hasattr(decoder, 'alphabet') assert decoder.wer('a bcd c', 'a dcc c') == 1 assert decoder.cer('a bc c', 'a dcc c') == 2 assert decoder.cer('a bcc', 'a dcc c', remove_space=True) == 2 with pytest.raises(NotImplementedError): decoder.decode(None, None)
def load(model_params, serialization_dir, weights_file=None, device='cpu'): weights_file = weights_file or os.path.join(serialization_dir, DEFAULT_WEIGHTS) # Load vocabulary from file alphabet_file = os.path.join(serialization_dir, 'vocabulary', 'alphabet') # If the config specifies a vocabulary subclass, we need to use it. alphabet = Alphabet.from_file(alphabet_file) default_params = {'num_classes': len(alphabet)} # Loading weights logger.info(f'Loading weights from {weights_file}.') state_dict = torch.load(weights_file, map_location='cpu') state_dict = {re.sub(r'^module.', '', k): v for k, v in state_dict.items()} model_name = model_params.pop('type') model_params = {**default_params, **model_params} model = by_name(model_name)(**model_params) model.load_state_dict(state_dict) model = model.to(device) return model
def evaluate_from_args(args): # Disable some of the more verbose logging statements logging.getLogger('asr.common.params').disabled = True logging.getLogger('asr.common.registrable').disabled = True # Load from archive _, weights_file = load_archive(args.serialization_dir, args.overrides, args.weights_file) params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME), args.overrides) prepare_environment(params) # Try to use the validation dataset reader if there is one - otherwise fall back # to the default dataset_reader used for both training and validation. dataset_params = params.pop('val_dataset', params.get('dataset_reader')) logger.info("Reading evaluation data from %s", args.input_file) dataset_params['manifest_filepath'] = args.input_file dataset = datasets.from_params(dataset_params) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file( os.path.join(args.serialization_dir, "alphabet", "tokens")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) logits_dir = os.path.join(args.serialization_dir, 'logits') os.makedirs(logits_dir, exist_ok=True) basename = os.path.splitext(os.path.split(args.input_file)[1])[0] print(basename) logits_file = os.path.join(logits_dir, basename + '.pth') if not os.path.exists(logits_file): model = models.from_params(alphabet=alphabet, params=params.pop('model')) model.load_state_dict( torch.load(weights_file, map_location=lambda storage, loc: storage)['model']) model.eval() decoder = GreedyCTCDecoder(alphabet) loader_params = params.pop("val_data_loader", params.get("data_loader")) batch_sampler = samplers.BucketingSampler(dataset, batch_size=args.batch_size) loader = loaders.from_params(loader_params, dataset=dataset, batch_sampler=batch_sampler) logger.info(f'Logits file `{logits_file}` not found. Generating...') with torch.no_grad(): model.to(args.device) logits = [] total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0 for batch in tqdm.tqdm(loader): sample, target, sample_lengths, target_lengths = batch sample = sample.to(args.device) sample_lengths = sample_lengths.to(args.device) output, output_lengths = model(sample, sample_lengths) output = output.to('cpu') references = decoder.tensor2str(target, target_lengths) transcripts = decoder.decode(output)[0] logits.extend( (o[:l, ...], r) for o, l, r in zip(output, output_lengths, references)) del sample, sample_lengths, output for reference, transcript in zip(references, transcripts): total_wer += decoder.wer(transcript, reference) total_cer += decoder.cer(transcript, reference) num_tokens += float(len(reference.split())) num_chars += float(len(reference)) torch.save(logits, logits_file) wer = float(total_wer) / num_tokens cer = float(total_cer) / num_chars print(f'WER: {wer:.02%}\nCER: {cer:.02%}') del model else: logger.info(f'Logits file `{logits_file}` already generated.')
def tune_from_args(args): # Disable some of the more verbose logging statements logging.getLogger('asr.common.params').disabled = True logging.getLogger('asr.common.registrable').disabled = True # Load from archive _, weights_file = load_archive(args.serialization_dir, args.overrides, args.weights_file) params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME), args.overrides) prepare_environment(params) # Try to use the validation dataset reader if there is one - otherwise fall back # to the default dataset_reader used for both training and validation. dataset_params = params.pop('val_dataset', params.get('dataset_reader')) logger.info("Reading evaluation data from %s", args.input_file) dataset_params['manifest_filepath'] = args.input_file dataset = datasets.from_params(dataset_params) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file( os.path.join(args.serialization_dir, "alphabet", "tokens")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) logits_dir = os.path.join(args.serialization_dir, 'logits') os.makedirs(logits_dir, exist_ok=True) basename = os.path.splitext(os.path.split(args.input_file)[1])[0] logits_file = os.path.join(logits_dir, basename + '.pth') if not os.path.exists(logits_file): model = models.from_params(alphabet=alphabet, params=params.pop('model')) model.load_state_dict( torch.load(weights_file, map_location=lambda storage, loc: storage)['model']) model.eval() decoder = GreedyCTCDecoder(alphabet) loader_params = params.pop("val_data_loader", params.get("data_loader")) batch_sampler = samplers.BucketingSampler(dataset, batch_size=args.batch_size) loader = loaders.from_params(loader_params, dataset=dataset, batch_sampler=batch_sampler) logger.info(f'Logits file `{logits_file}` not found. Generating...') with torch.no_grad(): model.to(args.device) logits = [] for batch in tqdm.tqdm(loader): sample, target, sample_lengths, target_lengths = batch sample = sample.to(args.device) sample_lengths = sample_lengths.to(args.device) output, output_lengths = model(sample, sample_lengths) output = output.to('cpu') references = decoder.tensor2str(target, target_lengths) logits.extend((o[:l, ...], r) for o, l, r in zip( output.to('cpu'), output_lengths, references)) del sample, sample_lengths, output torch.save(logits, logits_file) del model tune_dir = os.path.join(args.serialization_dir, 'tune') os.makedirs(tune_dir, exist_ok=True) params_grid = list( product( torch.linspace(args.alpha_from, args.alpha_to, args.alpha_steps), torch.linspace(args.beta_from, args.beta_to, args.beta_steps))) print( 'Scheduling {} jobs for alphas=linspace({}, {}, {}), betas=linspace({}, {}, {})' .format(len(params_grid), args.alpha_from, args.alpha_to, args.alpha_steps, args.beta_from, args.beta_to, args.beta_steps)) # start worker processes logger.info( f"Using {args.num_workers} processes and {args.lm_workers} for each CTCDecoder." ) extract_start = default_timer() p = Pool(args.num_workers, init, [ logits_file, alphabet, args.lm_path, args.cutoff_top_n, args.cutoff_prob, args.beam_width, args.lm_workers ]) scores = [] best_wer = float('inf') with tqdm.tqdm(p.imap(tune_step, params_grid), total=len(params_grid), desc='Grid search') as pbar: for params in pbar: alpha, beta, wer, cer = params scores.append([alpha, beta, wer, cer]) if wer < best_wer: best_wer = wer pbar.set_postfix(alpha=alpha, beta=beta, wer=wer, cer=cer) logger.info( f"Finished {len(params_grid)} processes in {default_timer() - extract_start:.1f}s" ) df = pd.DataFrame(scores, columns=['alpha', 'beta', 'wer', 'cer']) df.to_csv(os.path.join(tune_dir, basename + '.csv'), index=False)
def test_load_checkpoint(tmpdir, caplog, model, params): serialization_dir = (tmpdir / 'serialization_dir').mkdir() loss = losses.CTCLoss(backend='pytorch') alphabet = Alphabet('-abc ', blank_index=0) trainer = Trainer(serialization_dir, params, model, loss, alphabet, device='cpu') # no checkpoint trainer.load_checkpoint() assert not ('Last model checkpoint found' in caplog.record_tuples[-1][2]) (serialization_dir / 'models').mkdir() # no checkpoint trainer.load_checkpoint() assert not ('Last model checkpoint found' in caplog.record_tuples[-1][2]) # mocking calls ckpt_dict = { 'model': 'model-mock', 'optimizer': 'optimizer-mock', 'best_monitor': 2.0, 'metrics': { 'train': 'train-metric-mock', 'val': 'val-metric-mock' }, 'epoch': 1, 'epoch_iterations': 8, 'iterations_per_epoch': 10 } trainer.model.load_state_dict = Mock() trainer.optimizer.load_state_dict = Mock() for split in ['train', 'val']: trainer.metrics[split].load_state_dict = Mock() torch.load = Mock(return_value=ckpt_dict) # find biggest iterations f1 = (serialization_dir / 'models').join('model-20.pth') f1.write('') f2 = (serialization_dir / 'models').join('model-25.pth') f2.write('') trainer.load_checkpoint() assert 'Last model checkpoint found' in caplog.record_tuples[-1][2] torch.load.assert_called_with(str(f2), map_location='cpu') trainer.model.load_state_dict.assert_called_with('model-mock') trainer.optimizer.load_state_dict.assert_called_with('optimizer-mock') for split in ['train', 'val']: trainer.metrics[split].load_state_dict.assert_called_with( f'{split}-metric-mock') assert trainer.best_monitor == 2.0 assert trainer.start_epoch == 1 assert trainer.start_iteration == 8 assert trainer.iterations_per_epoch == 10
def test_save_checkpoint(tmpdir, model, params): serialization_dir = tmpdir / 'serialization_dir' loss = losses.CTCLoss(backend='pytorch') alphabet = Alphabet('-abc ', blank_index=0) trainer = Trainer(serialization_dir, params, model, loss, alphabet, device='cpu') # Mocking variables trainer.iterations_per_epoch = 10 trainer.epoch = 0 trainer.start_time = time.time() # should do nothing trainer.save_checkpoint(iteration=5, is_train=True) assert not (serialization_dir / 'models').exists() # should save, end of epoch trainer.model.state_dict = Mock(return_value='model dict') trainer.optimizer.state_dict = Mock(return_value='optimizer dict') trainer.score = Mock(return_value=float('inf')) trainer.save_checkpoint(iteration=9, is_train=True) assert (serialization_dir / 'models').exists() assert len(glob.glob(str(serialization_dir / 'models' / '*'))) == 1 assert (serialization_dir / 'models' / 'model-10.pth').exists() ckpt_dict = torch.load(str(serialization_dir / 'models' / 'model-10.pth')) empty_metric_state_dict = { 'val': 0, 'avg': 0, 'count': 0, 'sum': 0, 'history': [] } empty_metrics_state_dict = [{ 'type': 'asr.metrics.Loss', 'state_dict': empty_metric_state_dict }, { 'type': 'asr.metrics.CER', 'state_dict': empty_metric_state_dict }, { 'type': 'asr.metrics.WER', 'state_dict': empty_metric_state_dict }] expected_ckpt_dict = { 'model': 'model dict', 'epoch': 1, 'epoch_iterations': 0, 'iterations_per_epoch': 10, 'best_monitor': float('inf'), 'metrics': { 'train': empty_metrics_state_dict, 'val': empty_metrics_state_dict }, 'optimizer': 'optimizer dict' } assert ckpt_dict == expected_ckpt_dict # save best trainer.epoch = 1 trainer.score = Mock(return_value=2.0) trainer.save_checkpoint(is_train=False) assert len(glob.glob(str(serialization_dir / 'models' / '*'))) == 3 assert (serialization_dir / 'models' / 'model-20.pth').exists() assert (serialization_dir / 'models' / 'best-model.pth').exists() ckpt_dict = torch.load(str(serialization_dir / 'models' / 'model-20.pth')) best_ckpt_dict = torch.load( str(serialization_dir / 'models' / 'best-model.pth')) assert ckpt_dict == best_ckpt_dict expected_ckpt_dict['best_monitor'] = 2.0 expected_ckpt_dict['epoch'] = 2 assert best_ckpt_dict == expected_ckpt_dict # save by time trainer.start_time = time.time() - 60 * 10 trainer.epoch = 2 trainer.save_checkpoint(iteration=8, is_train=True) assert len(glob.glob(str(serialization_dir / 'models' / '*'))) == 4 assert (serialization_dir / 'models' / 'model-29.pth').exists() ckpt_dict = torch.load(str(serialization_dir / 'models' / 'model-29.pth')) expected_ckpt_dict['best_monitor'] = 2.0 expected_ckpt_dict['epoch'] = 2 expected_ckpt_dict['epoch_iterations'] = 9 assert ckpt_dict == expected_ckpt_dict
def test_greedy_decoder(): """ Code adapted from tensorflow """ max_time_steps = 6 seq_len_0 = 4 input_prob_matrix_0 = torch.tensor( [ [1.0, 0.0, 0.0, 0.0], # t=0 [0.0, 0.0, 0.4, 0.6], # t=1 [0.0, 0.0, 0.4, 0.6], # t=2 [0.0, 0.9, 0.1, 0.0], # t=3 [0.0, 0.0, 0.0, 0.0], # t=4 (ignored) [0.0, 0.0, 0.0, 0.0] ], # t=5 (ignored) dtype=torch.float32) input_log_prob_matrix_0 = input_prob_matrix_0.log() seq_len_1 = 5 # dimensions are time x depth input_prob_matrix_1 = torch.tensor( [ [0.1, 0.9, 0.0, 0.0], # t=0 [0.0, 0.9, 0.1, 0.0], # t=1 [0.0, 0.0, 0.1, 0.9], # t=2 [0.0, 0.9, 0.1, 0.1], # t=3 [0.9, 0.1, 0.0, 0.0], # t=4 [0.0, 0.0, 0.0, 0.0] ], # t=5 (ignored) dtype=torch.float32) input_log_prob_matrix_1 = input_prob_matrix_1.log() # len max_time_steps array of batch_size x depth matrices inputs = torch.stack([input_log_prob_matrix_0, input_log_prob_matrix_1]) # batch_size length vector of sequence_lengths seq_lens = torch.tensor([seq_len_0, seq_len_1], dtype=torch.int32) # batch_size length vector of negative log probabilities log_prob_truth = torch.tensor([ -(torch.tensor([1.0, 0.6, 0.6, 0.9]).log()).sum().item(), -(torch.tensor([0.9, 0.9, 0.9, 0.9, 0.9]).log()).sum().item() ]) decode_truth = ['ab', 'bba'] offsets_truth = [ torch.tensor([0, 3]), torch.tensor([0, 3, 4]), ] alphabet = Alphabet('abc-', blank_index=3) decoder = GreedyCTCDecoder(alphabet) out, scores, offsets = decoder.decode(inputs, seq_lens) assert out[0] == decode_truth[0] assert out[1] == decode_truth[1] assert torch.allclose(scores, log_prob_truth) assert torch.all(offsets[0] == offsets_truth[0]) assert torch.all(offsets[1] == offsets_truth[1])
def test_beam_search_decoder(): alphabet = ['\'', ' ', 'a', 'b', 'c', 'd', '-'] beam_width = 20 probs_seq1 = [[ 0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254, 0.18184413, 0.16493624 ], [ 0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462, 0.0094893, 0.06890021 ], [ 0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535, 0.08424043, 0.08120984 ], [ 0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305, 0.05206269, 0.09772094 ], [ 0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985, 0.41317442, 0.01946335 ], [ 0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937, 0.04377724, 0.01457421 ]] probs_seq2 = [[ 0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441, 0.04468023, 0.10903471 ], [ 0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123, 0.10219457, 0.20640612 ], [ 0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316, 0.12298585, 0.01654384 ], [ 0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055, 0.22538587, 0.13483174 ], [ 0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313, 0.07113197, 0.04139363 ], [ 0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306, 0.05294827, 0.22298418 ]] log_probs_seq1 = torch.log(torch.as_tensor(probs_seq1)) log_probs_seq2 = torch.log(torch.as_tensor(probs_seq2)) greedy_result = ["ac'bdc", "b'da"] beam_search_result = ['acdc', "b'a"] alphabet = Alphabet(alphabet, blank_index=alphabet.index('-')) decoder = BeamCTCDecoder(alphabet, beam_width=beam_width) log_probs_seq = log_probs_seq1[None, ...] beam_result, beam_scores, timesteps = decoder.decode(log_probs_seq) assert beam_result[0] == beam_search_result[0] log_probs_seq = log_probs_seq2[None, ...] beam_result, beam_scores, timesteps = decoder.decode(log_probs_seq) assert beam_result[0] == beam_search_result[1] # Test batch log_probs_seq = torch.stack([log_probs_seq1, log_probs_seq2]) beam_results, beam_scores, timesteps = decoder.decode(log_probs_seq) assert beam_results[0] == beam_search_result[0] assert beam_results[1] == beam_search_result[1]
def train_model_from_args(args): if args.local_rank == 0 and args.prev_output_dir is not None: logger.info('Copying results from {} to {}...'.format(args.prev_output_dir, args.serialization_dir)) copy_tree(args.prev_output_dir, args.serialization_dir, update=True, verbose=True) if not os.path.isfile(args.param_path): raise ConfigurationError(f'Parameters file {args.param_path} not found.') logger.info(f'Loading experiment from {args.param_path} with overrides `{args.overrides}`.') params = Params.load(args.param_path, args.overrides) prepare_environment(params) logger.info(args.local_rank) if args.local_rank == 0: create_serialization_dir(params, args.serialization_dir, args.reset) if args.distributed: logger.info(f'World size: {dist.get_world_size()} | Rank {dist.get_rank()} | ' f'Local Rank {args.local_rank}') dist.barrier() prepare_global_logging(args.serialization_dir, local_rank=args.local_rank, verbosity=args.verbosity) if args.local_rank == 0: params.save(os.path.join(args.serialization_dir, CONFIG_NAME)) loaders = loaders_from_params(params, distributed=args.distributed, world_size=args.world_size, first_epoch=args.first_epoch) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file(os.path.join(args.serialization_dir, "alphabet")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) alphabet.save_to_files(os.path.join(args.serialization_dir, "alphabet")) loss = losses.from_params(params.pop('loss')) model = models.from_params(alphabet=alphabet, params=params.pop('model')) trainer_params = params.pop("trainer") if args.fine_tune: _, archive_weight_file = models.load_archive(args.fine_tune) archive_weights = torch.load(archive_weight_file, map_location=lambda storage, loc: storage)['model'] # Avoiding initializing from archive some weights no_ft_regex = trainer_params.pop("no_ft", ()) finetune_weights = {} random_weights = [] for name, parameter in archive_weights.items(): if any(re.search(regex, name) for regex in no_ft_regex): random_weights.append(name) continue finetune_weights[name] = parameter logger.info(f'Loading the following weights from archive {args.fine_tune}:') logger.info(','.join(finetune_weights.keys())) logger.info(f'The following weights are at random:') logger.info(','.join(random_weights)) model.load_state_dict(finetune_weights, strict=False) # Freezing some parameters freeze_params(model, trainer_params.pop('no_grad', ())) trainer = Trainer(args.serialization_dir, trainer_params, model, loss, alphabet, local_rank=args.local_rank, world_size=args.world_size, sync_bn=args.sync_bn, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) try: trainer.run(loaders['train'], val_loader=loaders.get('val'), num_epochs=trainer_params['num_epochs']) except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(args.serialization_dir, models.DEFAULT_WEIGHTS)): logging.info("Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") raise