Пример #1
0
def exact_rescoring(model: ModelContainer,
                    forest: Hypergraph, goal_maker: GoalRuleMaker, log=dummyfunc) -> SimpleNamespace:
    """
    Exactly rescore a forest with a certain model.

    :param model: an instance of ModelContainer
    :param forest: a Hypergraph
    :param goal_maker: an object to deliver (output view of) goal rules
    :param log: a logging function
    :return: result.forest and result.components as a SimpleNamespace object
    """
    result = SimpleNamespace()

    if not model.stateful:  # when the model is not stateful, we don't need Earley
        log('Lookup scoring')
        lookup_comps = get_lookup_components(forest, model.lookup.extractors())  # lookup
        log('Stateless scoring')
        stateless_comps = get_stateless_components(forest, model.stateless.extractors())  # stateless
        result.forest = forest
        result.components = [FComponents([comps1, comps2]) for comps1, comps2 in zip(lookup_comps, stateless_comps)]

    else:  # here we cannot avoid it
        log('Forest rescoring')
        goal_maker.update()
        result.forest, result.components = rescore_forest(forest,
                                                          0,
                                                          TableLookupScorer(model.lookup),
                                                          StatelessScorer(model.stateless),
                                                          StatefulScorer(model.stateful),
                                                          goal_rule=goal_maker.get_oview(),
                                                          keep_components=True)

    return result
Пример #2
0
def make_slice_sampler(seg, model,
                       extra_grammar_paths=[], glue_grammar_paths=[], pass_through=True,
                       default_symbol='X', goal_str='GOAL', start_str='S',
                       saving={}, redo=True,
                       log=dummyfunc) -> 'str':
    """
    Return the best translation according to a consensus decision rule.
    :return: best translation string
    """

    # check for pass1
    if all(is_step_complete(step, saving, redo) for step in ['forest', 'lookup', 'stateless']):
        tgt_forest = unpickle_it(saving['forest'])
        lookup_comps = unpickle_it(saving['lookup'])
        stateless_comps = unpickle_it(saving['stateless'])
    else:
        src_forest = pass0(seg,
                           extra_grammar_paths=extra_grammar_paths,
                           glue_grammar_paths=glue_grammar_paths,
                           pass_through=pass_through,
                           default_symbol=default_symbol,
                           goal_str=goal_str,
                           start_str=start_str,
                           n_goal=0,
                           saving={},
                           redo=redo,
                           log=log)

        # pass1: local scoring
        tgt_forest, lookup_comps, stateless_comps = pass1(seg,
                                                          src_forest,
                                                          model,
                                                          saving=saving,
                                                          redo=redo,
                                                          log=log)

    # l(d)
    lfunc = TableLookupFunction(np.array([semiring.inside.times(model.lookup.score(ff1),
                                                                model.stateless.score(ff2))
                                          for ff1, ff2 in zip(lookup_comps, stateless_comps)], dtype=ptypes.weight))
    # top sort table
    tsort = AcyclicTopSortTable(tgt_forest)
    goal_maker = GoalRuleMaker(goal_str=goal_str, start_str=start_str, n=1)
    # slice sampler
    sampler = SlicedRescoring(tgt_forest,
                              lfunc,
                              tsort,
                              TableLookupScorer(model.dummy),
                              StatelessScorer(model.dummy),
                              StatefulScorer(model.stateful),
                              semiring.inside,
                              goal_rule=goal_maker.get_oview(),
                              dead_rule=make_dead_oview())
    return tgt_forest, lfunc, tsort, sampler
