コード例 #1
0
    def create_causal_observation(
        env,
        action,
        cur_state,
        prev_state,
        causal_observations,
        trial_count,
        attempt_count,
    ):
        state_diff = cur_state - prev_state
        state_change_occurred = len(state_diff) > 0
        # TODO(mjedmonds): generalize to more than 1 state change
        if len(state_diff) > 2:
            print_message(
                trial_count,
                attempt_count,
                "More than one state change this iteration, chain assumes only one variable changes at a time: {}"
                .format(state_diff),
            )

        precondition = None
        # need to check for previous effective precondition.
        # We could take an action with an effect, take an action with no effect, then take an action with an effect.
        # We want the precondition to carry over from the first action, so we need to find the preconditon of the last action with an effect
        for i in reversed(range(0, len(causal_observations))):
            if causal_observations[
                    i].causal_relation.causal_relation_type is not None:
                precondition = (
                    causal_observations[i].causal_relation.attributes,
                    causal_observations[i].causal_relation.
                    causal_relation_type[1],
                )
                # want the first precondition we find, so break
                break

        if state_change_occurred:
            # TODO(mjedmonds): refactor to include door_lock
            state_diff = [x for x in state_diff if x[0] != "door_lock"]
            # TODO(mjedmonds): this only handles a single state_diff per timestep
            assert (
                len(state_diff) < 2
            ), "Multiple fluents changing at each time step not yet implemented"
            state_diff = state_diff[0]
            causal_relation_type = CausalRelationType(state_diff[1])
            attributes = env.get_obj_attributes(state_diff[0])
            attributes = tuple(attributes[key] for key in env.attribute_order)
            causal_observations.append(
                CausalObservation(
                    CausalRelation(
                        action=action,
                        attributes=attributes,
                        causal_relation_type=causal_relation_type,
                        precondition=precondition,
                    ),
                    info_gain=None,
                ))
        return causal_observations
コード例 #2
0
    def prune_chains_from_initial_observation_common(
        self,
        causal_chain_space,
        position_to_color_dict,
        trial_count,
        attempt_count,
        starting_idx,
        ending_idx,
    ):
        position_index = causal_chain_space.structure_space.attribute_order.index(
            "position")
        color_index = causal_chain_space.structure_space.attribute_order.index(
            "color")
        chain_idxs_pruned = set()
        for causal_chain_idx in range(starting_idx, ending_idx):
            chain_chain = causal_chain_space.structure_space.causal_chains[
                causal_chain_idx]
            for subchain in chain_chain:
                chain_attribute_position = subchain.attributes[position_index]
                chain_attribute_color = subchain.attributes[color_index]
                # if we have a mismatch, mark chain as invalid
                if (chain_attribute_color !=
                        position_to_color_dict[chain_attribute_position]):
                    causal_chain_space.bottom_up_belief_space.beliefs[
                        causal_chain_idx] = 0.0
                    chain_idxs_pruned.add(causal_chain_idx)
                    break
                # check for invalid preconditions
                if subchain.precondition is not None:
                    chain_precondition_position = subchain.precondition[0][
                        position_index]
                    chain_precondition_color = subchain.precondition[0][
                        color_index]
                    if (chain_precondition_color !=
                            position_to_color_dict[chain_precondition_position]
                        ):
                        causal_chain_space.bottom_up_belief_space.beliefs[
                            causal_chain_idx] = 0.0
                        chain_idxs_pruned.add(causal_chain_idx)
                        break

        print_message(
            trial_count, attempt_count,
            "Pruned {}/{} chains based on initial observation".format(
                len(chain_idxs_pruned),
                ending_idx - starting_idx), self.print_messages)
        return (
            causal_chain_space.bottom_up_belief_space.
            beliefs[starting_idx:ending_idx],
            chain_idxs_pruned,
            starting_idx,
            ending_idx,
        )
コード例 #3
0
    def select_intervention_random(
        self,
        causal_chain_space,
        causal_chain_idxs,
        chain_sample_size,
        interventions_executed,
        trial_count,
        attempt_count,
        sample_chains=False,
    ):
        selected_causal_chain_idxs = []

        # sanity check
        assert_str = "Causal chain space has negative number ({}) of chains with belief above threshold {}".format(
            causal_chain_space.structure_space.
            num_chains_with_belief_above_threshold,
            causal_chain_space.structure_space.belief_threshold,
        )
        assert (causal_chain_space.structure_space.
                num_chains_with_belief_above_threshold >= 0), assert_str

        if sample_chains:
            selected_causal_chain_idxs = self.sample_chains(
                causal_chain_space.structure_space,
                causal_chain_idxs,
                chain_sample_size,
                interventions_executed,
            )
        else:
            selected_causal_chain_idxs = causal_chain_idxs
        print_message(
            trial_count,
            attempt_count,
            "Randomly picking intervention using {} chains from {} possible chains."
            .format(
                len(selected_causal_chain_idxs),
                causal_chain_space.structure_space.
                num_chains_with_belief_above_threshold,
            ),
            self.print_messages,
        )
        rand_idx = np.random.randint(0, len(selected_causal_chain_idxs))
        intervention_idx = selected_causal_chain_idxs[rand_idx]
        intervention = causal_chain_space.structure_space.causal_chains.get_actions(
            intervention_idx)
        return intervention, 0
