Example #1
0
def sample_model(
        num_features, num_classes,
        potentials, transitions,
        num_samples, max_len=20):
    """
    Sample `num_samples` from a linear-chain CRF specified
    by a `potentials` graph and a `transitions` graph. The
    samples will have a random length in `[1, max_len]`.
    """
    model = gtn.compose(potentials, transitions)

    # Draw a random X with length randomly from [1, max_len] and find the
    # most likely Y under the model:
    samples = []
    while len(samples) < num_samples:
        # Sample X:
        T = np.random.randint(1, max_len + 1)
        X = np.random.randint(0, num_features, size=(T,))
        X = make_chain_graph(X)
        # Find the most likely Y given X:
        Y = gtn.viterbi_path(gtn.compose(X, model))
        # Clean up Y:
        Y = gtn.project_output(Y)
        Y.set_weights(np.zeros(Y.num_arcs()))
        samples.append((X, Y))
    return samples
        def process(b):
            # Create emissions graph:
            emissions = gtn.linear_graph(T, C, inputs.requires_grad)
            cpu_data = inputs[b].cpu().contiguous()
            emissions.set_weights(cpu_data.data_ptr())
            target = make_chain_graph(targets[b])
            target.arc_sort(True)

            # Create token to grapheme decomposition graph
            tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
            tokens_target.arc_sort()

            # Create alignment graph:
            alignments = gtn.project_input(
                gtn.remove(gtn.compose(tokens, tokens_target))
            )
            alignments.arc_sort()

            # Add transition scores:
            if transitions is not None:
                alignments = gtn.intersect(transitions, alignments)
                alignments.arc_sort()

            loss = gtn.forward_score(gtn.intersect(emissions, alignments))

            # Normalize if needed:
            if transitions is not None:
                norm = gtn.forward_score(gtn.intersect(emissions, transitions))
                loss = gtn.subtract(loss, norm)

            losses[b] = gtn.negate(loss)

            # Save for backward:
            if emissions.calc_grad:
                emissions_graphs[b] = emissions
Example #3
0
        def pred_seq(batch_index):
            obs_fst = linearFstFromArray(arc_scores[batch_index].reshape(
                num_samples, -1))

            # Compose each sequence fst individually: it seems like composition
            # only works for lattices
            denom_fst = obs_fst
            for seq_fst in seq_fsts:
                denom_fst = gtn.compose(denom_fst, seq_fst)

            viterbi_path = gtn.viterbi_path(denom_fst)
            best_paths[batch_index] = gtn.remove(
                gtn.project_output(viterbi_path))
Example #4
0
    def test_project_clone(self):

        g_str = [
            "0 1",
            "3 4",
            "0 1 0 0 2",
            "0 2 1 1 1",
            "1 2 0 0 2",
            "2 3 0 0 1",
            "2 3 1 1 1",
            "1 4 0 0 2",
            "2 4 1 1 3",
            "3 4 0 0 2",
        ]
        graph = create_graph_from_text(g_str)

        # Test clone
        cloned = gtn.clone(graph)
        self.assertTrue(gtn.equal(graph, cloned))

        # Test projecting input
        g_str = [
            "0 1",
            "3 4",
            "0 1 0 0 2",
            "0 2 1 1 1",
            "1 2 0 0 2",
            "2 3 0 0 1",
            "2 3 1 1 1",
            "1 4 0 0 2",
            "2 4 1 1 3",
            "3 4 0 0 2",
        ]
        inputExpected = create_graph_from_text(g_str)
        self.assertTrue(gtn.equal(gtn.project_input(graph), inputExpected))

        # Test projecting output
        g_str = [
            "0 1",
            "3 4",
            "0 1 0 0 2",
            "0 2 1 1 1",
            "1 2 0 0 2",
            "2 3 0 0 1",
            "2 3 1 1 1",
            "1 4 0 0 2",
            "2 4 1 1 3",
            "3 4 0 0 2",
        ]
        outputExpected = create_graph_from_text(g_str)
        self.assertTrue(gtn.equal(gtn.project_output(graph), outputExpected))