Пример #3
0
def pass0(seg, extra_grammar_paths=[], glue_grammar_paths=[], pass_through=True,
          default_symbol='X', goal_str='GOAL', start_str='S', max_span=-1, n_goal=0,
          saving={}, redo=True, log=dummyfunc) -> 'Hypergraph':
    """
    Pass0 consists in parsing with the source side of the grammar.
    For now, pass0 does not do any scoring (not even local), but it could (TODO).

    Steps
        1. Make a hypergraph view of the grammar
        2. Make an input DFA
        3. Parse the input DFA

    :return: source forest
    """
    if is_step_complete('forest', saving, redo):
        return unpickle_it(saving['forest'])

    # here we need to decode for sure
    log('[%d] Make hypergraph view of all available grammars', seg.id)
    # make a hypergraph view of all available grammars
    grammar = make_grammar_hypergraph(seg,
                                      extra_grammar_paths=extra_grammar_paths,
                                      glue_grammar_paths=glue_grammar_paths,
                                      pass_through=pass_through,
                                      default_symbol=default_symbol)

    # parse source lattice
    log('[%d] Parse source DFA', seg.id)
    goal_maker = GoalRuleMaker(goal_str=goal_str, start_str=start_str, n=n_goal)
    dfa = make_input_dfa(seg)
    forest = parse_dfa(grammar,
                       grammar.fetch(Nonterminal(start_str)),
                       dfa,
                       goal_maker.get_iview(),
                       bottomup=True,
                       constraint=HieroConstraints(grammar, dfa, max_span))
    if 'forest' in saving:
        pickle_it(saving['forest'], forest)
    return forest
Пример #4
0
def decode(seg, args, proxy, target):
    # pass0
    src_forest = pipeline.pass0(seg,
                                extra_grammar_paths=args.extra_grammar,
                                glue_grammar_paths=args.glue_grammar,
                                pass_through=args.pass_through,
                                default_symbol=args.default_symbol,
                                goal_str=args.goal,
                                start_str=args.start,
                                max_span=args.max_span,
                                n_goal=0,
                                log=logging.info)

    if not proxy.stateful:
        tgt_forest, lookup_comps, stateless_comps = pipeline.pass1(
            seg, src_forest, proxy, saving={}, redo=True, log=logging.info)
        q_components = [
            FComponents([comp1, comp2])
            for comp1, comp2 in zip(lookup_comps, stateless_comps)
        ]
    else:
        tgt_forest = pipeline.make_target_forest(src_forest)
        goal_maker = GoalRuleMaker(goal_str=args.goal,
                                   start_str=args.start,
                                   n=1)
        tgt_forest, q_components = pipeline.pass2(
            seg,
            tgt_forest,
            TableLookupScorer(proxy.lookup),
            StatelessScorer(proxy.stateless),
            StatefulScorer(proxy.stateful),
            goal_rule=goal_maker.get_oview(),
            omega=None,
            saving={},
            redo=True,
            log=logging.info)
    # TODO: save tgt_forest and q_components
    # Make unnormalised q(d)
    q_func = TableLookupFunction(
        np.array([proxy.score(comps) for comps in q_components],
                 dtype=ptypes.weight))

    logging.info('[%d] Forest: nodes=%d edges=%d', seg.id,
                 tgt_forest.n_nodes(), tgt_forest.n_edges())
    tsort = AcyclicTopSortTable(tgt_forest)

    sampler = AncestralSampler(tgt_forest, tsort, omega=q_func)
    samples = sampler.sample(args.samples)
    n_samples = len(samples)

    d_groups = group_by_identity(samples)
    y_groups = group_by_projection(
        d_groups, lambda group: yield_string(tgt_forest, group.key))

    is_yields = []
    for y_group in y_groups:
        y = y_group.key
        is_derivations = []
        for d_group in y_group.values:
            edges = d_group.key
            # reduce q weights through inside.times
            q_score = derivation_weight(tgt_forest,
                                        edges,
                                        semiring.inside,
                                        omega=q_func)
            # reduce q components through inside.times
            q_comps = proxy.constant(semiring.inside.one)
            for e in edges:
                q_comps = q_comps.hadamard(q_components[e],
                                           semiring.inside.times)
            # compute p components and p score
            p_comps, p_score = score_derivation(
                tgt_forest, edges, semiring.inside,
                TableLookupScorer(target.lookup),
                StatelessScorer(target.stateless),
                StatefulScorer(target.stateful))
            # TODO: save {y => {edges: (q_comps, p_comps, count)}}
            is_derivations.append(
                ISDerivation(edges, q_comps, p_comps, d_group.count))
        is_yields.append(ISYield(y, is_derivations, y_group.count))
    # TODO: pickle pickling
    return decide(is_yields, n_samples, proxy, target)
