Ejemplo n.º 1
0
  def _compute_alpha(self, particleWeights, finalIndex):
    # Remove the weight of the chosen xi from the list instead of
    # trying to subtract in logspace to prevent catastrophic
    # cancellation (for the same reason as
    # PGibbsOperator._compute_alpha)
    particleWeightsNoXi = copy.copy(particleWeights)
    particleWeightsNoXi.pop(finalIndex)

    weightMinusXi = logaddexp(particleWeightsNoXi)
    weightMinusRho = logaddexp(particleWeights[0:-1])
    alpha = weightMinusRho - weightMinusXi
    return alpha
Ejemplo n.º 2
0
  def _compute_alpha(self, rhoWeight, xiWeights, finalIndex):
    # Remove the weight of the chosen xi from the list instead of
    # trying to subtract in logspace to prevent catastrophic
    # cancellation (as would happen if the chosen xi weight were
    # significantly larger than all the other xi weights and the rho
    # weight).
    otherXiWeightsWithRho = copy.copy(xiWeights)
    otherXiWeightsWithRho.pop(finalIndex)
    otherXiWeightsWithRho.append(rhoWeight)

    weightMinusXi = logaddexp(otherXiWeightsWithRho)
    weightMinusRho = logaddexp(xiWeights)
    alpha = weightMinusRho - weightMinusXi
    return alpha
Ejemplo n.º 3
0
def dp_inside_batch(batch_size,sentence_len,tags_dim,weights):
    inside_table = torch.DoubleTensor(batch_size, sentence_len * sentence_len * 8, tags_dim, tags_dim)
    inside_table.fill_(-np.inf)
    if torch.cuda.is_available():
        inside_table = inside_table.cuda()
    m = sentence_len
    seed_spans, base_left_spans, base_right_spans, left_spans, right_spans, ijss, ikss, kjss, id_span_map, span_id_map = test_constituent_indexes(
            m, False)

    for ii in seed_spans:
        inside_table[:, ii, :, :] = 0.0

    for ii in base_right_spans:
        (l, r, c) = id_span_map[ii]
        swap_weights = weights.permute(0, 1, 4, 3, 2)
        inside_table[:, ii, :, :] = swap_weights[:, r, :, l, :]

    for ii in base_left_spans:
        (l, r, c) = id_span_map[ii]
        inside_table[:, ii, :, :] = weights[:, l, :, r, :]

    for ij in ijss:
        (l, r, c) = id_span_map[ij]
        if ij in left_spans:
            ids = span_id_map.get((l, r, get_state_code(0, 0, 0)), -1)
            prob = inside_table[:, ids, :, :] + weights[:, l, :, r, :]
            inside_table[:, ij, :, :] = utils.logaddexp(inside_table[:, ij, :, :], prob)
        elif ij in right_spans:
            ids = span_id_map.get((l, r, get_state_code(0, 0, 0)), -1)
            swap_weights = weights.permute(0, 1, 4, 3, 2)
            prob = inside_table[:, ids, :, :] + swap_weights[:, r, :, l, :]
            inside_table[:, ij, :, :] = utils.logaddexp(inside_table[:, ij, :, :], prob)
        else:
            num_k = len(ikss[ij])
            beta_ik, beta_kj = inside_table[:, ikss[ij], :, :], inside_table[:, kjss[ij], :, :]
            probs = beta_ik.contiguous().view(batch_size, num_k, tags_dim, tags_dim, 1) +\
                        beta_kj.contiguous().view(batch_size, num_k, 1, tags_dim, tags_dim)
            probs = utils.logsumexp(probs, axis=(1, 3))
            inside_table[:, ij, :, :] = utils.logaddexp(inside_table[:, ij, :, :], probs)

    id1 = span_id_map.get((0, m - 1, get_state_code(0, 1, 0)), -1)
    id2 = span_id_map.get((0, m - 1, get_state_code(0, 1, 1)), -1)

    score1 = inside_table[:, id1, 0, :].contiguous().view(batch_size, 1, tags_dim)
    score2 = inside_table[:, id2, 0, :].contiguous().view(batch_size, 1, tags_dim)
    ll = utils.logaddexp(utils.logsumexp(score1, axis=2), utils.logsumexp(score2, axis=2))
    return inside_table, ll
Ejemplo n.º 4
0
  def propose(self, trace, scaffold):
    program = trace.proposal_programs[self.program_name]
    pnodes = [node for tar in program.tlist for node in trace.getNodesInBlock(tar[0],tar[1])]
    cnodes = [node for cond in program.clist for node in trace.scopes[cond[0]][cond[1]]]