コード例 #4
0
    def initialize_local_q2(self, trial_name):
        start_time = time.time()
        print_message(
            self.trial_count,
            self.attempt_count,
            "Initializing new Q for {}".format(trial_name),
        )
        self.qlearner.initialize_local_Q(trial_name)
        if self.qlearner.global_Q.max() == 0.0:
            # todo: refactor; this is stub code to get default dict to be created
            dummy = self.qlearner.local_Q[trial_name].shape
            return

        # copy global Q to initialize local Q
        self.qlearner.local_Q[trial_name] = copy.deepcopy(
            self.qlearner.global_Q)

        print_message(
            self.trial_count,
            self.attempt_count,
            "Initializing new Q for {} took {:0.6f}s".format(
                trial_name,
                time.time() - start_time),
        )
コード例 #5
0
    def prune_inconsistent_chains_common(
        self,
        causal_chain_space,
        causal_chain_idxs,
        causal_observations,
        trial_count,
        attempt_count,
    ):
        assert False, "function deprecated"
        chains_idxs_removed = []
        chains_idxs_consistent = []
        print_update_rate = 1000000
        start_time = time.time()

        # normalization_factor = 0
        for i in range(len(causal_chain_idxs)):
            if i % print_update_rate == 0 and i != 0:
                print_message(
                    trial_count, attempt_count,
                    "Checking for chains to prune. {}/{} chains checked. Runtime: {:0.6f}s"
                    .format(i, len(causal_chain_idxs),
                            time.time() - start_time), self.print_messages)

            # this represents where we are in the chain's transitions - not all observations are causal, so don't advance the chain's execution
            chain_change_idx = 0
            causal_chain_idx = causal_chain_idxs[i]
            chain = causal_chain_space.structure_space.causal_chains[
                causal_chain_idx]
            chain_belief = causal_chain_space.bottom_up_belief_space.beliefs[
                causal_chain_idx]

            # target_states = ("UPPER", "door")
            # target_actions = ("push_UPPER", "push_door")
            # target_causal_relations = (
            #     CausalRelationType.one_to_zero,
            #     CausalRelationType.zero_to_one,
            # )
            # target_cpt_choices = self.causal_chain_space.structure_space.convert_causal_relations_to_cpt_choices(
            #     target_causal_relations
            # )
            # target_attributes = (("UPPER", "GREY"), ("door", "GREY"))
            # target_chain = self.causal_chain_space.structure_space.create_compact_chain(
            #     target_states,
            #     target_actions,
            #     target_cpt_choices,
            #     target_attributes,
            #     convert_to_ids=True,
            # )
            # if causal_chain.states == target_chain.states and causal_chain.actions == target_chain.actions:
            #     self.causal_chain_space.structure_space.pretty_print_causal_chains([causal_chain])
            # if causal_chain == target_chain:
            #     self.causal_chain_space.structure_space.pretty_print_causal_chains([causal_chain])
            # if causal_chain.actions == target_actions:
            #     self.causal_chain_space.structure_space.pretty_print_causal_chains([causal_chain])

            # if causal_chain in self.causal_chain_space.structure_space.true_chains:
            #     print('True chain found')

            # skip checking this chain if we have already pruned it (chains below threshold should still be considered)
            if chain_belief <= 0.0:
                continue

            chain_consistent = True
            for causal_observation in causal_observations:
                skipped_relation_flag, skip_chain_flag, outcome_consistent = self.check_node_consistency(
                    causal_chain_space=causal_chain_space,
                    causal_observation=causal_observation,
                    chain=chain,
                    chain_change_idx=chain_change_idx,
                )
                # received signal to skip the rest of this chain
                if skip_chain_flag:
                    break
                # received signal to skip relation
                if skipped_relation_flag:
                    continue
                # if the outcome is inconsistent on the first action, prune
                if not outcome_consistent:
                    # print("GRAPH {} PRUNED because of subchain {}:".format(causal_chain_idx, chain_change_idx))
                    # causal_chain_space.structure_space.pretty_print_causal_chains([causal_chain_idx])
                    assert (chain_belief >
                            0.0), "Removing chain already marked as invalid"

                    causal_chain_space.bottom_up_belief_space[
                        causal_chain_idx] = 0.0
                    # self.qlearner.remove_causal_chain_from_local_Q(trial_name, causal_chain.chain_id)

                    chains_idxs_removed.append(causal_chain_idx)
                    chain_consistent = False
                    if DEBUGGING:
                        if (causal_chain_idx in causal_chain_space.
                                structure_space.true_chain_idxs):
                            print_message(trial_count, attempt_count,
                                          "GRAPH REMOVED: ")
                            causal_chain_space.structure_space.pretty_print_causal_chain_idxs(
                                [causal_chain_idx])
                            raise ValueError(
                                "true chain removed from plausible chains")
                    # exit loop over observations, we are finished with this chain
                    break

                # advance a chain change index if there was a causal change (causal relation present).
                if causal_observation.determine_causal_change_occurred():
                    chain_change_idx += 1

            if chain_consistent:
                # print("GRAPH {} CONSISTENT through subchain {}:".format(causal_chain_idx, chain_change_idx))
                # causal_chain_space.structure_space.pretty_print_causal_chains([causal_chain_idx])
                chains_idxs_consistent.append(causal_chain_idx)

        return chains_idxs_removed, chains_idxs_consistent