Пример #5
0
def training_biparse(seg, args, workingdir, model, log=dummyfunc) -> 'bool':
    """
    Steps:
        I. Pass0 and pass1: parse source, project, local scoring
        II. Pass2
            - make a reference DFA
            - parse the reference DFA
            - fully score the reference forest (lookup, stateless, stateful)
                - save rescored forest and components
    :return: whether or not the input is bi-parsable
    """

    pass1_files = ['{0}/{1}.hyp.forest'.format(workingdir, seg.id),
                   '{0}/{1}.hyp.ffs.rule'.format(workingdir, seg.id),
                   '{0}/{1}.hyp.ffs.stateless'.format(workingdir, seg.id)]
    ref_files = ['{0}/{1}.ref.ffs.all'.format(workingdir, seg.id),
                 '{0}/{1}.ref.forest'.format(workingdir, seg.id)]

    # check for redundant work
    if all(os.path.exists(path) for path in pass1_files) and not args.redo:
        if all(os.path.exists(path) for path in ref_files):
            log('[%d] Reusing forests for segment', seg.id)
            return True   # parsable
        else:
            return False  # not parsable

    # pass0: parsing

    src_forest = pass0(seg,
                       extra_grammar_paths=args.extra_grammar,
                       glue_grammar_paths=args.glue_grammar,
                       pass_through=args.pass_through,
                       default_symbol=args.default_symbol,
                       goal_str=args.goal,
                       start_str=args.start,
                       n_goal=0,
                       saving={},
                       redo=args.redo,
                       log=log)

    # pass1: local scoring

    saving1 = {
        'forest': '{0}/{1}.hyp.forest'.format(workingdir, seg.id),
        'lookup': '{0}/{1}.hyp.ffs.rule'.format(workingdir, seg.id),
        'stateless': '{0}/{1}.hyp.ffs.stateless'.format(workingdir, seg.id)
    }

    tgt_forest, lookup_comps, stateless_comps = pass1(seg,
                                                      src_forest,
                                                      model,
                                                      saving=saving1,
                                                      redo=args.redo,
                                                      log=log)


    # parse reference lattice
    log('[%d] Parse reference DFA', seg.id)
    ref_dfa = make_reference_dfa(seg)
    goal_maker = GoalRuleMaker(goal_str=args.goal, start_str=args.start, n=1)
    ref_forest = parse_dfa(tgt_forest,
                           0,
                           ref_dfa,
                           goal_maker.get_oview(),
                           bottomup=False)

    if not ref_forest:
        return False  # not parsable

    # pass2: rescore reference forest

    saving2 = {
        'forest': '{0}/{1}.ref.forest'.format(workingdir, seg.id),
        'components': '{0}/{1}.ref.ffs.all'.format(workingdir, seg.id)
    }
    goal_maker.update()
    pass2(seg, ref_forest,
          TableLookupScorer(model.lookup),
          StatelessScorer(model.stateless),
          StatefulScorer(model.stateful),
          goal_maker.get_oview(),
          saving=saving2, redo=args.redo,
          log=log)

    return True  # parsable