#    drawScaffoldKernel(trace,scaffold,pnodes,cnodes,program.tlist,program.clist)
    try:
      assert len(pnodes) == program.n_target
    except:
      raise(Exception('Expect to have one and only one random node in a block. Check if 1) multiple nodes are defined in a target block; 2) some target nodes are observed (and thus not random).'))
    old_target = [node.value.number for node in pnodes]
    conditioned = [node.value.number for node in cnodes]
#    print conditioned
    registerDeterministicLKernels(trace,scaffold,pnodes,map(lambda x:VentureNumber(x),old_target))
    rhoWeight = self.prepare(trace, scaffold)
#    print "CustomMHOperator rhoWeight: ", rhoWeight, "old_target: ", old_target
    # print rhoWeight
    # TODO: get conditioned values
    # TODO: invoke program, get target values and qratio from proposal
    # print 'cond:', conditioned
    # print 'old tar:', old_target
    # (new_target, qratio) = program.propose(conditioned,old_target)
    conditioned_labels, latent_labels, target_labels = [], [], []
    conditioned_src = program.gen_conditioned_src(conditioned)
    if conditioned_src:
      (conditioned_labels, conditioned_strings, _) = execute_and_record(program.ripl, conditioned_src, "conditioned")
    latent_src = program.gen_latent_src(conditioned)
    if latent_src:
      (latent_labels, latent_strings, _) = execute_and_record(program.ripl, latent_src, "latent")
    target_src = program.gen_target_src(conditioned)
    if target_src:
      (target_labels, target_strings, new_target) = execute_and_record(program.ripl, target_src, "target")
    else:
      raise(Exception("No target_src found. Procedure 'gen_target_src' must be specified."))
    for label in reversed(target_labels):
        program.ripl.forget(label)
    if self.method == 'assumed_gibbs':
      qratio = 1e10
    else:
      old_logscores, new_logscores = [], []
      if not latent_src: self.mc_samples = 1
      for i in range(self.mc_samples):
        predict_to_observe(program.ripl, target_strings, new_target)
        new_logscores.append(sum_directive_logscore(program.ripl, target_labels))
        for label in reversed(target_labels):
          program.ripl.forget(label)
        predict_to_observe(program.ripl, target_strings, old_target)
        old_logscores.append(sum_directive_logscore(program.ripl, target_labels))
        for label in reversed(target_labels):
          program.ripl.forget(label)
        for label in reversed(latent_labels):
          program.ripl.forget(label)
        execute_and_record(program.ripl, latent_src, "latent")
      old_logscore = logaddexp(old_logscores)
      new_logscore = logaddexp(new_logscores)
      qratio = old_logscore - new_logscore
#      print "CustomMHOperator old_logscores: ", old_logscores, "CustomMHOperator new_logscores: ", new_logscores
    for label in reversed(latent_labels):
      program.ripl.forget(label)
    for label in reversed(conditioned_labels):
      program.ripl.forget(label)
    # print map(lambda x:VentureNumber(x),new_target)
    # TODO: use DeterministicLKernel to fill in new target values
    registerDeterministicLKernels(trace,scaffold,pnodes,map(lambda x:VentureNumber(x),new_target))
    # TODO: get xiWeight
    # print scaffold.border[0]
    xiWeight = regenAndAttach(trace,scaffold.border[0],scaffold,False,OmegaDB(),{})
#    print "CustomMHOperator xiWeight: ", xiWeight, "new_target: ", new_target
#    print "CustomMHOperator qratio: ", qratio
    # TODO: return (new_trace, xiWeight + qratio - rhoWeight)
#    print '====================='
#    print 'new_target', new_target
#    print 'xiWeight:', xiWeight
#    print 'qratio:', qratio
#    print 'rhoWeight:', rhoWeight
    return (trace, xiWeight + qratio - rhoWeight)