Example #5
0
        def seq_loss(batch_index):
            obs_fst = linearFstFromArray(arc_scores[batch_index].reshape(
                num_samples, -1))
            gt_fst = fromSequence(arc_labels[batch_index])

            # Compose each sequence fst individually: it seems like composition
            # only works for lattices
            denom_fst = obs_fst
            for seq_fst in seq_fsts:
                denom_fst = gtn.compose(denom_fst, seq_fst)
                denom_fst = gtn.project_output(denom_fst)

            num_fst = gtn.compose(denom_fst, gt_fst)

            loss = gtn.subtract(gtn.forward_score(num_fst),
                                gtn.forward_score(denom_fst))

            losses[batch_index] = loss
            obs_fsts[batch_index] = obs_fst
        def process(b):
            emissions = gtn.linear_graph(T, C, False)
            cpu_data = outputs[b].cpu().contiguous()
            emissions.set_weights(cpu_data.data_ptr())
            if self.transitions is not None:
                full_graph = gtn.intersect(emissions, self.transitions)
            else:
                full_graph = emissions

            # Find the best path and remove back-off arcs:
            path = gtn.remove(gtn.viterbi_path(full_graph))
            # Left compose the viterbi path with the "alignment to token"
            # transducer to get the outputs:
            path = gtn.compose(path, self.tokens)

            # When there are ambiguous paths (allow_repeats is true), we take
            # the shortest:
            path = gtn.viterbi_path(path)
            path = gtn.remove(gtn.project_output(path))
            paths[b] = path.labels_to_list()