Пример #6
0
def decode(seg, args, model, outdir):
    """
    """

    # pass0
    src_forest = pipeline.pass0(seg,
                                extra_grammar_paths=args.extra_grammar,
                                glue_grammar_paths=args.glue_grammar,
                                pass_through=args.pass_through,
                                default_symbol=args.default_symbol,
                                goal_str=args.goal,
                                start_str=args.start,
                                max_span=args.max_span,
                                n_goal=0,
                                log=logging.info)
    tgt_forest = pipeline.make_target_forest(src_forest,
                                             TableLookupScorer(model.lookup))
    tsort = AcyclicTopSortTable(tgt_forest)

    if args.viterbi:
        viterbi(seg.id, tgt_forest, tsort, outdir, "pass0")

    # pass1
    if model.stateless:
        tgt_forest = stateless_rescoring(tgt_forest,
                                         StatelessScorer(model.stateless),
                                         semiring.inside)
        if args.viterbi:
            viterbi(seg.id, tgt_forest, tsort, outdir, "pass1")

    samples = []

    if args.framework == 'exact' or not model.stateful:  # exact scoring or no stateful scoring
        # we have access to Viterbi, k-best, sampling
        if model.stateful:
            goal_maker = GoalRuleMaker(goal_str=args.goal,
                                       start_str=args.start,
                                       n=1)

            rescorer = EarleyRescorer(tgt_forest,
                                      TableLookupScorer(model.dummy),
                                      StatelessScorer(model.dummy),
                                      StatefulScorer(model.stateful),
                                      semiring.inside)

            tgt_forest = rescorer.do(tsort.root(), goal_maker.get_oview())
            tsort = AcyclicTopSortTable(tgt_forest)

        # Do everything: viterbi, map, consensus, etc...
        if args.viterbi:
            viterbi(seg.id, tgt_forest, tsort, outdir, "pass2")

        if args.kbest > 0:
            # TODO: call kbest code
            pass
        if args.samples > 0:
            sampler = AncestralSampler(tgt_forest, tsort)
            samples = sampler.sample(args.samples)
            derivations = group_by_identity(samples)
            save_mc_derivations(
                '{0}/exact/derivations/{1}.gz'.format(outdir, seg.id),
                derivations,
                sampler.Z,
                valuefunc=lambda d: derivation_weight(tgt_forest, d, semiring.
                                                      inside),
                derivation2str=lambda d: bracketed_string(tgt_forest, d))
            projections = group_by_projection(
                samples, lambda d: yield_string(tgt_forest, d))
            save_mc_yields('{0}/exact/yields/{1}.gz'.format(outdir, seg.id),
                           projections)

            # TODO: fix this hack
            # it's here just so I can reuse pipeline.consensus
            # the fix involves moving SampleReturn to a more general module
            # and making AncestralSampler use it
            from grasp.alg.rescoring import SampleReturn
            samples = [SampleReturn(s, 0.0, FComponents([])) for s in samples]

    else:  # for sliced scoring, we only have access to sampling

        logging.info('Sliced rescoring...')
        from grasp.alg.rescoring import SlicedRescoring
        goal_maker = GoalRuleMaker(goal_str=args.goal,
                                   start_str=args.start,
                                   n=1)

        rescorer = SlicedRescoring(tgt_forest,
                                   HypergraphLookupFunction(tgt_forest), tsort,
                                   TableLookupScorer(model.dummy),
                                   StatelessScorer(model.dummy),
                                   StatefulScorer(model.stateful),
                                   semiring.inside, goal_maker.get_oview(),
                                   make_dead_oview(args.default_symbol))

        if args.gamma_shape > 0:
            gamma_shape = args.gamma_shape
        else:
            gamma_shape = len(model)  # number of local components
        gamma_scale_type = args.gamma_scale[0]
        gamma_scale_parameter = float(args.gamma_scale[1])

        # here samples are represented as sequences of edge ids
        d0, markov_chain = rescorer.sample(
            n_samples=args.samples,
            batch_size=args.batch,
            within=args.within,
            initial=args.initial,
            gamma_shape=gamma_shape,
            gamma_scale_type=gamma_scale_type,
            gamma_scale_parameter=gamma_scale_parameter,
            burn=args.burn,
            lag=args.lag,
            temperature0=args.temperature0)

        # apply usual MCMC heuristics (e.g. burn-in, lag)
        samples = apply_filters(markov_chain, burn=args.burn, lag=args.lag)

        # group by derivation (now a sample is represented by a Derivation object)
        derivations = group_by_identity(samples)
        save_mcmc_derivations(
            '{0}/slice/derivations/{1}.gz'.format(outdir, seg.id),
            derivations,
            valuefunc=lambda d: d.score,
            derivation2str=lambda d: bracketed_string(tgt_forest, d.edges))
        projections = group_by_projection(
            samples, lambda d: yield_string(tgt_forest, d.edges))
        save_mcmc_yields('{0}/slice/yields/{1}.gz'.format(outdir, seg.id),
                         projections)

        if args.save_chain:
            markov_chain.appendleft(d0)
            save_markov_chain(
                '{0}/slice/chain/{1}.gz'.format(outdir, seg.id),
                markov_chain,
                flat=True,
                valuefunc=lambda d: d.score,
                derivation2str=lambda d: bracketed_string(tgt_forest, d.edges))

    if samples:
        # decision rule
        decisions = pipeline.consensus(seg, tgt_forest, samples)
        return decisions[0]