コード例 #6
0
    def initialize_local_q(self, trial_name):
        start_time = time.time()
        print_message(
            self.trial_count,
            self.attempt_count,
            "Initializing new Q for {}".format(trial_name),
        )
        self.qlearner.initialize_local_Q(trial_name)
        if self.qlearner.global_Q.max() == 0.0:
            # todo: this is stub code to get default dict to be created
            dummy = self.qlearner.local_Q[trial_name].shape
            return

        indexed_confidences, total_confidences = self.compute_indexed_confidences(
            trial_name)

        total_possible_confidence = 0
        for causal_chain in self.causal_chain_space.causal_chains:
            chain_confidence = 0
            total_possible_chain_confidence = 0
            for causal_chain_prime in self.causal_chain_space.causal_chains:
                # todo: refactor to use attributes only rather than state + attributes
                for i in range(len(causal_chain_prime.attributes)):
                    # color confidence
                    color_confidence = indexed_confidences[i]["color"]
                    total_possible_chain_confidence += color_confidence
                    # contributes if chain attributes match (chain distance measure)
                    if causal_chain.attributes[
                            i] == causal_chain_prime.attributes[i]:
                        chain_confidence += color_confidence
                        # chain_confidence += color_confidence / total_color_confidence
                    # position confidence
                    position_confidence = indexed_confidences[i]["position"]
                    total_possible_chain_confidence += position_confidence
                    # contributes if chain states (positions) match (chain distance measure)
                    if causal_chain.states[i] == causal_chain_prime.states[i]:
                        chain_confidence += position_confidence
                        # chain_confidence += position_confidence / total_position_confidence

                # assign the raw chain confidence; confidence is constant across solutions remaining - still need to multiply by global Q value
                for num_solutions_remaining in range(
                        self.qlearner.global_Q.shape[0]):
                    # normalize and multiply by the global Q value
                    # self.qlearner.local_Q[trial_name][num_solutions_remaining][causal_chain.chain_id] = (chain_confidence / total_possible_chain_confidence) * self.qlearner.global_Q[num_solutions_remaining][causal_chain.chain_id]
                    # add the proportion of the g' Q-value to g according to confidence
                    self.qlearner.local_Q[trial_name][num_solutions_remaining][
                        causal_chain.chain_id] += (
                            chain_confidence *
                            self.qlearner.global_Q[num_solutions_remaining][
                                causal_chain_prime.chain_id])
                    # self.qlearner.local_Q[trial_name][num_solutions_remaining][causal_chain.chain_id] = chain_confidence

            total_possible_confidence += total_possible_chain_confidence

        # normalize by the total possible confidence
        # for num_solutions_remaining in self.qlearner.local_Q[trial_name].keys():
        #     self.qlearner.local_Q[trial_name][num_solutions_remaining] = self.qlearner.local_Q[trial_name][num_solutions_remaining] / total_possible_confidence
        print_message(
            self.trial_count,
            self.attempt_count,
            "Initializing new Q for {} took {:0.6f}s".format(
                trial_name,
                time.time() - start_time),
        )