Example #1
0
def test_tensor2str():
    alphabet = Alphabet('-abc ', blank_index=0)
    decoder = Decoder(alphabet)

    expected = ['ab c', 'aa-', 'c ', 'a', '']
    tensor = torch.tensor([
        [1, 2, 4, 3, 0],
        [1, 1, 0, 0, 0],
        [3, 4, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
    ],
                          dtype=torch.int)

    sizes = [4, 3, 2, 1, 0]

    with pytest.raises(ValueError) as excinfo:
        decoder.tensor2str(torch.tensor([[[1, 2, 3]]]))

    assert '`tensor.dim()` != 1 or 2' in str(excinfo.value)

    with pytest.raises(ValueError) as excinfo:
        decoder.tensor2str(tensor.float())

    assert 'must be int' in str(excinfo.value)

    output = decoder.tensor2str(tensor, sizes)
    for o, e in zip(output, expected):
        assert o == e

    output = decoder.tensor2str(tensor)
    for o, t in zip(output, tensor):
        assert o == alphabet.idx2str(t)
Example #2
0
def test_real_ctc_beam_decoder():
    labels = ' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_'

    alphabet = Alphabet(labels, blank_index=labels.index('_'))

    log_input = torch.load(
        os.path.join(base_dir, "data/rnn_output_log_softmax.pth"))
    sizes = torch.tensor([log_input.shape[1]])
    # greedy using beam
    decoder = BeamCTCDecoder(alphabet, beam_width=1)

    decode_result, scores, timesteps = decoder.decode(log_input, sizes)

    assert "the fak friend of the fomly hae tC" == decode_result[0]

    # default beam decoding
    decoder = BeamCTCDecoder(alphabet, beam_width=25)
    decode_result, scores, timesteps = decoder.decode(log_input, sizes)

    assert "the fak friend of the fomcly hae tC" == decode_result[0]

    # lm-based decoding
    decoder = BeamCTCDecoder(alphabet,
                             lm_path=os.path.join(base_dir, 'data',
                                                  'bigram.arpa'),
                             beam_width=25,
                             alpha=2,
                             beta=0)
    decode_result, scores, timesteps = decoder.decode(log_input, sizes)
    assert "the fake friend of the family, like the" == decode_result[0]
Example #3
0
def test_wer():
    alphabet = Alphabet('-abc ', blank_index=0)

    with pytest.raises(ValueError) as excinfo:
        metric = WER()

    assert '`alphabet` is required' in str(excinfo.value)

    targets = ['abcc abc', 'abc']
    targets = [torch.tensor(alphabet.str2idx(t)) for t in targets]
    targets = torch.cat(targets)
    targets_size = torch.tensor([8, 3])

    outputs_str = ['aa-b-c-bc   abb-c--a', 'a-bbbb-cc---------aa']
    outputs_ints = [torch.tensor(alphabet.str2idx(o)) for o in outputs_str]

    outputs = torch.zeros(2, 20, 5)
    for i, o in enumerate(outputs_ints):
        outputs[i, ...].scatter_(1, torch.as_tensor([[i] for i in o]), 1)

    outputs_size = torch.tensor([17, 9])

    wers = torch.tensor([1 / 2, 0])

    metric = WER(alphabet=alphabet)

    assert repr(metric) == 'WER'

    metric.update(['loss', outputs, targets, outputs_size, targets_size])

    assert metric.val == wers.mean()
    assert metric.count == 2
    assert str(metric) == f'WER {wers.mean(0):.02%} ({wers.mean(0):.02%})'

    outputs_str = ['aa-b-c-cc   abb-c---', 'a-bbbb-cc-----------']
    outputs_ints = [torch.tensor(alphabet.str2idx(o)) for o in outputs_str]

    outputs = torch.zeros(2, 20, 5)
    for i, o in enumerate(outputs_ints):
        outputs[i, ...].scatter_(1, torch.as_tensor([[i] for i in o]), 1)

    metric.update(['loss', outputs, targets, outputs_size, targets_size])
    assert metric.val == 0
    assert metric.avg == (1 / 2) / 4
    assert metric.count == 4
Example #4
0
def test_constructor(tmpdir, model, params):
    serialization_dir = tmpdir / 'serialization_dir'
    loss = losses.CTCLoss(backend='pytorch')
    alphabet = Alphabet('-abc ', blank_index=0)

    missing_params = copy.deepcopy(params)
    del missing_params['optimizer']
    with pytest.raises(ConfigurationError) as excinfo:
        Trainer(serialization_dir,
                copy.deepcopy(missing_params),
                model,
                loss,
                alphabet,
                device='cpu')

    assert "key 'optimizer' is required" in str(excinfo.value)

    allowed_missing_params = copy.deepcopy(params)
    del allowed_missing_params['lr_scheduler']
    trainer = Trainer(serialization_dir,
                      copy.deepcopy(allowed_missing_params),
                      model,
                      loss,
                      alphabet,
                      device='cpu')
    assert trainer.lr_scheduler == None

    trainer = Trainer(serialization_dir,
                      copy.deepcopy(params),
                      model,
                      loss,
                      alphabet,
                      device='cpu')

    for phase in ['train', 'val']:
        assert isinstance(trainer.metrics[phase], Metrics)

    assert trainer.monitor == 'loss'
    assert trainer.clip_grad_norm == 400
    assert trainer.clip_grad_value == None
    assert trainer.start_epoch == 0
    assert trainer.start_iteration == 0
    assert trainer.iterations_per_epoch == None
    assert trainer.start_time == 0

    params['monitor'] = 'cer'
    trainer = Trainer(serialization_dir,
                      copy.deepcopy(params),
                      model,
                      loss,
                      alphabet,
                      device='cpu')

    assert trainer.monitor == 'cer'
Example #5
0
def trainer(tmpdir_factory, params, model):
    serialization_dir = tmpdir_factory.mktemp('serialization_dir')
    loss = losses.CTCLoss(backend='pytorch')
    alphabet = Alphabet('-abc ', blank_index=0)

    trainer = Trainer(str(serialization_dir),
                      params,
                      model,
                      loss,
                      alphabet,
                      device='cpu')

    return trainer
Example #6
0
def tokenizer(args):
    if args.unit == 'word':
        raise ValueError('Not implemented yet')

    tokens = Alphabet.from_file(args.tokens)

    if '<space>' in tokens:
        raise ValueError(
            f'Reserved token `<space>` found in {str(args.tokens)}')

    lines = args.infile.readlines()
    for line in tqdm(lines, unit='line'):
        l = ' '.join(list(line.strip().replace(' ', '@')))
        args.outfile.write(l + '\n')
Example #7
0
def test_from_params():
    params = ['cer', 'wer']
    alphabet = Alphabet('-abc ', blank_index=0)
    m = metrics.from_params(params, alphabet=alphabet)

    assert isinstance(m, Metrics)
    assert isinstance(m[0], CER)
    assert isinstance(m[1], WER)

    params = [{'type': 'cer'}, 'wer']
    m = metrics.from_params(params, alphabet=alphabet)

    assert isinstance(m, Metrics)
    assert isinstance(m[0], CER)
    assert isinstance(m[1], WER)
Example #8
0
def test_base_decoder():
    with pytest.raises(TypeError) as excinfo:
        decoder = Decoder()

    assert "missing 1 required positional argument: 'alphabet'" in str(
        excinfo.value)

    alphabet = Alphabet('-abc ', blank_index=0)
    decoder = Decoder(alphabet)

    assert hasattr(decoder, 'alphabet')

    assert decoder.wer('a bcd c', 'a dcc c') == 1
    assert decoder.cer('a bc c', 'a dcc c') == 2
    assert decoder.cer('a bcc', 'a dcc c', remove_space=True) == 2

    with pytest.raises(NotImplementedError):
        decoder.decode(None, None)
Example #9
0
def load(model_params, serialization_dir, weights_file=None, device='cpu'):
    weights_file = weights_file or os.path.join(serialization_dir,
                                                DEFAULT_WEIGHTS)

    # Load vocabulary from file
    alphabet_file = os.path.join(serialization_dir, 'vocabulary', 'alphabet')
    # If the config specifies a vocabulary subclass, we need to use it.
    alphabet = Alphabet.from_file(alphabet_file)

    default_params = {'num_classes': len(alphabet)}

    # Loading weights
    logger.info(f'Loading weights from {weights_file}.')
    state_dict = torch.load(weights_file, map_location='cpu')
    state_dict = {re.sub(r'^module.', '', k): v for k, v in state_dict.items()}

    model_name = model_params.pop('type')
    model_params = {**default_params, **model_params}

    model = by_name(model_name)(**model_params)
    model.load_state_dict(state_dict)
    model = model.to(device)

    return model
Example #10
0
def evaluate_from_args(args):
    # Disable some of the more verbose logging statements
    logging.getLogger('asr.common.params').disabled = True
    logging.getLogger('asr.common.registrable').disabled = True

    # Load from archive
    _, weights_file = load_archive(args.serialization_dir, args.overrides,
                                   args.weights_file)

    params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME),
                         args.overrides)

    prepare_environment(params)

    # Try to use the validation dataset reader if there is one - otherwise fall back
    # to the default dataset_reader used for both training and validation.
    dataset_params = params.pop('val_dataset', params.get('dataset_reader'))

    logger.info("Reading evaluation data from %s", args.input_file)
    dataset_params['manifest_filepath'] = args.input_file
    dataset = datasets.from_params(dataset_params)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(
            os.path.join(args.serialization_dir, "alphabet", "tokens"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    logits_dir = os.path.join(args.serialization_dir, 'logits')
    os.makedirs(logits_dir, exist_ok=True)

    basename = os.path.splitext(os.path.split(args.input_file)[1])[0]
    print(basename)
    logits_file = os.path.join(logits_dir, basename + '.pth')

    if not os.path.exists(logits_file):
        model = models.from_params(alphabet=alphabet,
                                   params=params.pop('model'))
        model.load_state_dict(
            torch.load(weights_file,
                       map_location=lambda storage, loc: storage)['model'])
        model.eval()

        decoder = GreedyCTCDecoder(alphabet)

        loader_params = params.pop("val_data_loader",
                                   params.get("data_loader"))
        batch_sampler = samplers.BucketingSampler(dataset,
                                                  batch_size=args.batch_size)
        loader = loaders.from_params(loader_params,
                                     dataset=dataset,
                                     batch_sampler=batch_sampler)

        logger.info(f'Logits file `{logits_file}` not found. Generating...')

        with torch.no_grad():
            model.to(args.device)

            logits = []
            total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
            for batch in tqdm.tqdm(loader):
                sample, target, sample_lengths, target_lengths = batch
                sample = sample.to(args.device)
                sample_lengths = sample_lengths.to(args.device)

                output, output_lengths = model(sample, sample_lengths)

                output = output.to('cpu')

                references = decoder.tensor2str(target, target_lengths)

                transcripts = decoder.decode(output)[0]

                logits.extend(
                    (o[:l, ...], r)
                    for o, l, r in zip(output, output_lengths, references))

                del sample, sample_lengths, output

                for reference, transcript in zip(references, transcripts):
                    total_wer += decoder.wer(transcript, reference)
                    total_cer += decoder.cer(transcript, reference)
                    num_tokens += float(len(reference.split()))
                    num_chars += float(len(reference))

            torch.save(logits, logits_file)

            wer = float(total_wer) / num_tokens
            cer = float(total_cer) / num_chars

            print(f'WER: {wer:.02%}\nCER: {cer:.02%}')

        del model

    else:
        logger.info(f'Logits file `{logits_file}` already generated.')
Example #11
0
def tune_from_args(args):
    # Disable some of the more verbose logging statements
    logging.getLogger('asr.common.params').disabled = True
    logging.getLogger('asr.common.registrable').disabled = True

    # Load from archive
    _, weights_file = load_archive(args.serialization_dir, args.overrides,
                                   args.weights_file)

    params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME),
                         args.overrides)

    prepare_environment(params)

    # Try to use the validation dataset reader if there is one - otherwise fall back
    # to the default dataset_reader used for both training and validation.
    dataset_params = params.pop('val_dataset', params.get('dataset_reader'))

    logger.info("Reading evaluation data from %s", args.input_file)
    dataset_params['manifest_filepath'] = args.input_file
    dataset = datasets.from_params(dataset_params)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(
            os.path.join(args.serialization_dir, "alphabet", "tokens"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    logits_dir = os.path.join(args.serialization_dir, 'logits')
    os.makedirs(logits_dir, exist_ok=True)

    basename = os.path.splitext(os.path.split(args.input_file)[1])[0]
    logits_file = os.path.join(logits_dir, basename + '.pth')

    if not os.path.exists(logits_file):
        model = models.from_params(alphabet=alphabet,
                                   params=params.pop('model'))
        model.load_state_dict(
            torch.load(weights_file,
                       map_location=lambda storage, loc: storage)['model'])
        model.eval()

        decoder = GreedyCTCDecoder(alphabet)

        loader_params = params.pop("val_data_loader",
                                   params.get("data_loader"))
        batch_sampler = samplers.BucketingSampler(dataset,
                                                  batch_size=args.batch_size)
        loader = loaders.from_params(loader_params,
                                     dataset=dataset,
                                     batch_sampler=batch_sampler)

        logger.info(f'Logits file `{logits_file}` not found. Generating...')

        with torch.no_grad():
            model.to(args.device)

            logits = []
            for batch in tqdm.tqdm(loader):
                sample, target, sample_lengths, target_lengths = batch
                sample = sample.to(args.device)
                sample_lengths = sample_lengths.to(args.device)

                output, output_lengths = model(sample, sample_lengths)

                output = output.to('cpu')

                references = decoder.tensor2str(target, target_lengths)

                logits.extend((o[:l, ...], r) for o, l, r in zip(
                    output.to('cpu'), output_lengths, references))

                del sample, sample_lengths, output

            torch.save(logits, logits_file)

        del model

    tune_dir = os.path.join(args.serialization_dir, 'tune')
    os.makedirs(tune_dir, exist_ok=True)

    params_grid = list(
        product(
            torch.linspace(args.alpha_from, args.alpha_to, args.alpha_steps),
            torch.linspace(args.beta_from, args.beta_to, args.beta_steps)))

    print(
        'Scheduling {} jobs for alphas=linspace({}, {}, {}), betas=linspace({}, {}, {})'
        .format(len(params_grid), args.alpha_from, args.alpha_to,
                args.alpha_steps, args.beta_from, args.beta_to,
                args.beta_steps))

    # start worker processes
    logger.info(
        f"Using {args.num_workers} processes and {args.lm_workers} for each CTCDecoder."
    )
    extract_start = default_timer()

    p = Pool(args.num_workers, init, [
        logits_file, alphabet, args.lm_path, args.cutoff_top_n,
        args.cutoff_prob, args.beam_width, args.lm_workers
    ])

    scores = []
    best_wer = float('inf')
    with tqdm.tqdm(p.imap(tune_step, params_grid),
                   total=len(params_grid),
                   desc='Grid search') as pbar:
        for params in pbar:
            alpha, beta, wer, cer = params
            scores.append([alpha, beta, wer, cer])

            if wer < best_wer:
                best_wer = wer
                pbar.set_postfix(alpha=alpha, beta=beta, wer=wer, cer=cer)

    logger.info(
        f"Finished {len(params_grid)} processes in {default_timer() - extract_start:.1f}s"
    )

    df = pd.DataFrame(scores, columns=['alpha', 'beta', 'wer', 'cer'])
    df.to_csv(os.path.join(tune_dir, basename + '.csv'), index=False)
Example #12
0
def test_load_checkpoint(tmpdir, caplog, model, params):
    serialization_dir = (tmpdir / 'serialization_dir').mkdir()
    loss = losses.CTCLoss(backend='pytorch')
    alphabet = Alphabet('-abc ', blank_index=0)

    trainer = Trainer(serialization_dir,
                      params,
                      model,
                      loss,
                      alphabet,
                      device='cpu')

    # no checkpoint
    trainer.load_checkpoint()
    assert not ('Last model checkpoint found' in caplog.record_tuples[-1][2])

    (serialization_dir / 'models').mkdir()

    # no checkpoint
    trainer.load_checkpoint()
    assert not ('Last model checkpoint found' in caplog.record_tuples[-1][2])

    # mocking calls

    ckpt_dict = {
        'model': 'model-mock',
        'optimizer': 'optimizer-mock',
        'best_monitor': 2.0,
        'metrics': {
            'train': 'train-metric-mock',
            'val': 'val-metric-mock'
        },
        'epoch': 1,
        'epoch_iterations': 8,
        'iterations_per_epoch': 10
    }

    trainer.model.load_state_dict = Mock()
    trainer.optimizer.load_state_dict = Mock()
    for split in ['train', 'val']:
        trainer.metrics[split].load_state_dict = Mock()
    torch.load = Mock(return_value=ckpt_dict)

    # find biggest iterations
    f1 = (serialization_dir / 'models').join('model-20.pth')
    f1.write('')
    f2 = (serialization_dir / 'models').join('model-25.pth')
    f2.write('')

    trainer.load_checkpoint()

    assert 'Last model checkpoint found' in caplog.record_tuples[-1][2]

    torch.load.assert_called_with(str(f2), map_location='cpu')
    trainer.model.load_state_dict.assert_called_with('model-mock')
    trainer.optimizer.load_state_dict.assert_called_with('optimizer-mock')

    for split in ['train', 'val']:
        trainer.metrics[split].load_state_dict.assert_called_with(
            f'{split}-metric-mock')

    assert trainer.best_monitor == 2.0
    assert trainer.start_epoch == 1
    assert trainer.start_iteration == 8
    assert trainer.iterations_per_epoch == 10
Example #13
0
def test_save_checkpoint(tmpdir, model, params):
    serialization_dir = tmpdir / 'serialization_dir'
    loss = losses.CTCLoss(backend='pytorch')
    alphabet = Alphabet('-abc ', blank_index=0)

    trainer = Trainer(serialization_dir,
                      params,
                      model,
                      loss,
                      alphabet,
                      device='cpu')

    # Mocking variables
    trainer.iterations_per_epoch = 10
    trainer.epoch = 0
    trainer.start_time = time.time()

    # should do nothing
    trainer.save_checkpoint(iteration=5, is_train=True)

    assert not (serialization_dir / 'models').exists()

    # should save, end of epoch
    trainer.model.state_dict = Mock(return_value='model dict')
    trainer.optimizer.state_dict = Mock(return_value='optimizer dict')
    trainer.score = Mock(return_value=float('inf'))

    trainer.save_checkpoint(iteration=9, is_train=True)

    assert (serialization_dir / 'models').exists()

    assert len(glob.glob(str(serialization_dir / 'models' / '*'))) == 1

    assert (serialization_dir / 'models' / 'model-10.pth').exists()

    ckpt_dict = torch.load(str(serialization_dir / 'models' / 'model-10.pth'))

    empty_metric_state_dict = {
        'val': 0,
        'avg': 0,
        'count': 0,
        'sum': 0,
        'history': []
    }

    empty_metrics_state_dict = [{
        'type': 'asr.metrics.Loss',
        'state_dict': empty_metric_state_dict
    }, {
        'type': 'asr.metrics.CER',
        'state_dict': empty_metric_state_dict
    }, {
        'type': 'asr.metrics.WER',
        'state_dict': empty_metric_state_dict
    }]

    expected_ckpt_dict = {
        'model': 'model dict',
        'epoch': 1,
        'epoch_iterations': 0,
        'iterations_per_epoch': 10,
        'best_monitor': float('inf'),
        'metrics': {
            'train': empty_metrics_state_dict,
            'val': empty_metrics_state_dict
        },
        'optimizer': 'optimizer dict'
    }

    assert ckpt_dict == expected_ckpt_dict

    # save best
    trainer.epoch = 1
    trainer.score = Mock(return_value=2.0)

    trainer.save_checkpoint(is_train=False)

    assert len(glob.glob(str(serialization_dir / 'models' / '*'))) == 3

    assert (serialization_dir / 'models' / 'model-20.pth').exists()
    assert (serialization_dir / 'models' / 'best-model.pth').exists()

    ckpt_dict = torch.load(str(serialization_dir / 'models' / 'model-20.pth'))

    best_ckpt_dict = torch.load(
        str(serialization_dir / 'models' / 'best-model.pth'))

    assert ckpt_dict == best_ckpt_dict

    expected_ckpt_dict['best_monitor'] = 2.0
    expected_ckpt_dict['epoch'] = 2
    assert best_ckpt_dict == expected_ckpt_dict

    # save by time
    trainer.start_time = time.time() - 60 * 10
    trainer.epoch = 2

    trainer.save_checkpoint(iteration=8, is_train=True)

    assert len(glob.glob(str(serialization_dir / 'models' / '*'))) == 4

    assert (serialization_dir / 'models' / 'model-29.pth').exists()

    ckpt_dict = torch.load(str(serialization_dir / 'models' / 'model-29.pth'))

    expected_ckpt_dict['best_monitor'] = 2.0
    expected_ckpt_dict['epoch'] = 2
    expected_ckpt_dict['epoch_iterations'] = 9
    assert ckpt_dict == expected_ckpt_dict
Example #14
0
def test_greedy_decoder():
    """ Code adapted from tensorflow
    """
    max_time_steps = 6

    seq_len_0 = 4
    input_prob_matrix_0 = torch.tensor(
        [
            [1.0, 0.0, 0.0, 0.0],  # t=0
            [0.0, 0.0, 0.4, 0.6],  # t=1
            [0.0, 0.0, 0.4, 0.6],  # t=2
            [0.0, 0.9, 0.1, 0.0],  # t=3
            [0.0, 0.0, 0.0, 0.0],  # t=4 (ignored)
            [0.0, 0.0, 0.0, 0.0]
        ],  # t=5 (ignored)
        dtype=torch.float32)
    input_log_prob_matrix_0 = input_prob_matrix_0.log()

    seq_len_1 = 5
    # dimensions are time x depth
    input_prob_matrix_1 = torch.tensor(
        [
            [0.1, 0.9, 0.0, 0.0],  # t=0
            [0.0, 0.9, 0.1, 0.0],  # t=1
            [0.0, 0.0, 0.1, 0.9],  # t=2
            [0.0, 0.9, 0.1, 0.1],  # t=3
            [0.9, 0.1, 0.0, 0.0],  # t=4
            [0.0, 0.0, 0.0, 0.0]
        ],  # t=5 (ignored)
        dtype=torch.float32)
    input_log_prob_matrix_1 = input_prob_matrix_1.log()

    # len max_time_steps array of batch_size x depth matrices
    inputs = torch.stack([input_log_prob_matrix_0, input_log_prob_matrix_1])
    # batch_size length vector of sequence_lengths
    seq_lens = torch.tensor([seq_len_0, seq_len_1], dtype=torch.int32)

    # batch_size length vector of negative log probabilities
    log_prob_truth = torch.tensor([
        -(torch.tensor([1.0, 0.6, 0.6, 0.9]).log()).sum().item(),
        -(torch.tensor([0.9, 0.9, 0.9, 0.9, 0.9]).log()).sum().item()
    ])

    decode_truth = ['ab', 'bba']
    offsets_truth = [
        torch.tensor([0, 3]),
        torch.tensor([0, 3, 4]),
    ]

    alphabet = Alphabet('abc-', blank_index=3)

    decoder = GreedyCTCDecoder(alphabet)
    out, scores, offsets = decoder.decode(inputs, seq_lens)

    assert out[0] == decode_truth[0]
    assert out[1] == decode_truth[1]

    assert torch.allclose(scores, log_prob_truth)

    assert torch.all(offsets[0] == offsets_truth[0])
    assert torch.all(offsets[1] == offsets_truth[1])
Example #15
0
def test_beam_search_decoder():

    alphabet = ['\'', ' ', 'a', 'b', 'c', 'd', '-']
    beam_width = 20
    probs_seq1 = [[
        0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254, 0.18184413,
        0.16493624
    ],
                  [
                      0.03309247, 0.22866108, 0.24390638, 0.09699597,
                      0.31895462, 0.0094893, 0.06890021
                  ],
                  [
                      0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535,
                      0.08424043, 0.08120984
                  ],
                  [
                      0.12094152, 0.19162472, 0.01473646, 0.28045061,
                      0.24246305, 0.05206269, 0.09772094
                  ],
                  [
                      0.1333387, 0.00550838, 0.00301669, 0.21745861,
                      0.20803985, 0.41317442, 0.01946335
                  ],
                  [
                      0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937,
                      0.04377724, 0.01457421
                  ]]
    probs_seq2 = [[
        0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441, 0.04468023,
        0.10903471
    ],
                  [
                      0.09742457, 0.12959763, 0.09435383, 0.21889204,
                      0.15113123, 0.10219457, 0.20640612
                  ],
                  [
                      0.45033529, 0.09091417, 0.15333208, 0.07939558,
                      0.08649316, 0.12298585, 0.01654384
                  ],
                  [
                      0.02512238, 0.22079203, 0.19664364, 0.11906379,
                      0.07816055, 0.22538587, 0.13483174
                  ],
                  [
                      0.17928453, 0.06065261, 0.41153005, 0.1172041,
                      0.11880313, 0.07113197, 0.04139363
                  ],
                  [
                      0.15882358, 0.1235788, 0.23376776, 0.20510435,
                      0.00279306, 0.05294827, 0.22298418
                  ]]
    log_probs_seq1 = torch.log(torch.as_tensor(probs_seq1))
    log_probs_seq2 = torch.log(torch.as_tensor(probs_seq2))

    greedy_result = ["ac'bdc", "b'da"]
    beam_search_result = ['acdc', "b'a"]

    alphabet = Alphabet(alphabet, blank_index=alphabet.index('-'))
    decoder = BeamCTCDecoder(alphabet, beam_width=beam_width)

    log_probs_seq = log_probs_seq1[None, ...]
    beam_result, beam_scores, timesteps = decoder.decode(log_probs_seq)

    assert beam_result[0] == beam_search_result[0]

    log_probs_seq = log_probs_seq2[None, ...]
    beam_result, beam_scores, timesteps = decoder.decode(log_probs_seq)

    assert beam_result[0] == beam_search_result[1]

    # Test batch

    log_probs_seq = torch.stack([log_probs_seq1, log_probs_seq2])

    beam_results, beam_scores, timesteps = decoder.decode(log_probs_seq)

    assert beam_results[0] == beam_search_result[0]
    assert beam_results[1] == beam_search_result[1]
Example #16
0
def train_model_from_args(args):

    if args.local_rank == 0 and args.prev_output_dir is not None:
        logger.info('Copying results from {} to {}...'.format(args.prev_output_dir, args.serialization_dir))

        copy_tree(args.prev_output_dir, args.serialization_dir, update=True, verbose=True)

    if not os.path.isfile(args.param_path):
        raise ConfigurationError(f'Parameters file {args.param_path} not found.')

    logger.info(f'Loading experiment from {args.param_path} with overrides `{args.overrides}`.')

    params = Params.load(args.param_path, args.overrides)

    prepare_environment(params)

    logger.info(args.local_rank)
    if args.local_rank == 0:
        create_serialization_dir(params, args.serialization_dir, args.reset)

    if args.distributed:
        logger.info(f'World size: {dist.get_world_size()} | Rank {dist.get_rank()} | ' f'Local Rank {args.local_rank}')
        dist.barrier()

    prepare_global_logging(args.serialization_dir, local_rank=args.local_rank, verbosity=args.verbosity)

    if args.local_rank == 0:
        params.save(os.path.join(args.serialization_dir, CONFIG_NAME))

    loaders = loaders_from_params(params,
                                  distributed=args.distributed,
                                  world_size=args.world_size,
                                  first_epoch=args.first_epoch)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(os.path.join(args.serialization_dir, "alphabet"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    alphabet.save_to_files(os.path.join(args.serialization_dir, "alphabet"))

    loss = losses.from_params(params.pop('loss'))
    model = models.from_params(alphabet=alphabet, params=params.pop('model'))

    trainer_params = params.pop("trainer")
    if args.fine_tune:
        _, archive_weight_file = models.load_archive(args.fine_tune)

        archive_weights = torch.load(archive_weight_file, map_location=lambda storage, loc: storage)['model']

        # Avoiding initializing from archive some weights
        no_ft_regex = trainer_params.pop("no_ft", ())

        finetune_weights = {}
        random_weights = []
        for name, parameter in archive_weights.items():
            if any(re.search(regex, name) for regex in no_ft_regex):
                random_weights.append(name)
                continue
            finetune_weights[name] = parameter

        logger.info(f'Loading the following weights from archive {args.fine_tune}:')
        logger.info(','.join(finetune_weights.keys()))
        logger.info(f'The following weights are at random:')
        logger.info(','.join(random_weights))

        model.load_state_dict(finetune_weights, strict=False)

    # Freezing some parameters
    freeze_params(model, trainer_params.pop('no_grad', ()))

    trainer = Trainer(args.serialization_dir,
                      trainer_params,
                      model,
                      loss,
                      alphabet,
                      local_rank=args.local_rank,
                      world_size=args.world_size,
                      sync_bn=args.sync_bn,
                      opt_level=args.opt_level,
                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                      loss_scale=args.loss_scale)

    try:
        trainer.run(loaders['train'], val_loader=loaders.get('val'), num_epochs=trainer_params['num_epochs'])
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(args.serialization_dir, models.DEFAULT_WEIGHTS)):
            logging.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
        raise