Пример #7
0
def biparse(seg: SegmentMetaData, options: SimpleNamespace,
            joint_model: ModelView, conditional_model: ModelView,
            workingdir=None, redo=True, log=dummyfunc) -> SimpleNamespace:
    """
    Biparse a segment using a local model.
    1. we parse the source with a joint model
    2. we bi-parse source and target with a conditional model
    This separation allows us to factorise these models differently wrt local/nonlocal components.
    For example, an LM maybe seen as a local (read tractable) component of a conditional model,
     and as a nonlocal (read intractable) component of a joint model.
    An implementation detail: bi-parsing is implemented as a cascade of intersections (with projections in between).

    :param seg: a segment
    :param options: parsing options
    :param joint_model: a factorised view of the joint model, here we use only the local components
    :param conditional_model: a factorised view of the conditional, here we use only the local components
    :param workingdir: where to save files
    :param redo: whether or not previously saved computation should be discarded
    :param log: a logging function
    :return: result.{joint,conditional}.{forest,components} for the respective local model
    """

    if workingdir:
        saving = preprocessed_training_files('{0}/{1}'.format(workingdir, seg.id))
    else:
        saving = {}

    result = SimpleNamespace()
    result.joint = SimpleNamespace()
    result.conditional = SimpleNamespace()

    if conditional_model is None:
        steps = ['joint.forest', 'joint.components']
        if all(is_step_complete(step, saving, redo) for step in steps):
            log('[%d] Reusing joint and conditional distributions from files', seg.id)
            result.joint.forest = unpickle_it(saving['joint.forest'])
            result.joint.components = unpickle_it(saving['joint.components'])
            result.conditional.forest = None
            result.conditional.components = []
            return result

    steps = ['joint.forest', 'joint.components', 'conditional.forest', 'conditional.components']
    if all(is_step_complete(step, saving, redo) for step in steps):
        log('[%d] Reusing joint and conditional distributions from files', seg.id)
        result.joint.forest = unpickle_it(saving['joint.forest'])
        result.joint.components = unpickle_it(saving['joint.components'])
        result.conditional.forest = unpickle_it(saving['conditional.forest'])
        result.conditional.components = unpickle_it(saving['conditional.components'])
        return result

    # 1. Make a grammar

    # here we need to decode for sure
    log('[%d] Make hypergraph view of all available grammars', seg.id)
    # make a hypergraph view of all available grammars
    grammar = make_grammar_hypergraph(seg,
                                      extra_grammar_paths=options.extra_grammars,
                                      glue_grammar_paths=options.glue_grammars,
                                      pass_through=options.pass_through,
                                      default_symbol=options.default_symbol)
    #print('GRAMMAR')
    #print(grammar)

    # 2. Joint distribution - Step 1: parse source lattice
    n_goal = 0
    log('[%d] Parse source DFA', seg.id)
    goal_maker = GoalRuleMaker(goal_str=options.goal, start_str=options.start, n=n_goal)
    src_dfa = make_input_dfa(seg)
    src_forest = parse_dfa(grammar,
                           grammar.fetch(Nonterminal(options.start)),
                           src_dfa,
                           goal_maker.get_iview(),
                           bottomup=True,
                           constraint=HieroConstraints(grammar, src_dfa, options.max_span))
    #print('SOURCE')
    #print(src_forest)

    if not src_forest:
        raise ValueError('I cannot parse the input lattice: i) make sure your grammar has glue rules; ii) make sure it handles OOVs')

    # 3. Target projection of the forest
    log('[%d] Project target rules', seg.id)
    tgt_forest = make_target_forest(src_forest)
    #print('TARGET')
    #print(tgt_forest)

    # 4. Joint distribution - Step 2: scoring

    log('[%d] Joint model: (exact) local scoring', seg.id)
    result.joint = exact_rescoring(joint_model.local_model(), tgt_forest, goal_maker, log)

    # save joint distribution
    if 'joint.forest' in saving:
        pickle_it(saving['joint.forest'], result.joint.forest)
    if 'joint.components' in saving:
        pickle_it(saving['joint.components'], result.joint.components)

    if conditional_model is None:
        result.conditional.forest = None
        result.conditional.components = []
        return result

    # 5. Conditional distribution - Step 1: parse the reference lattice

    log('[%d] Parse reference DFA', seg.id)
    ref_dfa = make_reference_dfa(seg)
    goal_maker.update()
    ref_forest = parse_dfa(result.joint.forest,
                           0,
                           ref_dfa,
                           goal_maker.get_oview(),
                           bottomup=False)

    if not ref_forest:  # reference cannot be parsed
        log('[%d] References cannot be parsed', seg.id)
        result.conditional.forest = ref_forest
        result.conditional.components = []
    else:
        # 6. Conditional distribution - Step 2: scoring
        log('[%d] Conditional model: exact (local) scoring', seg.id)
        result.conditional = exact_rescoring(conditional_model.local_model(), ref_forest, goal_maker, log)

    # save conditional distribution
    if 'conditional.forest' in saving:
        pickle_it(saving['conditional.forest'], result.conditional.forest)
    if 'conditional.components' in saving:
        pickle_it(saving['conditional.components'], result.conditional.components)

    return result