Ejemplo n.º 5
0
    def batch_outside(self, inside_table, crf_score):
        inside_complete_table = inside_table[0]
        inside_incomplete_table = inside_table[1]
        outside_complete_table = torch.DoubleTensor(
            self.batch_size, self.sentence_length * self.sentence_length * 2,
            self.tag_num)
        outside_incomplete_table = torch.DoubleTensor(
            self.batch_size, self.sentence_length * self.sentence_length * 2,
            self.tag_num, self.tag_num)
        if torch.cuda.is_available():
            outside_complete_table = outside_complete_table.cuda()
            outside_incomplete_table = outside_incomplete_table.cuda()
        span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(
            self.sentence_length, False)
        outside_complete_table.fill_(LOGZERO)
        outside_incomplete_table.fill_(LOGZERO)

        root_id = span_2_id.get((0, self.sentence_length - 1, 1))
        outside_complete_table[:, root_id, 0] = 0.0

        complete_span_used = set()
        incomplete_span_used = set()
        complete_span_used.add(root_id)

        for ij in reversed(ijss):
            (l, r, dir) = id_2_span[ij]
            # complete span consists of one incomplete span and one complete span
            num_kc = len(ikcs[ij])
            if dir == 0:
                outside_ij_cc = outside_complete_table[:, ij, :].contiguous(
                ).view(self.batch_size, 1, 1, self.tag_num)
                inside_kj_ic = inside_incomplete_table[:, kjcs[
                    ij], :, :].contiguous().view(self.batch_size, num_kc,
                                                 self.tag_num, self.tag_num)
                inside_ik_cc = inside_complete_table[:,
                                                     ikcs[ij], :].contiguous(
                                                     ).view(
                                                         self.batch_size,
                                                         num_kc, self.tag_num,
                                                         1)
                outside_ik_cc = (outside_ij_cc + inside_kj_ic).permute(
                    0, 1, 3, 2)
                # swap left-right position since right tags are to be indexed
                outside_kj_ic = outside_ij_cc + inside_ik_cc
                for i in range(num_kc):
                    ik = ikcs[ij][i]
                    kj = kjcs[ij][i]
                    outside_ik_cc_i = utils.logsumexp(outside_ik_cc[:,
                                                                    i, :, :],
                                                      axis=1)
                    if ik in complete_span_used:
                        outside_complete_table[:, ik, :] = utils.logaddexp(
                            outside_complete_table[:, ik, :], outside_ik_cc_i)
                    else:
                        outside_complete_table[:,
                                               ik, :] = outside_ik_cc_i.clone(
                                               )
                        complete_span_used.add(ik)

                    if kj in incomplete_span_used:
                        outside_incomplete_table[:, kj, :, :] = utils.logaddexp(
                            outside_incomplete_table[:, kj, :, :],
                            outside_kj_ic[:, i, :, :])
                    else:
                        outside_incomplete_table[:,
                                                 kj, :, :] = outside_kj_ic[:,
                                                                           i, :, :]
                        incomplete_span_used.add(kj)
            else:
                outside_ij_cc = outside_complete_table[:, ij, :].contiguous(
                ).view(self.batch_size, 1, self.tag_num, 1)
                inside_ik_ic = inside_incomplete_table[:, ikcs[
                    ij], :, :].contiguous().view(self.batch_size, num_kc,
                                                 self.tag_num, self.tag_num)
                inside_kj_cc = inside_complete_table[:,
                                                     kjcs[ij], :].contiguous(
                                                     ).view(
                                                         self.batch_size,
                                                         num_kc, 1,
                                                         self.tag_num)
                outside_kj_cc = outside_ij_cc + inside_ik_ic
                outside_ik_ic = outside_ij_cc + inside_kj_cc
                for i in range(num_kc):
                    kj = kjcs[ij][i]
                    ik = ikcs[ij][i]
                    outside_kj_cc_i = utils.logsumexp(outside_kj_cc[:,
                                                                    i, :, :],
                                                      axis=1)
                    if kj in complete_span_used:
                        outside_complete_table[:, kj, :] = utils.logaddexp(
                            outside_complete_table[:, kj, :], outside_kj_cc_i)
                    else:
                        outside_complete_table[:,
                                               kj, :] = outside_kj_cc_i.clone(
                                               )
                        complete_span_used.add(kj)

                    if ik in incomplete_span_used:
                        outside_incomplete_table[:, ik, :, :] = utils.logaddexp(
                            outside_incomplete_table[:, ik, :, :],
                            outside_ik_ic[:, i, :, :])
                    else:
                        outside_incomplete_table[:,
                                                 ik, :, :] = outside_ik_ic[:,
                                                                           i, :, :]
                        incomplete_span_used.add(ik)

            # incomplete span consists of two complete spans
            num_ki = len(ikis[ij])

            outside_ij_ii = outside_incomplete_table[:, ij, :, :].contiguous(
            ).view(self.batch_size, 1, self.tag_num, self.tag_num)
            inside_ik_ci = inside_complete_table[:, ikis[ij], :].contiguous(
            ).view(self.batch_size, num_ki, self.tag_num, 1)
            inside_kj_ci = inside_complete_table[:, kjis[ij], :].contiguous(
            ).view(self.batch_size, num_ki, 1, self.tag_num)

            if dir == 0:
                outside_ik_ci = outside_ij_ii + inside_kj_ci + crf_score[:, r, l, :, :]. \
                    permute(0, 2, 1).contiguous().view(self.batch_size, 1, self.tag_num, self.tag_num)

                outside_kj_ci = outside_ij_ii + inside_ik_ci + crf_score[:, r, l, :, :]. \
                    permute(0, 2, 1).contiguous().view(self.batch_size, 1, self.tag_num, self.tag_num)
            else:
                outside_ik_ci = outside_ij_ii + inside_kj_ci + crf_score[:, l, r, :, :].contiguous(
                ).view(self.batch_size, 1, self.tag_num, self.tag_num)
                outside_kj_ci = outside_ij_ii + inside_ik_ci + crf_score[:, l, r, :, :].contiguous(
                ).view(self.batch_size, 1, self.tag_num, self.tag_num)

            for i in range(num_ki):
                ik = ikis[ij][i]
                kj = kjis[ij][i]

                outside_ik_ci_i = utils.logsumexp(outside_ik_ci[:, i, :, :],
                                                  axis=2)
                outside_kj_ci_i = utils.logsumexp(outside_kj_ci[:, i, :, :],
                                                  axis=1)
                if ik in complete_span_used:
                    outside_complete_table[:, ik, :] = utils.logaddexp(
                        outside_complete_table[:, ik, :], outside_ik_ci_i)
                else:
                    outside_complete_table[:, ik, :] = outside_ik_ci_i.clone()
                    complete_span_used.add(ik)
                if kj in complete_span_used:
                    outside_complete_table[:, kj, :] = utils.logaddexp(
                        outside_complete_table[:, kj, :], outside_kj_ci_i)
                else:
                    outside_complete_table[:, kj, :] = outside_kj_ci_i.clone()
                    complete_span_used.add(kj)

        return (outside_complete_table, outside_incomplete_table)