Example #7
0
def main(out_dir=None,
         gpu_dev_id=None,
         num_samples=10,
         random_seed=None,
         learning_rate=1e-3,
         num_epochs=500,
         dataset_kwargs={},
         dataloader_kwargs={},
         model_kwargs={}):

    if out_dir is None:
        out_dir = os.path.join('~', 'data', 'output', 'seqtools', 'test_gtn')

    out_dir = os.path.expanduser(out_dir)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    vocabulary = ['a', 'b', 'c', 'd', 'e']

    transition = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [0, 0, 0, 0, 1],
                           [0, 1, 0, 0, 1], [0, 0, 0, 0, 0]],
                          dtype=float)
    initial = np.array([1, 0, 1, 0, 0], dtype=float)
    final = np.array([0, 1, 0, 0, 1], dtype=float) / 10

    seq_params = (transition, initial, final)
    simulated_dataset = simulate(num_samples, *seq_params)
    label_seqs, obsv_seqs = tuple(zip(*simulated_dataset))
    seq_params = tuple(map(lambda x: -np.log(x), seq_params))

    dataset = torchutils.SequenceDataset(obsv_seqs, label_seqs,
                                         **dataset_kwargs)
    data_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)

    train_loader = data_loader
    val_loader = data_loader

    transition_weights = torch.tensor(transition, dtype=torch.float).log()
    initial_weights = torch.tensor(initial, dtype=torch.float).log()
    final_weights = torch.tensor(final, dtype=torch.float).log()

    model = libfst.LatticeCrf(vocabulary,
                              transition_weights=transition_weights,
                              initial_weights=initial_weights,
                              final_weights=final_weights,
                              debug_output_dir=fig_dir,
                              **model_kwargs)

    gtn.draw(model._transition_fst,
             os.path.join(fig_dir, 'transitions-init.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    gtn.draw(model._duration_fst,
             os.path.join(fig_dir, 'durations-init.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    if True:
        for i, (inputs, targets, seq_id) in enumerate(train_loader):
            arc_scores = model.scores_to_arc(inputs)
            arc_labels = model.labels_to_arc(targets)

            batch_size, num_samples, num_classes = arc_scores.shape

            obs_fst = libfst.linearFstFromArray(arc_scores[0].reshape(
                num_samples, -1))
            gt_fst = libfst.fromSequence(arc_labels[0])
            d1_fst = gtn.compose(obs_fst, model._duration_fst)
            d1_fst = gtn.project_output(d1_fst)
            denom_fst = gtn.compose(d1_fst, model._transition_fst)
            # denom_fst = gtn.project_output(denom_fst)
            num_fst = gtn.compose(denom_fst, gt_fst)
            viterbi_fst = gtn.viterbi_path(denom_fst)
            pred_fst = gtn.remove(gtn.project_output(viterbi_fst))

            loss = gtn.subtract(gtn.forward_score(num_fst),
                                gtn.forward_score(denom_fst))
            loss = torch.tensor(loss.item())

            if torch.isinf(loss).any():
                denom_alt = gtn.compose(obs_fst, model._transition_fst)
                d1_min = gtn.remove(gtn.project_output(d1_fst))
                denom_alt = gtn.compose(d1_min, model._transition_fst)
                num_alt = gtn.compose(denom_alt, gt_fst)
                gtn.draw(obs_fst,
                         os.path.join(fig_dir, 'observations-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(gt_fst,
                         os.path.join(fig_dir, 'labels-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(d1_fst,
                         os.path.join(fig_dir, 'd1-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(d1_min,
                         os.path.join(fig_dir, 'd1-min-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(denom_fst,
                         os.path.join(fig_dir, 'denominator-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(denom_alt,
                         os.path.join(fig_dir, 'denominator-alt-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(num_fst,
                         os.path.join(fig_dir, 'numerator-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(num_alt,
                         os.path.join(fig_dir, 'numerator-alt-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(viterbi_fst,
                         os.path.join(fig_dir, 'viterbi-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                gtn.draw(pred_fst,
                         os.path.join(fig_dir, 'pred-init.png'),
                         isymbols=model._arc_symbols,
                         osymbols=model._arc_symbols)
                import pdb
                pdb.set_trace()

    # Train the model
    train_epoch_log = collections.defaultdict(list)
    val_epoch_log = collections.defaultdict(list)
    metric_dict = {
        'Avg Loss': metrics.AverageLoss(),
        'Accuracy': metrics.Accuracy()
    }

    criterion = model.nllLoss
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1,
                                                gamma=1.00)

    model, last_model_wts = torchutils.trainModel(
        model,
        criterion,
        optimizer,
        scheduler,
        train_loader,
        val_loader,
        metrics=metric_dict,
        test_metric='Avg Loss',
        train_epoch_log=train_epoch_log,
        val_epoch_log=val_epoch_log,
        num_epochs=num_epochs)

    gtn.draw(model._transition_fst,
             os.path.join(fig_dir, 'transitions-trained.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)
    gtn.draw(model._duration_fst,
             os.path.join(fig_dir, 'durations-trained.png'),
             isymbols=model._arc_symbols,
             osymbols=model._arc_symbols)

    torchutils.plotEpochLog(train_epoch_log,
                            title="Train Epoch Log",
                            fn=os.path.join(fig_dir, "train-log.png"))
Example #8
0
    tokens = token_graph(word_pieces)
    gtn.draw(tokens, "tokens.pdf", idx_to_wp, idx_to_wp)

    # Recognizes "abc":
    abc = gtn.Graph(False)
    abc.add_node(True)
    abc.add_node()
    abc.add_node()
    abc.add_node(False, True)
    abc.add_arc(0, 1, let_to_idx["a"])
    abc.add_arc(1, 2, let_to_idx["b"])
    abc.add_arc(2, 3, let_to_idx["c"])
    gtn.draw(abc, "abc.pdf", idx_to_let)

    # Compute the decomposition graph for "abc":
    abc_decomps = gtn.remove(gtn.project_output(gtn.compose(abc, lex)))
    gtn.draw(abc_decomps, "abc_decomps.pdf", idx_to_wp, idx_to_wp)

    # Compute the alignment graph for "abc":
    abc_alignments = gtn.project_input(
        gtn.remove(gtn.compose(tokens, abc_decomps)))
    gtn.draw(abc_alignments, "abc_alignments.pdf", idx_to_wp)

    # From here we can use the alignment graph with an emissions graph and
    # transitions graphs to compute the sequence level criterion:
    emissions = gtn.linear_graph(10, len(word_pieces), True)
    loss = gtn.subtract(
        gtn.forward_score(emissions),
        gtn.forward_score(gtn.intersect(emissions, abc_alignments)))
    print(f"Loss is {loss.item():.2f}")