Пример #8
0
def pass0_to_pass2(seg, options, lookup, stateless, stateful, saving={}, redo=True, log=dummyfunc) -> 'tuple':
    """
    Pass2 consists in exactly rescoring a forest.
    :return: rescored forest (a Hypergraph), and components (one FComponents object per edge)
    """

    # We try to reuse previous results
    if is_step_complete('pass2.forest', saving, redo) and is_step_complete('pass2.components', saving, redo):
        forest = unpickle_it(saving['pass2.forest'])
        components = unpickle_it(saving['pass2.components'])
        return forest, components

    # We check whether we need pass2
    if not stateful:  # execute passes 0 to 1 only
        forest, components = pass0_to_pass1(seg,
                                            options,
                                            lookup,
                                            stateless,
                                            saving,
                                            redo=redo,
                                            log=log)

        # TODO: complete components with empty stateful model
        # save (or link) forest
        if 'pass2.forest' in saving:
            if 'pass1.forest' in saving:
                symlink(saving['pass1.forest'], saving['pass2.forest'])
            else:
                pickle_it(saving['pass2.forest'], forest)
        # save (or link) components
        if 'pass2.components' in saving:
            if 'pass1.components' in saving:
                symlink(saving['pass1.components'], saving['pass2.components'])
            else:
                pickle_it(saving['pass2.components'], components)
        return forest, components

    # From here we are sure we have stateful scorers
    # then we first execute passes 0 to 1 (and discard dummy components)
    forest, _ = pass0_to_pass1(seg,
                               options,
                               TableLookupScorer(DummyModel()),
                               StatelessScorer(DummyModel()),
                               saving,
                               redo=redo,
                               log=log)

    # then we fully re-score the forest (keeping all components)
    log('[%d] Forest rescoring', seg.id)
    goal_maker = GoalRuleMaker(goal_str=options.goal, start_str=options.start, n=1)
    forest, components = rescore_forest(forest,
                                        0,
                                        TableLookupScorer(lookup),
                                        StatelessScorer(stateless),
                                        StatefulScorer(stateful),
                                        goal_rule=goal_maker.get_oview(),
                                        keep_components=True)
    # save the forest
    if 'pass2.forest' in saving:
        pickle_it(saving['pass2.forest'], forest)
    # save the components
    if 'pass2.components' in saving:
        pickle_it(saving['pass2.components'], components)

    return forest, components
