Exemple #1
0
    def step(self, player, home_away_race, upgrades, available_act_mask, minimap):
        """Sample actions and compute logp(a|s)"""
        out = self.call(player, available_act_mask, minimap)

        # Gumbel-max sampling
        action_id = categorical_sample(out["action_id"], available_act_mask)

        tf.assert_greater(available_act_mask[:, action_id.numpy().item()], 0.0)

        # Fill out args based on sampled action type
        arg_spatial = []
        arg_nonspatial = []

        logp_a = log_prob(action_id, out["action_id"])

        for arg_type in self.action_spec.functions[action_id.numpy().item()].args:
            if arg_type.name in ["screen", "screen2", "minimap"]:
                location_id = categorical_sample(out["target_location"])
                arg_spatial.append(location_id)
                logp_a += log_prob(location_id, out["target_location"])
            else:
                # non-spatial args
                sample = categorical_sample(out[arg_type.name])
                arg_nonspatial.append(sample)
                logp_a += log_prob(sample, out[arg_type.name])
        # tf.debugging.check_numerics(logp_a, "Bad logp(a|s)")

        return (
            out["value"],
            action_id,
            arg_spatial,
            arg_nonspatial,
            logp_a,
        )
def get_words(expn, parent, lmk=None, rel=None):
    words = []
    probs = []
    entropy = []

    for n in expn.split():
        if n in NONTERMINALS:
            if n == parent == 'LANDMARK-PHRASE':
                # we need to move to the parent landmark
                lmk = parent_landmark(lmk)
            # we need to keep expanding
            expansion, exp_prob, exp_ent = get_expansion(n, parent, lmk, rel)
            w, w_prob, w_ent = get_words(expansion, n, lmk, rel)
            words.append(w)
            probs.append(exp_prob * w_prob)
            entropy.append(exp_ent + w_ent)
        else:
            # get word for POS
            w_db = Word.get_words(pos=n, lmk=lmk_id(lmk), rel=rel_type(rel))
            counter = collections.Counter(w_db)
            keys, counts = zip(*counter.items())
            counts = np.array(counts)
            counts /= counts.sum()
            w, w_prob, w_entropy = categorical_sample(keys, counts)
            words.append(w.word)
            probs.append(w.prob)
            entropy.append(w_entropy)
    p, H = np.prod(probs), np.sum(entropy)
    print 'expanding %s to %s (p: %f, H: %f)' % (expn, words, p, H)
    return words, p, H