Ejemplo n.º 6
0
    def dp_outside_batch(self, inside_table, weights):
        outside_table = torch.DoubleTensor(self.batch_size, self.sentence_len * self.sentence_len * 8, self.tags_dim, self.tags_dim)
        outside_table.fill_(LOGZERO)
        if torch.cuda.is_available():
            outside_table = outside_table.cuda()
        m = self.sentence_len
        seed_spans, base_left_spans, base_right_spans, left_spans, right_spans, ijss, ikss, kjss, id_span_map, span_id_map = utils.constituent_indexes(
            m, self.is_multi_root)
        id1 = span_id_map.get((0, m - 1, utils.get_state_code(0, 1, 0)), -1)
        id2 = span_id_map.get((0, m - 1, utils.get_state_code(0, 1, 1)), -1)
        outside_table[:, id1, :, :] = 0.0
        outside_table[:, id2, :, :] = 0.0

        for ij in reversed(ijss):
            (l, r, c) = id_span_map[ij]
            if ij in left_spans:
                assert c == utils.get_state_code(0, 1, 1)
                prob = outside_table[:, ij, :, :] + weights[:, l, :, r, :]
                ids = span_id_map.get((l, r, utils.get_state_code(0, 0, 0)), -1)
                outside_table[:, ids, :, :] = utils.logaddexp(outside_table[:, ids, :, :], prob)
            elif ij in right_spans:
                assert c == utils.get_state_code(1, 0, 1)
                swap_weights = weights.permute(0, 1, 4, 3, 2)
                prob = outside_table[:, ij, :, :] + swap_weights[:, r, :, l, :]
                ids = span_id_map.get((l, r, utils.get_state_code(0, 0, 0)), -1)
                outside_table[:, ids, :, :] = utils.logaddexp(outside_table[:, ids, :, :], prob)
            else:
                num_k = len(ikss[ij])
                if l == 0:
                    # ROOT's value can only be 0
                    alpha_ij = outside_table[:, ij, 0, :].contiguous().view(self.batch_size, 1, 1, 1, self.tags_dim)
                    beta_left = inside_table[:, ikss[ij], [0], :].contiguous().view(self.batch_size, num_k, 1, self.tags_dim, 1)
                    beta_right = inside_table[:, kjss[ij], :, :].contiguous().view(self.batch_size, num_k, 1, self.tags_dim, self.tags_dim)
                    new_left = alpha_ij + beta_right
                    new_left = utils.logsumexp(new_left, axis=4).contiguous().view(self.batch_size, num_k, 1, self.tags_dim, 1)
                    new_right = alpha_ij + beta_left
                    if len(list(set(ikss[ij]))) == num_k:
                        outside_table[:, ikss[ij], [0], :] = utils.logaddexp(
                            outside_table[:, ikss[ij], [0], :].contiguous().view(self.batch_size, num_k, 1, self.tags_dim, 1), new_left).contiguous().view(self.batch_size, num_k, self.tags_dim)
                    outside_table[:, kjss[ij], :, :] = utils.logaddexp(
                            outside_table[:, kjss[ij], :, :].contiguous().view(self.batch_size, num_k, 1, self.tags_dim, self.tags_dim),
                            new_right).contiguous().view(self.batch_size, num_k, self.tags_dim, self.tags_dim)
                else:
                    alpha_ij = outside_table[:, ij, :, :].contiguous().view(self.batch_size, 1, self.tags_dim, 1, self.tags_dim)
                    beta_left = inside_table[:, ikss[ij], :, :].contiguous().view(self.batch_size, num_k, self.tags_dim, self.tags_dim, 1)
                    beta_right = inside_table[:, kjss[ij], :, :].contiguous().view(self.batch_size, num_k, 1, self.tags_dim, self.tags_dim)
                    new_left = alpha_ij + beta_right
                    new_left = utils.logsumexp(new_left, axis=4).contiguous().view(self.batch_size, num_k, self.tags_dim, self.tags_dim, 1)
                    new_right = alpha_ij + beta_left
                    new_right = utils.logsumexp(new_right, axis=2).contiguous().view(self.batch_size, num_k, 1, self.tags_dim, self.tags_dim)
                    if len(list(set(ikss[ij]))) == num_k:
                        outside_table[:, ikss[ij], :, :] = utils.logaddexp(
                            outside_table[:, ikss[ij], :, :].contiguous().view(self.batch_size, num_k, self.tags_dim, self.tags_dim, 1),
                            new_left).contiguous().view(self.batch_size, num_k, self.tags_dim, self.tags_dim)

                    outside_table[:, kjss[ij], :, :] = utils.logaddexp(outside_table[:, kjss[ij], :, :].contiguous().view(self.batch_size, num_k, 1, self.tags_dim, self.tags_dim), new_right).contiguous().view(self.batch_size, num_k, self.tags_dim, self.tags_dim)

                if len(list(set(ikss[ij]))) == num_k:
                    # Already done in the above
                    pass
                else:
                    # TODO make this vectorized, the problem is the id in ikss is not unique
                    for i in range(num_k):
                        ik = ikss[ij][i]
                        kj = kjss[ij][i]
                        if l == 0:
                            alpha_ij = outside_table[:, ij, 0, :].contiguous().view(self.batch_size, 1, 1, 1, self.tags_dim)
                            beta_right = inside_table[:, kj, :, :].contiguous().view(self.batch_size, 1, 1, self.tags_dim, self.tags_dim)
                            new_left = alpha_ij + beta_right
                            new_left = utils.logsumexp(new_left, axis=4).contiguous().view(self.batch_size, 1, 1, self.tags_dim, 1)
                            outside_table[:, ik, 0, :] = utils.logaddexp(
                                outside_table[:, ik, 0, :].contiguous().view(self.batch_size, 1, 1, self.tags_dim, 1), new_left).contiguous().view(self.batch_size, self.tags_dim, )
                        else:
                            alpha_ij = outside_table[:, ij, :, :].contiguous().view(self.batch_size, 1, self.tags_dim, 1, self.tags_dim)
                            beta_right = inside_table[:, kj, :, :].contiguous().view(self.batch_size, 1, 1, self.tags_dim, self.tags_dim)
                            new_left = alpha_ij + beta_right
                            new_left = utils.logsumexp(new_left, axis=4).contiguous().view(self.batch_size, 1, self.tags_dim, self.tags_dim, 1)
                            outside_table[:, ik, :, :] = utils.logaddexp(outside_table[:, ik, :, :].contiguous().view(self.batch_size, 1, self.tags_dim, self.tags_dim, 1), new_left).contiguous().view(self.batch_size, self.tags_dim, self.tags_dim)

        for ij in base_left_spans:
            (l, r, c) = id_span_map[ij]
            prob = outside_table[:, ij, :, :] + weights[:, l, :, r, :]
            ids = span_id_map.get((l, r, utils.get_state_code(0, 0, 1)), -1)
            outside_table[:, ids, :, :] = utils.logaddexp(outside_table[:, ids, :, :], prob)

        for ij in base_right_spans:
            (l, r, c) = id_span_map[ij]
            swap_weights = weights.permute(0, 1, 4, 3, 2)
            prob = outside_table[:, ij, :, :] + swap_weights[:, r, :, l, :]
            ids = span_id_map.get((l, r, utils.get_state_code(0, 0, 1)), -1)
            outside_table[:, ids, :, :] = utils.logaddexp(outside_table[:, ids, :, :], prob)

        return outside_table