Пример #9
0
def slice_sample(seg, args, staticdir, supportdir, workspace, model):
    files = [
        '{0}/{1}.D.ffs.all'.format(supportdir, seg.id),
        '{0}/{1}.hyp.ffs.all'.format(workspace, seg.id)
    ]

    if all(os.path.exists(path) for path in files) and not args.redo:
        logging.info('Reusing samples for segment %d', seg.id)
        return

    # 1. Load pickled objects
    logging.debug('[%d] Loading target forest', seg.id)
    forest = unpickle_it('{0}/{1}.hyp.forest'.format(staticdir, seg.id))
    # TODO: store top sort table
    logging.debug('[%d] Loading local components', seg.id)
    lookupffs = unpickle_it('{0}/{1}.hyp.ffs.rule'.format(staticdir, seg.id))
    statelessffs = unpickle_it('{0}/{1}.hyp.ffs.stateless'.format(
        staticdir, seg.id))

    # 2. Compute l(d)
    # there is a guarantee that lookup components and stateless components were computed over the same forest
    # that is, with the same nodes/edges structure
    # this is crucial to compute l(d) as below
    logging.debug('[%d] Computing l(d)', seg.id)
    lfunc = TableLookupFunction(
        np.array([
            semiring.inside.times(model.lookup.score(ff1),
                                  model.stateless.score(ff2))
            for ff1, ff2 in zip(lookupffs, statelessffs)
        ],
                 dtype=ptypes.weight))

    # 3. Sample from f(d) = n(d) * l(d)
    logging.debug('[%d] Sampling from f(d) = n(d) * l(d)', seg.id)
    tsort = AcyclicTopSortTable(forest)
    goal_maker = GoalRuleMaker(args.goal, args.start, n=2)

    sampler = SlicedRescoring(forest, lfunc, tsort,
                              TableLookupScorer(model.dummy),
                              StatelessScorer(model.dummy),
                              StatefulScorer(model.stateful), semiring.inside,
                              goal_maker.get_oview(),
                              OutputView(make_dead_srule()))

    # here samples are represented as sequences of edge ids
    d0, markov_chain = sampler.sample(n_samples=args.samples[0],
                                      batch_size=args.batch,
                                      within=args.within,
                                      initial=args.initial,
                                      prior=args.prior,
                                      burn=args.burn,
                                      lag=args.lag,
                                      temperature0=args.temperature0)

    # save empirical support
    pickle_it(
        '{0}/{1}.D.ffs.all'.format(supportdir, seg.id),
        get_empirical_support(model, frozenset(seg.refs), forest, lookupffs,
                              statelessffs, markov_chain))

    # apply usual MCMC filters to the Markov chain
    samples = apply_filters(markov_chain, burn=args.burn, lag=args.lag)

    n_samples = len(samples)

    # 4. Complete feature vectors and compute expectation
    hypcomps = []
    hypexp = model.constant(semiring.prob.zero)
    d_groups = group_by_identity(samples)
    for d_group in d_groups:
        derivation = d_group.key
        # reconstruct components
        lookup_comps = model.lookup.constant(semiring.inside.one)
        stateless_comps = model.stateless.constant(semiring.inside.one)
        for e in derivation.edges:
            lookup_comps = lookup_comps.hadamard(lookupffs[e],
                                                 semiring.inside.times)
            stateless_comps = stateless_comps.hadamard(statelessffs[e],
                                                       semiring.inside.times)
        # complete components (lookup, stateless, stateful)
        # note that here we are updating derivation.components!
        derivation.components = FComponents(
            [lookup_comps, stateless_comps, derivation.components])
        # incorporate sample frequency
        hypcomps.append(
            derivation.components.power(
                float(d_group.count) / n_samples, semiring.inside))
        hypexp = hypexp.hadamard(hypcomps[-1], semiring.prob.plus)

    # save feature vectors
    pickle_it('{0}/{1}.hyp.ffs.all'.format(workspace, seg.id), hypcomps)

    # 5. Log stuff
    if args.save_d:
        save_mcmc_derivations(
            '{0}/{1}.hyp.d.gz'.format(workspace, seg.id),
            d_groups,
            valuefunc=lambda d: d.score,
            compfunc=lambda d: d.components,
            derivation2str=lambda d: bracketed_string(forest, d.edges))

    if args.save_y:
        projections = group_by_projection(
            samples, lambda d: yield_string(forest, d.edges))
        save_mcmc_yields('{0}/{1}.hyp.y.gz'.format(workspace, seg.id),
                         projections)

    if args.save_chain:
        markov_chain.appendleft(d0)
        save_markov_chain(
            '{0}/{1}.hyp.chain.gz'.format(workspace, seg.id),
            markov_chain,
            flat=True,
            valuefunc=lambda d: d.score,
            #compfunc=lambda d: d.components,  # TODO: complete feature vectors of all derivations
            derivation2str=lambda d: bracketed_string(forest, d.edges))