def get_words(terminals, landmarks, rel=None):
    words = []
    probs = []
    entropy = []

    for n,lmk in zip(terminals, landmarks):
        # if we could not get an expansion for the LHS, we just pass down the unexpanded nonterminal symbol
        # it gets the probability of 1 and entropy of 0
        if n in NONTERMINALS:
            words.append(n)
            probs.append(1.0)
            entropy.append(0.0)
            continue

        lmk_class = (lmk.object_class if lmk else None)
        lmk_color = (lmk.color if lmk else None)
        rel_class = rel_type(rel)
        dist_class = (rel.measurement.best_distance_class if hasattr(rel, 'measurement') else None)
        deg_class = (rel.measurement.best_degree_class if hasattr(rel, 'measurement') else None)

        cp_db = CWord.get_word_counts(pos=n,
                                      lmk_class=lmk_class,
                                      lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                                      lmk_color=lmk_color,
                                      rel=rel_class,
                                      rel_dist_class=dist_class,
                                      rel_deg_class=deg_class)

        if cp_db.count() <= 0:
            logger( 'Could not expand %s (lmk_class: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, lmk_class, lmk_color, rel_class, dist_class, deg_class) )
            terminals.append( n )
            continue

        logger( 'Expanded %s (lmk_class: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, lmk_class, lmk_color, rel_class, dist_class, deg_class) )

        ckeys, ccounts = zip(*[(cword.word,cword.count) for cword in cp_db.all()])

        ccounter = {}
        for cword in cp_db.all():
            if cword.word in ccounter: ccounter[cword.word] += cword.count
            else: ccounter[cword.word] = cword.count

        ckeys, ccounts = zip(*ccounter.items())

        # print 'ckeys', ckeys
        # print 'ccounts', ccounts

        ccounts = np.array(ccounts, dtype=float)
        ccounts /= ccounts.sum()

        w, w_prob, w_entropy = categorical_sample(ckeys, ccounts)
        words.append(w)
        probs.append(w_prob)
        entropy.append(w_entropy)

    p, H = np.prod(probs), np.sum(entropy)
    # print 'expanding %s to %s (p: %f, H: %f)' % (terminals, words, p, H)
    return words, p, H
 def sample_next(self, a):
     transitions = self.successors[a]
     i = categorical_sample([t[0] for t in transitions])
     p, n_s, rew, done = transitions[i]
     if n_s not in self.node_register:
         n_s = SearchNode(n_s, self.env, self.gamma, done,
                          self.node_register)
     else:
         n_s = self.node_register[n_s]
     return n_s
Exemple #5
0
 def choose_action(self, observations, influencer_action):                       #使用obs和influencer的动作作为输入
     influencee_logist = []
     influencee_onehot = []
     for agent, obs, inf_act in zip(self.agents[self.influencer_num:],
                                    observations[self.influencer_num:],
                                    influencer_action[self.influencer_num:]):
         prob, logist = agent.choose_action(obs, inf_act)                        #具体见network中A3CAgent
         int_act, act = categorical_sample(prob)
         influencee_logist.append(logist)
         influencee_onehot.append(act)
     return influencee_onehot, influencee_logist
def get_expansion(lhs, parent=None, lmk=None, rel=None):
    p_db = Production.get_productions(lhs=lhs, parent=parent,
                                      lmk=lmk_id(lmk), rel=rel_type(rel))

    counter = collections.Counter(p_db)
    keys, counts = zip(*counter.items())
    counts = np.array(counts)
    counts /= counts.sum()

    prod, prod_prob, prod_entropy = categorical_sample(keys, counts)
    print 'expanding:', prod, prod_prob, prod_entropy
    return prod.rhs, prod_prob, prod_entropy
Exemple #7
0
 def choose_influencer_action(self, observations):                               #选择influencer的action, 直接放obs
     influencer_act_logists = []
     influencer_act_prob = []
     influencer_act_int = []
     influencer_act_onehot = []
     for agent, obs in zip(self.agents[:self.influencer_num], observations[:self.influencer_num]):
         prob, logist = agent.choose_action(obs)
         int_act, act = categorical_sample(prob)                                 #MAAC里的函数直接那过来的, 放入probability distribution, 返回int型动作和 onehot 动作
         influencer_act_logists.append(logist)
         influencer_act_prob.append(prob)
         influencer_act_onehot.append(act)
         influencer_act_int.append(int_act)
     return influencer_act_onehot, influencer_act_prob, influencer_act_int, influencer_act_logists
Exemple #8
0
    def sample_landmark(self, landmarks, trajector, usebest=False):
        ''' Weight by inverse of distance to landmark center and choose probabilistically  '''

        lm_probabilities = self.all_landmark_probs(landmarks, trajector)
        if usebest:
            index = index_max(lm_probabilities)
        else:
            index = categorical_sample(lm_probabilities)

        sampled_landmark = landmarks[index]
        head_on = self.get_head_on_viewpoint(sampled_landmark)
        self.set_orientations(sampled_landmark, head_on)

        return sampled_landmark, lm_probabilities[index], self.get_entropy(lm_probabilities), head_on
Exemple #9
0
    def sample_landmark(self, landmarks, trajector, usebest=False):
        ''' Weight by inverse of distance to landmark center and choose probabilistically  '''

        lm_probabilities = self.all_landmark_probs(landmarks, trajector)
        if usebest:
            index = index_max(lm_probabilities)
        else:
            index = categorical_sample(lm_probabilities)

        sampled_landmark = landmarks[index]
        head_on = self.get_head_on_viewpoint(sampled_landmark)
        self.set_orientations(sampled_landmark, head_on)

        return sampled_landmark, lm_probabilities[index], self.get_entropy(
            lm_probabilities), head_on
Exemple #10
0
    def sample_relation(self, trajector, bounding_box, perspective, landmark, step=0.02, usebest=False):
        """
        Sample a relation given a trajector and landmark.
        Evaluate each relation and probabilisticaly choose the one that is likely to
        generate the trajector given a landmark.
        """
        rel_probabilities, rel_classes = self.all_relation_probs(trajector, bounding_box, perspective, landmark, step)
        if usebest:
            index = index_max(rel_probabilities)
        else:
            index = categorical_sample(rel_probabilities)

        index = rel_probabilities.cumsum().searchsorted( random.sample(1) )[0]

        return rel_classes[index], rel_probabilities[index], self.get_entropy(rel_probabilities)
Exemple #11
0
    def sample_relation(self,
                        trajector,
                        bounding_box,
                        perspective,
                        landmark,
                        step=0.02,
                        usebest=False):
        """
        Sample a relation given a trajector and landmark.
        Evaluate each relation and probabilisticaly choose the one that is likely to
        generate the trajector given a landmark.
        """
        rel_probabilities, rel_classes = self.all_relation_probs(
            trajector, bounding_box, perspective, landmark, step)
        if usebest:
            index = index_max(rel_probabilities)
        else:
            index = categorical_sample(rel_probabilities)

        index = rel_probabilities.cumsum().searchsorted(random.sample(1))[0]

        return rel_classes[index], rel_probabilities[index], self.get_entropy(
            rel_probabilities)
def get_expansion(lhs, parent=None, lmk=None, rel=None):
    lhs_rhs_parent_chain = []
    prob_chain = []
    entropy_chain = []
    terminals = []
    landmarks = []

    for n in lhs.split():
        if n in NONTERMINALS:
            if n == parent == 'LANDMARK-PHRASE':
                # we need to move to the parent landmark
                lmk = parent_landmark(lmk)

            lmk_class = (lmk.object_class if lmk else None)
            lmk_ori_rels = get_lmk_ori_rels_str(lmk)
            lmk_color = (lmk.color if lmk else None)
            rel_class = rel_type(rel)
            dist_class = (rel.measurement.best_distance_class if hasattr(rel, 'measurement') else None)
            deg_class = (rel.measurement.best_degree_class if hasattr(rel, 'measurement') else None)

            cp_db = CProduction.get_production_counts(lhs=n,
                                                      parent=parent,
                                                      lmk_class=lmk_class,
                                                      lmk_ori_rels=lmk_ori_rels,
                                                      lmk_color=lmk_color,
                                                      rel=rel_class,
                                                      dist_class=dist_class,
                                                      deg_class=deg_class)

            if cp_db.count() <= 0:
                logger('Could not expand %s (parent: %s, lmk_class: %s, lmk_ori_rels: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, parent, lmk_class, lmk_ori_rels, lmk_color, rel_class, dist_class, deg_class))
                terminals.append( n )
                continue

            ckeys, ccounts = zip(*[(cprod.rhs,cprod.count) for cprod in cp_db.all()])

            ccounter = {}
            for cprod in cp_db.all():
                if cprod.rhs in ccounter: ccounter[cprod.rhs] += cprod.count
                else: ccounter[cprod.rhs] = cprod.count

            ckeys, ccounts = zip(*ccounter.items())

            # print 'ckeys', ckeys
            # print 'ccounts', ccounts

            ccounts = np.array(ccounts, dtype=float)
            ccounts /= ccounts.sum()

            cprod, cprod_prob, cprod_entropy = categorical_sample(ckeys, ccounts)
            # print cprod, cprod_prob, cprod_entropy

            lhs_rhs_parent_chain.append( ( n,cprod,parent,lmk ) )
            prob_chain.append( cprod_prob )
            entropy_chain.append( cprod_entropy )

            lrpc, pc, ec, t, ls = get_expansion( lhs=cprod, parent=n, lmk=lmk, rel=rel )
            lhs_rhs_parent_chain.extend( lrpc )
            prob_chain.extend( pc )
            entropy_chain.extend( ec )
            terminals.extend( t )
            landmarks.extend( ls )
        else:
            terminals.append( n )
            landmarks.append( lmk )

    return lhs_rhs_parent_chain, prob_chain, entropy_chain, terminals, landmarks
def get_words(terminals, landmarks, rel=None, prevword=None):
    words = []
    probs = []
    alphas = []
    entropy = []
    C = CWord.get_count

    for n,lmk in zip(terminals, landmarks):
        # if we could not get an expansion for the LHS, we just pass down the unexpanded nonterminal symbol
        # it gets the probability of 1 and entropy of 0
        if n in NONTERMINALS:
            words.append(n)
            probs.append(1.0)
            entropy.append(0.0)
            continue

        lmk_class = (lmk.object_class if lmk else None)
        lmk_color = (lmk.color if lmk else None)
        rel_class = rel_type(rel)
        dist_class = (rel.measurement.best_distance_class if hasattr(rel, 'measurement') else None)
        deg_class = (rel.measurement.best_degree_class if hasattr(rel, 'measurement') else None)



        meaning = dict(pos=n,
                       lmk_class=lmk_class,
                       lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                       lmk_color=lmk_color,
                       rel=rel_class,
                       rel_dist_class=dist_class,
                       rel_deg_class=deg_class)

        cp_db_uni = CWord.get_word_counts(**meaning)

        ccounter = {}
        for c in cp_db_uni:
            ccounter[c.word] = ccounter.get(c.word, 0) + c.count
        ckeys, ccounts_uni = zip(*ccounter.items())
        ccounts_uni = np.array(ccounts_uni, dtype=float)
        ccounts_uni /= ccounts_uni.sum()


        prev_word = words[-1] if words else prevword
        alpha = C(prev_word=prev_word, **meaning) / C(**meaning)
        alphas.append(alpha)

        if alpha:
            cp_db_bi = CWord.get_word_counts(prev_word=prev_word, **meaning)

            ccounter = {}
            for c in cp_db_bi:
                ccounter[c.word] = ccounter.get(c.word, 0) + c.count
            ccounts_bi = np.array([ccounter.get(k,0) for k in ckeys], dtype=float)
            ccounts_bi /= ccounts_bi.sum()

            cprob = (alpha * ccounts_bi) + ((1-alpha) * ccounts_uni)

        else:
            cprob = ccounts_uni


        # if cp_db.count() <= 0:
            # logger( 'Could not expand %s (lmk_class: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, lmk_class, lmk_color, rel_class, dist_class, deg_class) )
            # terminals.append( n )
            # continue

        # ckeys, ccounts = zip(*[(cword.word,cword.count) for cword in cp_db.all()])

        # ccounter = {}
        # for cword in cp_db.all():
        #     if cword.word in ccounter: ccounter[cword.word] += cword.count
        #     else: ccounter[cword.word] = cword.count

        # ckeys, ccounts = zip(*ccounter.items())

        # print 'ckeys', ckeys
        # print 'ccounts', ccounts

        # ccounts = np.array(ccounts, dtype=float)
        # ccounts /= ccounts.sum()

        w, w_prob, w_entropy = categorical_sample(ckeys, cprob)
        words.append(w)
        probs.append(w_prob)
        entropy.append(w_entropy)

    p, H = np.prod(probs), np.sum(entropy)
    # print 'expanding %s to %s (p: %f, H: %f)' % (terminals, words, p, H)
    return words, p, H, alphas
Exemple #14
0
def get_expansion(lhs, parent=None, lmk=None, rel=None):
    lhs_rhs_parent_chain = []
    prob_chain = []
    entropy_chain = []
    terminals = []
    landmarks = []

    for n in lhs.split():
        if n in NONTERMINALS:
            if n == parent == 'LANDMARK-PHRASE':
                # we need to move to the parent landmark
                lmk = parent_landmark(lmk)

            lmk_class = (lmk.object_class if lmk else None)
            lmk_ori_rels = get_lmk_ori_rels_str(lmk)
            lmk_color = (lmk.color if lmk else None)
            rel_class = rel_type(rel)
            dist_class = (rel.measurement.best_distance_class if hasattr(
                rel, 'measurement') else None)
            deg_class = (rel.measurement.best_degree_class if hasattr(
                rel, 'measurement') else None)

            cp_db = CProduction.get_production_counts(
                lhs=n,
                parent=parent,
                lmk_class=lmk_class,
                lmk_ori_rels=lmk_ori_rels,
                lmk_color=lmk_color,
                rel=rel_class,
                dist_class=dist_class,
                deg_class=deg_class)

            if cp_db.count() <= 0:
                logger(
                    'Could not expand %s (parent: %s, lmk_class: %s, lmk_ori_rels: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)'
                    % (n, parent, lmk_class, lmk_ori_rels, lmk_color,
                       rel_class, dist_class, deg_class))
                terminals.append(n)
                continue

            ckeys, ccounts = zip(*[(cprod.rhs, cprod.count)
                                   for cprod in cp_db.all()])

            ccounter = {}
            for cprod in cp_db.all():
                if cprod.rhs in ccounter: ccounter[cprod.rhs] += cprod.count
                else: ccounter[cprod.rhs] = cprod.count

            ckeys, ccounts = zip(*ccounter.items())

            # print 'ckeys', ckeys
            # print 'ccounts', ccounts

            ccounts = np.array(ccounts, dtype=float)
            ccounts /= ccounts.sum()

            cprod, cprod_prob, cprod_entropy = categorical_sample(
                ckeys, ccounts)
            # print cprod, cprod_prob, cprod_entropy

            lhs_rhs_parent_chain.append((n, cprod, parent, lmk))
            prob_chain.append(cprod_prob)
            entropy_chain.append(cprod_entropy)

            lrpc, pc, ec, t, ls = get_expansion(lhs=cprod,
                                                parent=n,
                                                lmk=lmk,
                                                rel=rel)
            lhs_rhs_parent_chain.extend(lrpc)
            prob_chain.extend(pc)
            entropy_chain.extend(ec)
            terminals.extend(t)
            landmarks.extend(ls)
        else:
            terminals.append(n)
            landmarks.append(lmk)

    return lhs_rhs_parent_chain, prob_chain, entropy_chain, terminals, landmarks
Exemple #15
0
def get_words(terminals, landmarks, rel=None, prevword=None):
    words = []
    probs = []
    alphas = []
    entropy = []
    C = CWord.get_count

    for n, lmk in zip(terminals, landmarks):
        # if we could not get an expansion for the LHS, we just pass down the unexpanded nonterminal symbol
        # it gets the probability of 1 and entropy of 0
        if n in NONTERMINALS:
            words.append(n)
            probs.append(1.0)
            entropy.append(0.0)
            continue

        lmk_class = (lmk.object_class if lmk else None)
        lmk_color = (lmk.color if lmk else None)
        rel_class = rel_type(rel)
        dist_class = (rel.measurement.best_distance_class if hasattr(
            rel, 'measurement') else None)
        deg_class = (rel.measurement.best_degree_class if hasattr(
            rel, 'measurement') else None)

        meaning = dict(pos=n,
                       lmk_class=lmk_class,
                       lmk_ori_rels=get_lmk_ori_rels_str(lmk),
                       lmk_color=lmk_color,
                       rel=rel_class,
                       rel_dist_class=dist_class,
                       rel_deg_class=deg_class)

        cp_db_uni = CWord.get_word_counts(**meaning)

        ccounter = {}
        for c in cp_db_uni:
            ccounter[c.word] = ccounter.get(c.word, 0) + c.count
        ckeys, ccounts_uni = zip(*ccounter.items())
        ccounts_uni = np.array(ccounts_uni, dtype=float)
        ccounts_uni /= ccounts_uni.sum()

        prev_word = words[-1] if words else prevword
        alpha = C(prev_word=prev_word, **meaning) / C(**meaning)
        alphas.append(alpha)

        if alpha:
            cp_db_bi = CWord.get_word_counts(prev_word=prev_word, **meaning)

            ccounter = {}
            for c in cp_db_bi:
                ccounter[c.word] = ccounter.get(c.word, 0) + c.count
            ccounts_bi = np.array([ccounter.get(k, 0) for k in ckeys],
                                  dtype=float)
            ccounts_bi /= ccounts_bi.sum()

            cprob = (alpha * ccounts_bi) + ((1 - alpha) * ccounts_uni)

        else:
            cprob = ccounts_uni

        # if cp_db.count() <= 0:
        # logger( 'Could not expand %s (lmk_class: %s, lmk_color: %s, rel: %s, dist_class: %s, deg_class: %s)' % (n, lmk_class, lmk_color, rel_class, dist_class, deg_class) )
        # terminals.append( n )
        # continue

        # ckeys, ccounts = zip(*[(cword.word,cword.count) for cword in cp_db.all()])

        # ccounter = {}
        # for cword in cp_db.all():
        #     if cword.word in ccounter: ccounter[cword.word] += cword.count
        #     else: ccounter[cword.word] = cword.count

        # ckeys, ccounts = zip(*ccounter.items())

        # print 'ckeys', ckeys
        # print 'ccounts', ccounts

        # ccounts = np.array(ccounts, dtype=float)
        # ccounts /= ccounts.sum()

        w, w_prob, w_entropy = categorical_sample(ckeys, cprob)
        words.append(w)
        probs.append(w_prob)
        entropy.append(w_entropy)

    p, H = np.prod(probs), np.sum(entropy)
    # print 'expanding %s to %s (p: %f, H: %f)' % (terminals, words, p, H)
    return words, p, H, alphas
    def loop(data):

        time.sleep(random())

        if 'num_iterations' in data:
            scene, speaker = construct_training_scene(True)
            num_iterations = data['num_iterations']
        else:
            scene = data['scene']
            speaker = data['speaker']
            num_iterations = len(data['loc_descs'])

        utils.scene.set_scene(scene,speaker)

        scene_bb = scene.get_bounding_box()
        scene_bb = scene_bb.inflate( Vec2(scene_bb.width*0.5,scene_bb.height*0.5) )
        table = scene.landmarks['table'].representation.get_geometry()

        # step = 0.04
        loi = [lmk for lmk in scene.landmarks.values() if lmk.name != 'table']
        all_heatmaps_tupless, xs, ys = speaker.generate_all_heatmaps(scene, step=step, loi=loi)

        loi_infos = []
        all_meanings = set()
        for obj_lmk,all_heatmaps_tuples in zip(loi, all_heatmaps_tupless):

            lmks, rels, heatmapss = zip(*all_heatmaps_tuples)
            meanings = zip(lmks,rels)
            # print meanings
            all_meanings.update(meanings)
            loi_infos.append( (obj_lmk, meanings, heatmapss) )

        all_heatmaps_tupless, xs, ys = speaker.generate_all_heatmaps(scene, step=step)
        all_heatmaps_tuples = all_heatmaps_tupless[0]
        # x = np.array( [list(xs-step*0.5)]*len(ys) )
        # y = np.array( [list(ys-step*0.5)]*len(xs) ).T
        # for lamk, rel, (heatmap1,heatmap2) in all_heatmaps_tuples:
        #     logger( m2s(lamk,rel))
        #     if isinstance(rel, DistanceRelation):
        #         probabilities = heatmap2.reshape( (len(xs),len(ys)) ).T
        #         plt.pcolor(x, y, probabilities, cmap = 'jet', edgecolors='none', alpha=0.7)
        #         plt.colorbar()
        #         for lmk in scene.landmarks.values():
        #             if isinstance(lmk.representation, GroupLineRepresentation):
        #                 xxs = [lmk.representation.line.start.x, lmk.representation.line.end.x]
        #                 yys = [lmk.representation.line.start.y, lmk.representation.line.end.y]
        #                 plt.fill(xxs,yys,facecolor='none',linewidth=2)
        #             elif isinstance(lmk.representation, RectangleRepresentation):
        #                 rect = lmk.representation.rect
        #                 xxs = [rect.min_point.x,rect.min_point.x,rect.max_point.x,rect.max_point.x]
        #                 yys = [rect.min_point.y,rect.max_point.y,rect.max_point.y,rect.min_point.y]
        #                 plt.fill(xxs,yys,facecolor='none',linewidth=2)
        #                 plt.text(rect.min_point.x+0.01,rect.max_point.y+0.02,lmk.name)
        #         plt.title(m2s(lamk,rel))
        #         logger("Showing")
        #         plt.show()
        #     logger("End")

        x = np.array( [list(xs-step*0.5)]*len(ys) )
        y = np.array( [list(ys-step*0.5)]*len(xs) ).T
        lmks, rels, heatmapss = zip(*all_heatmaps_tuples)
        # graphmax1 = graphmax2 = 0
        meanings = zip(lmks,rels)
        landmarks = list(set(lmks))
        # relations = list(set(rels))

        epsilon = 0.0001
        def heatmaps_for_sentences(sentences, all_meanings, loi_infos, xs, ys, scene, speaker, step=0.02):
            printing=False
            x = np.array( [list(xs-step*0.5)]*len(ys) )
            y = np.array( [list(ys-step*0.5)]*len(xs) ).T
            scene_bb = scene.get_bounding_box()
            scene_bb = scene_bb.inflate( Vec2(scene_bb.width*0.5,scene_bb.height*0.5) )
            # x = np.array( [list(xs-step*0.5)]*len(ys) )
            # y = np.array( [list(ys-step*0.5)]*len(xs) ).T

            combined_heatmaps = []
            for obj_lmk, ms, heatmapss in loi_infos:

                # for m,(h1,h2) in zip(ms, heatmapss):

                #     logger( h1.shape )
                #     logger( x.shape )
                #     logger( y.shape )
                #     logger( xs.shape )
                #     logger( ys.shape )
                #     plt.pcolor(x, y, h1.reshape((len(xs),len(ys))).T, cmap = 'jet', edgecolors='none', alpha=0.7)
                #     plt.colorbar()

                #     for lmk in scene.landmarks.values():
                #         if isinstance(lmk.representation, GroupLineRepresentation):
                #             xxs = [lmk.representation.line.start.x, lmk.representation.line.end.x]
                #             yys = [lmk.representation.line.start.y, lmk.representation.line.end.y]
                #             plt.fill(xxs,yys,facecolor='none',linewidth=2)
                #         elif isinstance(lmk.representation, RectangleRepresentation):
                #             rect = lmk.representation.rect
                #             xxs = [rect.min_point.x,rect.min_point.x,rect.max_point.x,rect.max_point.x]
                #             yys = [rect.min_point.y,rect.max_point.y,rect.max_point.y,rect.min_point.y]
                #             plt.fill(xxs,yys,facecolor='none',linewidth=2)
                #             plt.text(rect.min_point.x+0.01,rect.max_point.y+0.02,lmk.name)
                #     plt.title(m2s(*m))
                #     logger( m2s(*m))
                #     plt.axis('scaled')
                #     plt.show()

                combined_heatmap = None
                for sentence in sentences:
                    posteriors = get_all_sentence_posteriors(sentence, all_meanings, printing=printing)

                    big_heatmap1 = None
                    for m,(h1,h2) in zip(ms, heatmapss):

                        lmk,rel = m
                        p = posteriors[rel]*posteriors[lmk]
                        if big_heatmap1 is None:
                            big_heatmap1 = p*h1
                        else:
                            big_heatmap1 += p*h1

                    if combined_heatmap is None:
                        combined_heatmap = big_heatmap1
                    else:
                        combined_heatmap *= big_heatmap1

                combined_heatmaps.append(combined_heatmap)

            return combined_heatmaps

        object_meaning_applicabilities = {}
        for obj_lmk, ms, heatmapss in loi_infos:
            for m,(h1,h2) in zip(ms, heatmapss):
                ps = [p for (x,y),p in zip(list(product(xs,ys)),h1) if obj_lmk.representation.contains_point( Vec2(x,y) )]
                if m not in object_meaning_applicabilities:
                    object_meaning_applicabilities[m] = {}
                object_meaning_applicabilities[m][obj_lmk] = sum(ps)/len(ps)

        k = len(loi)
        for meaning_dict in object_meaning_applicabilities.values():
            total = sum( meaning_dict.values() )
            if total != 0:
                for obj_lmk in meaning_dict.keys():
                    meaning_dict[obj_lmk] = meaning_dict[obj_lmk]/total - 1.0/k
                total = sum( [value for value in meaning_dict.values() if value > 0] )
                for obj_lmk in meaning_dict.keys():
                    meaning_dict[obj_lmk] = (2 if meaning_dict[obj_lmk] > 0 else 1)*meaning_dict[obj_lmk] - total

        sorted_meaning_lists = {}

        for m in object_meaning_applicabilities.keys():
            for obj_lmk in object_meaning_applicabilities[m].keys():
                if obj_lmk not in sorted_meaning_lists:
                    sorted_meaning_lists[obj_lmk] = []
                sorted_meaning_lists[obj_lmk].append( (object_meaning_applicabilities[m][obj_lmk], m) )
        for obj_lmk in sorted_meaning_lists.keys():
            sorted_meaning_lists[obj_lmk].sort(reverse=True)

        min_dists = []
        lmk_priors = []
        rel_priors = []
        lmk_posts = []
        rel_posts = []
        golden_log_probs = []
        golden_entropies = []
        golden_ranks = []
        rel_types = []

        total_mass = []

        student_probs = []
        student_entropies = []
        student_ranks = []
        student_rel_types = []

        object_answers = []
        object_distributions = []
        object_sentences =[]

        epsilon = 1e-15

        for iteration in range(num_iterations):
            logger(('Iteration %d comprehension' % iteration),'okblue')

            if 'loc_descs' in data:
                trajector = data['lmks'][iteration]
                logger( 'Teacher chooses: %s' % trajector )
                sentences = data['loc_descs'][iteration]
                probs, sorted_meanings = zip(*sorted_meaning_lists[trajector][:30])
                probs = np.array(probs)# - min(probs)
                probs /= probs.sum()
                if sentences is None:
                    (sampled_landmark, sampled_relation) = categorical_sample( sorted_meanings, probs )[0]
                    logger( 'Teacher tries to say: %s' % m2s(sampled_landmark,sampled_relation) )
                    head_on = speaker.get_head_on_viewpoint(sampled_landmark)

                    sentences = [describe( head_on, trajector, sampled_landmark, sampled_relation )]
            else:
                # Teacher describe
                trajector = choice(loi)
                # sentence, sampled_relation, sampled_landmark = speaker.describe(trajector, scene, max_level=1)
                logger( 'Teacher chooses: %s' % trajector )
                # Choose from meanings
                probs, sorted_meanings = zip(*sorted_meaning_lists[trajector][:30])
                probs = np.array(probs)# - min(probs)
                probs /= probs.sum()
                (sampled_landmark, sampled_relation) = categorical_sample( sorted_meanings, probs )[0]
                logger( 'Teacher tries to say: %s' % m2s(sampled_landmark,sampled_relation) )

                # Generate sentence
                # _, sentence = generate_sentence(None, False, scene, speaker, meaning=(sampled_landmark, sampled_relation), golden=True, printing=printing)

                sentences = [describe( speaker.get_head_on_viewpoint(sampled_landmark), trajector, sampled_landmark, sampled_relation )]

            object_sentences.append( ' '.join(sentences) )
            logger( 'Teacher says: %s' % ' '.join(sentences))
            for i,(p,sm) in enumerate(zip(probs[:15],sorted_meanings[:15])):
                lm,re = sm
                logger( '%i: %f %s' % (i,p,m2s(*sm)) )
                # head_on = speaker.get_head_on_viewpoint(lm)
                # speaker.visualize( scene, trajector, head_on, lm, re)

            lmk_probs = []

            try:
                combined_heatmaps = heatmaps_for_sentences(sentences, all_meanings, loi_infos, xs, ys, scene, speaker, step=step)

                for combined_heatmap,obj_lmk in zip(combined_heatmaps, loi):

                    # x = np.array( [list(xs-step*0.5)]*len(ys) )
                    # y = np.array( [list(ys-step*0.5)]*len(xs) ).T
                    # logger( combined_heatmap.shape )
                    # logger( x.shape )
                    # logger( y.shape )
                    # logger( xs.shape )
                    # logger( ys.shape )
                    # plt.pcolor(x, y, combined_heatmap.reshape((len(xs),len(ys))).T, cmap = 'jet', edgecolors='none', alpha=0.7)
                    # plt.colorbar()

                    # for lmk in scene.landmarks.values():
                    #     if isinstance(lmk.representation, GroupLineRepresentation):
                    #         xxs = [lmk.representation.line.start.x, lmk.representation.line.end.x]
                    #         yys = [lmk.representation.line.start.y, lmk.representation.line.end.y]
                    #         plt.fill(xxs,yys,facecolor='none',linewidth=2)
                    #     elif isinstance(lmk.representation, RectangleRepresentation):
                    #         rect = lmk.representation.rect
                    #         xxs = [rect.min_point.x,rect.min_point.x,rect.max_point.x,rect.max_point.x]
                    #         yys = [rect.min_point.y,rect.max_point.y,rect.max_point.y,rect.min_point.y]
                    #         plt.fill(xxs,yys,facecolor='none',linewidth=2)
                    #         plt.text(rect.min_point.x+0.01,rect.max_point.y+0.02,lmk.name)
                    # plt.axis('scaled')
                    # plt.axis([scene_bb.min_point.x, scene_bb.max_point.x, scene_bb.min_point.y, scene_bb.max_point.y])
                    # plt.show()

                    ps = [p for (x,y),p in zip(list(product(xs,ys)),combined_heatmap) if obj_lmk.representation.contains_point( Vec2(x,y) )]
                    # print ps, xs.shape, ys.shape, combined_heatmap.shape
                    lmk_probs.append( (sum(ps)/len(ps), obj_lmk) )

                lmk_probs = sorted(lmk_probs, reverse=True)
                top_p, top_lmk = lmk_probs[0]
                lprobs, lmkss = zip(*lmk_probs)

                answer, distribution = loi.index(trajector), [ (lprob, loi.index(lmk)) for lprob,lmk in lmk_probs ]
                logger( sorted(zip(np.array(lprobs)/sum(lprobs), [(l.name, l.color, l.object_class) for l in lmkss]), reverse=True) )
                logger( 'I bet %f you are talking about a %s %s %s' % (top_p/sum(lprobs), top_lmk.name, top_lmk.color, top_lmk.object_class) )
                # objects.append(top_lmk)
            except Exception as e:
                logger( 'Unable to get object from sentence. %s' % e, 'fail' )
                answer = None
                top_lmk = None
                distribution = [(0,False)]

            object_answers.append( answer )
            object_distributions.append( distribution )

            # Present top_lmk to teacher
            if top_lmk == trajector:
                # Give morphine
                pass
            else:
                updates, _ = zip(*sorted_meaning_lists[trajector][:30])
                howmany=5
                for sentence in sentences:
                    for _ in range(howmany):
                        meaning = categorical_sample( sorted_meanings, probs )[0]
                        update = updates[ sorted_meanings.index(meaning) ]
                        try:
                            accept_object_correction( meaning, sentence, update*scale, printing=printing)
                        except:
                            pass
                    for update, meaning in sorted_meaning_lists[trajector][-howmany:]:
                        try:
                            accept_object_correction( meaning, sentence, update*scale, printing=printing)
                        except:
                            pass

            for _ in range(0):
	            logger(('Iteration %d production' % iteration),'okblue')
	            rand_p = Vec2(random()*table.width+table.min_point.x, random()*table.height+table.min_point.y)
	            trajector = Landmark( 'point', PointRepresentation(rand_p), None, Landmark.POINT )

	            meaning, sentence = generate_sentence(rand_p, False, scene, speaker, usebest=True, printing=printing)
	            logger( 'Generated sentence: %s' % sentence)

	            landmark = meaning.args[0]
	            # if ambiguous_pointing:
	                # pointing_point = landmark.representation.middle + Vec2(random()*0.1-0.05,random()*0.1-0.05)
	            #_, bestsentence = generate_sentence(rand_p, False, scene, speaker, usebest=True, printing=printing)

	            try:
	                golden_posteriors = get_all_sentence_posteriors(sentence, meanings, golden=True, printing=printing)
	            except ParseError as e:
	                logger( e )
	                prob = 0
	                rank = len(meanings)-1
	                entropy = 0
	                ed = len(sentence)
	                golden_log_probs.append( prob )
	                golden_entropies.append( entropy )
	                golden_ranks.append( rank )
	                min_dists.append( ed )
	                continue
	            epsilon = 1e-15
	            ps = [[golden_posteriors[lmk]*golden_posteriors[rel],(lmk,rel)] for lmk, rel in meanings if (lmk == landmark)]

	            all_lmk_probs = speaker.all_landmark_probs(landmarks, Landmark(None, PointRepresentation(rand_p), None))
	            all_lmk_probs = dict(zip(landmarks, all_lmk_probs))
	            temp = None
	            for i,(p,(lmk,rel)) in enumerate(ps):
	                # lmk,rel = meanings[i]
	                # logger( '%f, %s' % (p, m2s(lmk,rel)))
	                head_on = speaker.get_head_on_viewpoint(lmk)
	                ps[i][0] *= speaker.get_probabilities_points( np.array([rand_p]), rel, head_on, lmk)[0]
	                if lmk == meaning.args[0] and rel == meaning.args[3]:
	                    temp = i

	            ps,_meanings = zip(*ps)
	            print ps
	            ps = np.array(ps)
	            ps += epsilon
	            ps = ps/ps.sum()
	            temp = ps[temp]

	            ps = sorted(zip(ps,_meanings),reverse=True)

	            logger( 'Attempted to say: %s' %  m2s(meaning.args[0],meaning.args[3]) )
	            logger( 'Interpreted as: %s' % m2s(ps[0][1][0],ps[0][1][1]) )
	            logger( 'Attempted: %f vs Interpreted: %f' % (temp, ps[0][0]))

	            # logger( 'Golden entropy: %f, Max entropy %f' % (golden_entropy, max_entropy))

	            landmark, relation = ps[0][1]
	            head_on = speaker.get_head_on_viewpoint(landmark)
	            all_descs = speaker.get_all_meaning_descriptions(trajector, scene, landmark, relation, head_on, 1)

	            distances = []
	            for desc in all_descs:
	                distances.append([edit_distance( sentence, desc ), desc])

	            distances.sort()
	            print distances

	            correction = distances[0][1]
	            # if correction == sentence:
	            #     correction = None
	            #     logger( 'No correction!!!!!!!!!!!!!!!!!!', 'okgreen' )
	            accept_correction( meaning, correction, update_scale=scale, eval_lmk=False, multiply=False, printing=printing )


            def probs_metric(inverse=False):
                rand_p = Vec2(random()*table.width+table.min_point.x, random()*table.height+table.min_point.y)
                try:
                    bestmeaning, bestsentence = generate_sentence(rand_p, False, scene, speaker, usebest=True, golden=inverse, printing=printing)
                    sampled_landmark, sampled_relation = bestmeaning.args[0], bestmeaning.args[3]
                    golden_posteriors = get_all_sentence_posteriors(bestsentence, meanings, golden=(not inverse), printing=printing)

                    # lmk_prior = speaker.get_landmark_probability(sampled_landmark, landmarks, PointRepresentation(rand_p))[0]
                    all_lmk_probs = speaker.all_landmark_probs(landmarks, Landmark(None, PointRepresentation(rand_p), None))
                    all_lmk_probs = dict(zip(landmarks, all_lmk_probs))

                    lmk_prior = all_lmk_probs[sampled_landmark]
                    head_on = speaker.get_head_on_viewpoint(sampled_landmark)
                    rel_prior = speaker.get_probabilities_points( np.array([rand_p]), sampled_relation, head_on, sampled_landmark)
                    lmk_post = golden_posteriors[sampled_landmark]
                    rel_post = golden_posteriors[sampled_relation]

                    ps = np.array([golden_posteriors[lmk]*golden_posteriors[rel] for lmk, rel in meanings])
                    rank = None
                    for i,p in enumerate(ps):
                        lmk,rel = meanings[i]
                        # logger( '%f, %s' % (p, m2s(lmk,rel)))
                        head_on = speaker.get_head_on_viewpoint(lmk)
                        # ps[i] *= speaker.get_landmark_probability(lmk, landmarks, PointRepresentation(rand_p))[0]
                        ps[i] *= all_lmk_probs[lmk]
                        ps[i] *= speaker.get_probabilities_points( np.array([rand_p]), rel, head_on, lmk)
                        if lmk == sampled_landmark and rel == sampled_relation:
                            idx = i

                    ps += epsilon
                    ps = ps/ps.sum()
                    prob = ps[idx]
                    rank = sorted(ps, reverse=True).index(prob)
                    entropy = entropy_of_probs(ps)
                except (ParseError,RuntimeError) as e:
                    logger( e )
                    lmk_prior = 0
                    rel_prior = 0
                    lmk_post = 0
                    rel_post = 0
                    prob = 0
                    rank = len(meanings)-1
                    entropy = 0
                    distances = [[None]]

                head_on = speaker.get_head_on_viewpoint(sampled_landmark)
                all_descs = speaker.get_all_meaning_descriptions(trajector, scene, sampled_landmark, sampled_relation, head_on, 1)
                distances = []
                for desc in all_descs:
                    distances.append([edit_distance( bestsentence, desc ), desc])
                distances.sort()
                return lmk_prior,rel_prior,lmk_post,rel_post,\
                       prob,entropy,rank,distances[0][0],type(sampled_relation)

            def db_mass():
                total = CProduction.get_production_sum(None)
                total += CWord.get_word_sum(None)
                return total

            def choosing_object_metric():
                trajector = choice(loi)

                sentence, sampled_relation, sampled_landmark = speaker.describe(trajector, scene, max_level=1)

                lmk_probs = []
                try:
                    combined_heatmaps = heatmaps_for_sentence(sentence, all_meanings, loi_infos, xs, ys, scene, speaker, step=step)

                    for combined_heatmap,obj_lmk in zip(combined_heatmaps, loi):
                        ps = [p for (x,y),p in zip(list(product(xs,ys)),combined_heatmap) if obj_lmk.representation.contains_point( Vec2(x,y) )]
                        # print ps, xs.shape, ys.shape, combined_heatmap.shape
                        lmk_probs.append( (sum(ps)/len(ps), obj_lmk) )

                    lmk_probs = sorted(lmk_probs, reverse=True)
                    top_p, top_lmk = lmk_probs[0]
                    lprobs, lmkss = zip(*lmk_probs)

                    logger( sorted(zip(np.array(lprobs)/sum(lprobs), [(l.name, l.color, l.object_class) for l in lmkss]), reverse=True) )
                    logger( 'I bet %f you are talking about a %s %s %s' % (top_p/sum(lprobs), top_lmk.name, top_lmk.color, top_lmk.object_class) )
                    # objects.append(top_lmk)
                except Exception as e:
                    logger( 'Unable to get object from sentence. %s' % e, 'fail' )
                    print traceback.format_exc()
                    exit()
                return loi.index(trajector), [ (lprob, loi.index(lmk)) for lprob,lmk in lmk_probs ]

            if golden_metric:
                lmk_prior,rel_prior,lmk_post,rel_post,prob,entropy,rank,ed,rel_type = probs_metric()
            else:
                lmk_prior,rel_prior,lmk_post,rel_post,prob,entropy,rank,ed,rel_type = [None]*9

            lmk_priors.append( lmk_prior )
            rel_priors.append( rel_prior )
            lmk_posts.append( lmk_post )
            rel_posts.append( rel_post )
            golden_log_probs.append( prob )
            golden_entropies.append( entropy )
            golden_ranks.append( rank )
            min_dists.append( ed )
            rel_types.append( rel_type )

            if mass_metric:
                total_mass.append( db_mass() )
            else:
                total_mass.append( None )

            if student_metric:
                _,_,_,_,student_prob,student_entropy,student_rank,_,student_rel_type = probs_metric(inverse=True)
            else:
                _,_,_,_,student_prob,student_entropy,student_rank,_,student_rel_type = \
                None, None, None, None, None, None, None, None, None

            student_probs.append( student_prob )
            student_entropies.append( student_entropy )
            student_ranks.append( student_rank )
            student_rel_types.append( student_rel_type )

            # if choosing_metric:
            #     answer, distribution = choosing_object_metric()
            # else:
            #     answer, distribution = None, None
            # object_answers.append( answer )
            # object_distributions.append( distribution )

        return zip(lmk_priors, rel_priors, lmk_posts, rel_posts,
                   golden_log_probs, golden_entropies, golden_ranks,
                   min_dists, rel_types, total_mass, student_probs,
                   student_entropies, student_ranks, student_rel_types,
                   object_answers, object_distributions, object_sentences)
Exemple #17
0
    def loop(data):

        time.sleep(data['delay'])

        scene = data['scene']
        speaker = data['speaker']
        utils.scene.set_scene(scene,speaker)
        num_iterations = len(data['loc_descs'])

        all_meanings = data['all_meanings']
        loi = data['loi']
        loi_infos = data['loi_infos']
        landmarks = data['landmarks']
        sorted_meaning_lists = data['sorted_meaning_lists']
        learn_objects = data['learn_objects']

        def heatmaps_for_sentences(sentences, all_meanings, loi_infos, xs, ys, scene, speaker, step=0.02):
            printing=False
            x = np.array( [list(xs-step*0.5)]*len(ys) )
            y = np.array( [list(ys-step*0.5)]*len(xs) ).T
            scene_bb = scene.get_bounding_box()
            scene_bb = scene_bb.inflate( Vec2(scene_bb.width*0.5,scene_bb.height*0.5) )

            combined_heatmaps = []
            for obj_lmk, ms, heatmapss in loi_infos:

                combined_heatmap = None
                for sentence in sentences:
                    posteriors = None
                    while not posteriors:
                        try:
                            posteriors = get_all_sentence_posteriors(sentence, all_meanings, printing=printing)
                        except ParseError as pe:
                            raise pe
                        except Exception as e:
                            print e
                            sleeptime = random()*0.5
                            logger('Sleeping for %f and retrying "%s"' % (sleeptime,sentence))
                            time.sleep(sleeptime)
                            continue

                    big_heatmap1 = None
                    for m,(h1,h2) in zip(ms, heatmapss):

                        lmk,rel = m
                        p = posteriors[rel]*posteriors[lmk]
                        if big_heatmap1 is None:
                            big_heatmap1 = p*h1
                        else:
                            big_heatmap1 += p*h1

                    if combined_heatmap is None:
                        combined_heatmap = big_heatmap1
                    else:
                        combined_heatmap *= big_heatmap1

                combined_heatmaps.append(combined_heatmap)

            return combined_heatmaps

        object_answers = []
        object_distributions = []
        object_sentences =[]
        object_ids = []

        epsilon = 1e-15

        for iteration in range(num_iterations):
            logger(('Iteration %d comprehension' % iteration),'okblue')

            trajector = data['lmks'][iteration]
            if trajector is None:
                trajector = choice(loi)
            logger( 'Teacher chooses: %s' % trajector )

            probs, sorted_meanings = zip(*sorted_meaning_lists[trajector][:30])
            probs = np.array(probs)# - min(probs)
            probs /= probs.sum()

            sentences = data['loc_descs'][iteration]
            if sentences is None:
                (sampled_landmark, sampled_relation) = categorical_sample( sorted_meanings, probs )[0]
                logger( 'Teacher tries to say: %s' % m2s(sampled_landmark,sampled_relation) )
                head_on = speaker.get_head_on_viewpoint(sampled_landmark)

                sentences = [describe( head_on, trajector, sampled_landmark, sampled_relation )]

            object_sentences.append( ' '.join(sentences) )
            object_ids.append( data['ids'][iteration] )

            logger( 'Teacher says: %s' % ' '.join(sentences))

            for i,(p,sm) in enumerate(zip(probs[:15],sorted_meanings[:15])):
                lm,re = sm
                logger( '%i: %f %s' % (i,p,m2s(*sm)) )

            lmk_probs = []

            try:
                combined_heatmaps = heatmaps_for_sentences(sentences, all_meanings, loi_infos, xs, ys, scene, speaker, step=step)
            except ParseError as e:
                logger( 'Unable to get object from sentence. %s' % e, 'fail' )
                top_lmk = None
                distribution = [(0, False, False)]
            else:
                for combined_heatmap,obj_lmk in zip(combined_heatmaps, loi):

                    ps = [p for (x,y),p in zip(list(product(xs,ys)),combined_heatmap) if obj_lmk.representation.contains_point( Vec2(x,y) )]
                    # print ps, xs.shape, ys.shape, combined_heatmap.shape
                    lmk_probs.append( (sum(ps)/len(ps), obj_lmk) )

                lmk_probs = sorted(lmk_probs, reverse=True)
                top_p, top_lmk = lmk_probs[0]
                lprobs, lmkss = zip(*lmk_probs)

                distribution = [ (lprob, lmk.name, loi.index(lmk)) for lprob,lmk in lmk_probs ]
                logger( sorted(zip(np.array(lprobs)/sum(lprobs), [(l.name, l.color, l.object_class) for l in lmkss]), reverse=True) )
                logger( 'I bet %f you are talking about a %s %s %s' % (top_p/sum(lprobs), top_lmk.name, top_lmk.color, top_lmk.object_class) )
                # objects.append(top_lmk)

            answer = (trajector.name,loi.index(trajector))
            object_answers.append( answer )
            object_distributions.append( distribution )

            # Present top_lmk to teacher
            logger("top_lmk == trajector: %r, learn_objects: %r" % (top_lmk == trajector,learn_objects), 'okgreen')
            if top_lmk == trajector or not learn_objects:
                # Give morphine
                logger("Ahhhhh, morphine...", 'okgreen')
                pass
            else:
                logger("LEARNING!!!!!!!!!!!", 'okgreen')
                updates, _ = zip(*sorted_meaning_lists[trajector][:30])
                howmany=5
                for sentence in sentences:
                    for _ in range(howmany):
                        meaning = categorical_sample( sorted_meanings, probs )[0]
                        update = updates[ sorted_meanings.index(meaning) ]
                        try:
                            accept_object_correction( meaning, sentence, update*scale, printing=printing)
                        except:
                            pass
                    for update, meaning in sorted_meaning_lists[trajector][-howmany:]:
                        try:
                            accept_object_correction( meaning, sentence, update*scale, printing=printing)
                        except:
                            pass

        return zip(object_answers, object_distributions, object_sentences, object_ids)