コード例 #1
0
ファイル: schema_learner.py プロジェクト: sidorovTV/schema-rl
    def __init__(self):
        self._params = [[ParamMatrix() for _ in range(C.N_PREDICTABLE_ATTRIBUTES)]
                        for _ in range(2)]
        self._R = ParamMatrix()

        self._buff = []
        self._replay = self.Batch(np.empty((0, C.SCHEMA_VEC_SIZE), dtype=bool),
                                  np.empty((0, C.N_PREDICTABLE_ATTRIBUTES), dtype=bool),
                                  np.empty((0, C.N_PREDICTABLE_ATTRIBUTES), dtype=bool),
                                  np.empty(0, dtype=bool))

        self._attr_mip_models = [[MipModel() for _ in range(C.N_PREDICTABLE_ATTRIBUTES)]
                                 for _ in range(2)]
        self._reward_mip_model = MipModel()
        self._solved = []

        self._curr_iter = None
        self._visualizer = Visualizer(None, None, None)
コード例 #2
0
ファイル: schema_learner.py プロジェクト: pkuderov/schema-rl
    def __init__(self):
        self._W = [np.ones((C.SCHEMA_VEC_SIZE, C.L), dtype=bool)
                   for _ in range(C.N_PREDICTABLE_ATTRIBUTES)]
        self._n_attr_schemas = np.ones(C.N_PREDICTABLE_ATTRIBUTES, dtype=np.int)

        self._R = np.ones((C.SCHEMA_VEC_SIZE, C.L), dtype=bool)
        self._n_reward_schemas = 1

        self._buff = []
        self._replay = self.Batch(np.empty((0, C.SCHEMA_VEC_SIZE), dtype=bool),
                                  np.empty((0, C.N_PREDICTABLE_ATTRIBUTES), dtype=bool),
                                  np.empty((0), dtype=bool))

        self._attr_mip_models = [MipModel() for _ in range(C.N_PREDICTABLE_ATTRIBUTES)]
        self._reward_mip_model = MipModel()
        self._solved = []

        self._curr_iter = None
        self._visualizer = Visualizer(None, None, None)
コード例 #3
0
    def learn(self):
        env = StandardBreakout(**self.env_params)
        shaper = Shaper()
        data_loader = DataLoader(shaper)
        learner = GreedySchemaLearner()
        visualizer = Visualizer(None, None, None)

        env.reset()
        reward = 0

        for episode_idx in range(self.n_episodes):
            observations = deque(maxlen=self.LEARNING_BATCH_SIZE)
            actions = deque(maxlen=self.LEARNING_BATCH_SIZE)

            for step_idx in range(self.n_steps):
                curr_iter = episode_idx * self.n_steps + step_idx
                print('\ncurr_iter: {}'.format(curr_iter))

                obs = EntityExtractor.extract(env)
                observations.append(obs)

                # hardcoded action choice
                chosen_action = self._get_hardcoded_action(env)
                actions.append(chosen_action)

                # visualizing
                if self.VISUALIZE_STATE:
                    visualizer.set_iter(curr_iter)
                    visualizer.visualize_env_state(obs)

                # learning
                if len(observations) >= self.LEARNING_BATCH_SIZE:
                    batch = data_loader.make_batch(observations, actions,
                                                   reward)

                    learner.set_curr_iter(curr_iter)
                    learner.take_batch(batch)

                    is_flush_needed = curr_iter == self.n_episodes * self.n_steps - 1
                    if curr_iter % self.LEARNING_PERIOD == 0 or is_flush_needed:
                        learner.learn()

                obs, reward, done, _ = env.step(chosen_action)
                if done:
                    print('END_OF_EPISODE, step_idx == {}'.format(step_idx))
                    break
コード例 #4
0
    def loop(self):
        handcrafted_W, handcrafted_R, _ = HardcodedSchemaVectors.gen_schema_matrices()

        env = self.env_class(**self.env_params)
        shaper = Shaper()
        visualizer = Visualizer(None, None, None)
        data_loader = DataLoader(shaper)
        learner = GreedySchemaLearner()
        planner = SchemaNetwork()

        for episode_idx in range(self.n_episodes):
            env.reset()
            reward = 0

            observations = deque(maxlen=C.LEARNING_BATCH_SIZE)
            frame_stack = deque(maxlen=C.FRAME_STACK_SIZE)
            actions_taken = deque(maxlen=C.LEARNING_BATCH_SIZE)
            exec_actions = deque()

            planning_timer = 0
            emergency_replanning_timer = None

            episode_reward = 0
            step_idx = 0

            for step_idx in range(self.n_steps):
                curr_iter = episode_idx * self.n_steps + step_idx
                print('\ncurr_iter: {}'.format(curr_iter))

                obs = EntityExtractor.extract(env)
                observations.append(obs)
                frame_stack.append(obs)

                if C.VISUALIZE_STATE:
                    # visualize env state
                    visualizer.set_iter(curr_iter)
                    visualizer.visualize_env_state(obs)

                # --- planning ---
                
                learned_W, learned_R = learner.get_weights()

                W = handcrafted_W if C.USE_HANDCRAFTED_ATTRIBUTE_SCHEMAS else learned_W
                R = handcrafted_R if C.USE_HANDCRAFTED_REWARD_SCHEMAS else learned_R

                are_weights_ok = W is not None and R is not None

                can_run_planner = are_weights_ok and len(frame_stack) == C.FRAME_STACK_SIZE

                if C.USE_EMERGENCY_REPLANNING:
                    is_planning_needed = len(exec_actions) == 0 and emergency_replanning_timer is None \
                                        or emergency_replanning_timer == 0
                else:
                    is_planning_needed = (planning_timer == 0)

                if is_planning_needed and can_run_planner:
                    print('Launching planning procedure...')

                    # handle timers
                    emergency_replanning_timer = None
                    planning_timer = C.PLANNING_PERIOD

                    planner.set_weights(W, R)
                    planner.set_curr_iter(curr_iter)

                    planned_actions = planner.plan_actions(frame_stack)
                    if planned_actions is not None:
                        exec_actions.clear()
                        exec_actions.extend(planned_actions)

                if exec_actions:
                    chosen_action = exec_actions.popleft()
                else:
                    chosen_action = np.random.choice(C.ACTION_SPACE_DIM)

                    if can_run_planner:
                        if emergency_replanning_timer is None:
                            emergency_replanning_timer = C.EMERGENCY_REPLANNING_PERIOD
                        emergency_replanning_timer -= 1

                if planning_timer > 0:
                    planning_timer -= 1
                # ---------------------

                actions_taken.append(chosen_action)

                # --- learning ---
                if len(observations) >= C.LEARNING_BATCH_SIZE:
                    print('adding batch to learner')
                    batch = data_loader.make_batch(observations, actions_taken, reward)

                    learner.set_curr_iter(curr_iter)
                    learner.take_batch(batch)

                    is_flush_needed = curr_iter == self.n_episodes * self.n_steps - 1
                    if curr_iter % C.LEARNING_PERIOD == 0 or is_flush_needed:
                        print('Launching learning procedure...')
                        learner.learn()

                obs, reward, done, _ = env.step(chosen_action)
                episode_reward += reward
                if done:
                    break

            self._end_of_episode_handler(episode_idx, step_idx, episode_reward)
コード例 #5
0
    def loop(self):
        W = self.load_schema_matrices()
        _, R, _ = HardcodedSchemaVectors.gen_schema_matrices()

        env = self.env_class(**self.env_params)
        planner = SchemaNetwork()
        visualizer = Visualizer(None, None, None)

        for episode_idx in range(self.n_episodes):
            env.reset()
            reward = 0
            episode_reward = 0

            frame_stack = deque(maxlen=self.FRAME_STACK_SIZE)
            actions = deque()

            planning_timer = 0
            emergency_replanning_timer = None

            for step_idx in range(self.n_steps):
                curr_iter = episode_idx * self.n_steps + step_idx
                print('\ncurr_iter: {}'.format(curr_iter))

                obs = EntityExtractor.extract(env)
                frame_stack.append(obs)

                if self.VISUALIZE_STATE:
                    # visualize env state
                    visualizer.set_iter(curr_iter)
                    visualizer.visualize_env_state(obs)

                can_run_planner = len(frame_stack) == self.FRAME_STACK_SIZE
                #is_planning_needed = len(actions) == 0 and emergency_replanning_timer is None \
                #                     or emergency_replanning_timer == 0
                is_planning_needed = (planning_timer == 0)

                if is_planning_needed and can_run_planner:
                    emergency_replanning_timer = None

                    planner.set_weights(W, R)
                    planner.set_curr_iter(curr_iter)

                    planned_actions = planner.plan_actions(frame_stack)
                    if planned_actions is not None:
                        actions.clear()
                        actions.extend(planned_actions)

                    planning_timer = self.PLANNING_PERIOD

                if planning_timer > 0:
                    planning_timer -= 1

                if actions:
                    action = actions.popleft()
                else:
                    action = 0

                    if can_run_planner:
                        if emergency_replanning_timer is None:
                            emergency_replanning_timer = self.EMERGENCY_REPLANNING_PERIOD
                        emergency_replanning_timer -= 1

                obs, reward, done, _ = env.step(action)
                episode_reward += reward
                if done:
                    print('END_OF_EPISODE, step_idx == {}'.format(step_idx))
                    break

            self._end_of_episode_handler(episode_idx, step_idx, episode_reward)
コード例 #6
0
ファイル: run_agent.py プロジェクト: sidorovTV/schema-rl
    def loop(self, logger):
        env = self._env_class(**self._env_params)
        shaper = Shaper()
        visualizer = Visualizer(None, None, None)

        learner = GreedySchemaLearner()
        if C.DO_PRELOAD_DUMP_PARAMS:
            W_pos, W_neg, R = self._load_dumped_params()
            learner.set_params(W_pos, W_neg, R)
        planner = SchemaNetwork()

        curr_iter = 0

        for episode_idx in range(self._n_max_episodes):
            env.reset()
            reward = 0
            episode_reward = 0
            step_idx = 0

            learning_handler = LearningHandler(learner, shaper,
                                               self._n_max_steps)
            planning_handler = PlanningHandler(planner, env)

            for step_idx in range(self._n_max_steps):
                if curr_iter % self._print_freq == 0:
                    print('\ncurr_iter: {}'.format(curr_iter))

                obs = EntityExtractor.extract(env)
                if C.VISUALIZE_STATE:
                    visualizer.set_iter(curr_iter)
                    visualizer.visualize_env_state(obs)

                # --- planning ---
                W_pos, W_neg, R = learner.get_params()

                if C.DO_PRELOAD_HANDCRAFTED_ATTRIBUTE_PARAMS:
                    W_pos = self.hc_W_pos
                    W_neg = self.hc_W_neg
                if C.DO_PRELOAD_HANDCRAFTED_REWARD_PARAMS:
                    R = self.hc_R

                for params in (W_pos, W_neg, R):
                    for idx, matrix in enumerate(params):
                        if not matrix.size:
                            params[idx] = np.ones((C.SCHEMA_VEC_SIZE, 1),
                                                  dtype=bool)

                chosen_action = planning_handler.plan(obs, W_pos, W_neg, R,
                                                      curr_iter, reward)

                # --- learning ---
                learning_handler.learn(obs, chosen_action, reward, step_idx,
                                       curr_iter)

                obs, reward, done, _ = env.step(chosen_action)
                curr_iter += 1
                episode_reward += reward
                if done:
                    learning_handler.flush()
                    break

            record = {
                'episode_idx': episode_idx,
                'n_steps_taken': step_idx,
                'end_iter': curr_iter,
                'episode_reward': episode_reward
            }
            self._end_of_episode_handler(logger, record)
            if curr_iter >= self._n_max_iters:
                break
コード例 #7
0
ファイル: schema_learner.py プロジェクト: sidorovTV/schema-rl
class GreedySchemaLearner:
    Batch = namedtuple('Batch', ['x', 'y_creation', 'y_destruction', 'r'])
    CREATION_T, DESTRUCTION_T, REWARD_T = range(3)
    ATTR_SCHEMA_TYPES = (CREATION_T, DESTRUCTION_T)

    @Visualizer.measure_time('Learner')
    def __init__(self):
        self._params = [[ParamMatrix() for _ in range(C.N_PREDICTABLE_ATTRIBUTES)]
                        for _ in range(2)]
        self._R = ParamMatrix()

        self._buff = []
        self._replay = self.Batch(np.empty((0, C.SCHEMA_VEC_SIZE), dtype=bool),
                                  np.empty((0, C.N_PREDICTABLE_ATTRIBUTES), dtype=bool),
                                  np.empty((0, C.N_PREDICTABLE_ATTRIBUTES), dtype=bool),
                                  np.empty(0, dtype=bool))

        self._attr_mip_models = [[MipModel() for _ in range(C.N_PREDICTABLE_ATTRIBUTES)]
                                 for _ in range(2)]
        self._reward_mip_model = MipModel()
        self._solved = []

        self._curr_iter = None
        self._visualizer = Visualizer(None, None, None)

    def set_curr_iter(self, curr_iter):
        self._curr_iter = curr_iter
        self._visualizer.set_iter(curr_iter)

    def set_params(self, W_pos, W_neg, R):
        for schema_type, params in zip(self.ATTR_SCHEMA_TYPES, (W_pos, W_neg)):
            for attr_idx in range(C.N_PREDICTABLE_ATTRIBUTES):
                self._params[schema_type][attr_idx].set_matrix(params[attr_idx])

        self._R.set_matrix(R[0])

    def _handle_duplicates(self, batch, return_index=False):
        augmented_entities, *rest = batch
        samples, idx = np.unique(augmented_entities, axis=0, return_index=True)
        out = self.Batch(samples, *[vec[idx] for vec in rest])
        if return_index:
            return out, idx
        return out

    def take_batch(self, batch):
        for part in batch:
            assert part.dtype == bool

        if batch.x.size:
            assert np.all(batch.r == batch.r[0])
            filtered_batch = self._handle_duplicates(batch)
            self._buff.append(filtered_batch)

    def _get_buff_batch(self):
        out = None
        if self._buff:
            # sort buff to keep r = 0 entries
            self._buff = sorted(self._buff, key=lambda b: b.r[0])

            batch = self.Batch(*[np.concatenate(minibatches_part, axis=0)
                                 for minibatches_part in zip(*self._buff)])

            out = self._handle_duplicates(batch)
            assert isinstance(out, self.Batch)
            self._buff.clear()
        return out

    def _add_to_replay_and_constraints_buff(self, batch):
        batch_size = len(batch.x)
        old_replay_size = len(self._replay.x)

        # concatenate replay + batch
        concat_batch = self.Batch(*[np.concatenate((a, b), axis=0)
                                    for a, b in zip(self._replay, batch)])
        # remove duplicates
        self._replay, unique_idx = self._handle_duplicates(concat_batch, return_index=True)

        concat_size = len(concat_batch.x)

        # find r = 0 duplicates (they can only locate in batch)
        duplicates_mask_concat = np.ones(concat_size, dtype=bool)
        duplicates_mask_concat[unique_idx] = False
        zero_reward_mask_concat = (concat_batch.r == 0)
        reward_renew_indices = np.nonzero(duplicates_mask_concat & zero_reward_mask_concat)[0]
        assert (reward_renew_indices >= old_replay_size).all()
        samples_to_update = concat_batch.x[reward_renew_indices]

        # update rewards to zero
        replay_indices_to_update = []
        for sample in samples_to_update:
            new_replay_indices = np.nonzero((self._replay.x == sample).all(axis=1))[0]
            assert len(new_replay_indices) == 1
            new_replay_idx = new_replay_indices[0]
            if self._replay.r[new_replay_idx] != 0:
                self._replay.r[new_replay_idx] = 0
                replay_indices_to_update.append(new_replay_idx)
        n_updated_indices = len(replay_indices_to_update)
        if n_updated_indices:
            print('Nullified rewards of {} old samples.'.format(n_updated_indices))

        # find non-duplicate indices in new batch (batch-based indexing)
        batch_mask_of_concat = unique_idx >= old_replay_size
        new_non_duplicate_indices = unique_idx[batch_mask_of_concat] - old_replay_size

        # find indices that will index constraints_buff + new_batch_unique synchronously with replay
        constraints_unique_idx = unique_idx.copy()
        constraints_unique_idx[batch_mask_of_concat] = old_replay_size + np.arange(len(new_non_duplicate_indices))

        for schema_type in self.ATTR_SCHEMA_TYPES:
            for attr_idx in range(C.N_PREDICTABLE_ATTRIBUTES):
                y = batch.y_creation if schema_type == self.CREATION_T else batch.y_destruction
                attr_batch = (batch.x[new_non_duplicate_indices],
                              y[new_non_duplicate_indices, attr_idx])
                self._attr_mip_models[schema_type][attr_idx].add_to_constraints_buff(attr_batch, constraints_unique_idx)

        reward_batch = (batch.x[new_non_duplicate_indices],
                        batch.r[new_non_duplicate_indices])
        self._reward_mip_model.add_to_constraints_buff(reward_batch, constraints_unique_idx,
                                                       replay_renewed_indices=replay_indices_to_update)

    def _get_replay_batch(self):
        if self._replay.x.size:
            out = self._replay
        else:
            out = None
        return out

    def _predict_attribute_delta(self, augmented_entities, attr_idx, attr_schema_type):
        assert augmented_entities.dtype == bool
        delta = ~(~augmented_entities @ self._params[attr_schema_type][attr_idx].mult_me)
        return delta

    def _predict_reward(self, augmented_entities):
        assert augmented_entities.dtype == bool
        reward_prediction = ~(~augmented_entities @ self._R.mult_me)
        return reward_prediction

    def _delete_incorrect_schemas(self, batch):
        augmented_entities, target_creation, target_destruction, rewards = batch
        for param_type in self.ATTR_SCHEMA_TYPES:
            for attr_idx in range(C.N_PREDICTABLE_ATTRIBUTES):
                target = target_creation if param_type == self.CREATION_T else target_destruction

                attr_delta = self._predict_attribute_delta(augmented_entities, attr_idx,
                                                           attr_schema_type=param_type)
                # false positive predictions
                mispredicted_samples_mask = attr_delta.any(axis=1) & ~target[:, attr_idx]

                incorrect_schemas_mask = attr_delta[mispredicted_samples_mask, :].any(axis=0)
                incorrect_schemas_indices = np.nonzero(incorrect_schemas_mask)[0]

                assert incorrect_schemas_indices.ndim == 1
                n_incorrect_attr_schemas = incorrect_schemas_indices.size
                if n_incorrect_attr_schemas:
                    self._params[param_type][attr_idx].purge_vectors(incorrect_schemas_indices)
                    print('Deleted incorrect attr ({}) delta schemas: {} of {}'.format(
                        param_type, n_incorrect_attr_schemas, C.ENTITY_NAMES[attr_idx]))

        # regarding reward

        reward_prediction = self._predict_reward(augmented_entities)

        # false positive predictions
        mispredicted_samples_mask = reward_prediction.any(axis=1) & ~rewards

        incorrect_schemas_mask = reward_prediction[mispredicted_samples_mask, :].any(axis=0)
        incorrect_schemas_indices = np.nonzero(incorrect_schemas_mask)[0]

        assert incorrect_schemas_indices.ndim == 1
        n_incorrect_reward_schemas = incorrect_schemas_indices.size
        if n_incorrect_reward_schemas:
            self._R.purge_vectors(incorrect_schemas_indices)
            print('Deleted incorrect reward schemas: {}'.format(n_incorrect_reward_schemas))

        return n_incorrect_attr_schemas, n_incorrect_reward_schemas

    def _find_cluster(self, zp_pl_mask, zp_nl_mask, augmented_entities, target, attr_idx, opt_model):
        """
        augmented_entities: zero-predicted only
        target: scalar vector
        """
        assert augmented_entities.dtype == np.int
        assert target.dtype == np.int

        # find all entries, that can be potentially solved (have True labels)
        candidates = augmented_entities[zp_pl_mask]

        if not candidates.size:
            return None

        print('finding cluster...    zp pos samples: {}'.format(candidates.shape[0]))
        # print('augmented_entities: {}'.format(augmented_entities.shape[0]))

        zp_pl_indices = np.nonzero(zp_pl_mask)[0]

        # sample one entry and add it's idx to 'solved'
        idx = np.random.choice(zp_pl_indices)
        self._solved.append(idx)

        # resample candidates
        zp_pl_mask[idx] = False
        zp_pl_indices = np.nonzero(zp_pl_mask)[0]
        candidates = augmented_entities[zp_pl_mask]

        # solve LP
        objective_coefficients = (1 - candidates).sum(axis=0)
        objective_coefficients = list(objective_coefficients)

        new_schema_vector = opt_model.optimize(objective_coefficients, zp_nl_mask, self._solved)

        if new_schema_vector is None:
            print('Cannot find cluster!')
            return None

        # add all samples that are solved by just learned schema vector
        if candidates.size:
            new_predicted_attribute = (1 - candidates) @ new_schema_vector
            cluster_members_mask = np.isclose(new_predicted_attribute, 0, rtol=0, atol=C.ADDING_SCHEMA_TOLERANCE)
            n_new_members = np.count_nonzero(cluster_members_mask)

            if n_new_members:
                print('Also added to solved: {}'.format(n_new_members))
                self._solved.extend(zp_pl_indices[cluster_members_mask])

        return new_schema_vector

    def _simplify_schema(self, zp_nl_mask, schema_vector, opt_model):
        objective_coefficients = [1] * len(schema_vector)

        new_schema_vector = opt_model.optimize(objective_coefficients, zp_nl_mask, self._solved)
        assert new_schema_vector is not None
        return new_schema_vector

    def _binarize_schema(self, schema_vector):
        threshold = 0.5
        return schema_vector > threshold

    def _generate_new_schema(self, augmented_entities, targets, attr_idx, schema_type):
        if schema_type in self.ATTR_SCHEMA_TYPES:
            target = targets[:, attr_idx].astype(np.int, copy=False)
            prediction = self._predict_attribute_delta(augmented_entities, attr_idx, schema_type)
            opt_model = self._attr_mip_models[schema_type][attr_idx]
        elif schema_type == self.REWARD_T:
            target = targets.astype(np.int, copy=False)
            prediction = self._predict_reward(augmented_entities)
            opt_model = self._reward_mip_model
        else:
            assert False

        augmented_entities = augmented_entities.astype(np.int, copy=False)

        # sample only entries with zero-prediction
        zp_mask = ~prediction.any(axis=1)
        pl_mask = (target == 1)
        # pos and neg labels' masks
        zp_pl_mask = zp_mask & pl_mask
        zp_nl_mask = zp_mask & ~pl_mask

        new_schema_vector = self._find_cluster(zp_pl_mask, zp_nl_mask,
                                               augmented_entities, target, attr_idx,
                                               opt_model)
        if new_schema_vector is None:
            return None

        new_schema_vector = self._simplify_schema(zp_nl_mask, new_schema_vector, opt_model)
        new_schema_vector = self._binarize_schema(new_schema_vector)

        self._solved.clear()

        return new_schema_vector

    def get_params(self):
        W_pos = [W.get_matrix() for W in self._params[self.CREATION_T]]
        W_neg = [W.get_matrix() for W in self._params[self.DESTRUCTION_T]]
        R = [self._R.get_matrix()]
        return W_pos, W_neg, R

    def _dump_params(self):
        dir_name = 'dump'
        os.makedirs(dir_name, exist_ok=True)

        W_pos, W_neg, R = self.get_params()
        names = ['w_pos', 'w_neg', 'r']

        for params, name in zip((W_pos, W_neg, R), names):
            for idx, matrix in enumerate(params):
                file_name = name + '_{}'.format(idx)
                path = os.path.join(dir_name, file_name)
                np.save(path, matrix, allow_pickle=False)

    @Visualizer.measure_time('learn()')
    def learn(self):
        print('Launching learning procedure...')

        # get full batch from buffer
        buff_batch = self._get_buff_batch()
        if buff_batch is not None:
            self._add_to_replay_and_constraints_buff(buff_batch)
            self._delete_incorrect_schemas(buff_batch)

        # get all data to learn on
        replay_batch = self._get_replay_batch()
        if replay_batch is None:
            return

        # check if replay consistent
        # a, b = self._delete_incorrect_schemas(replay_batch)
        # assert a == 0 and b == 0

        augmented_entities, targets_construction, targets_destruction, rewards = replay_batch

        if C.DO_LEARN_ATTRIBUTE_PARAMS:
            for schema_type in self.ATTR_SCHEMA_TYPES:
                params = self._params[schema_type]
                targets = targets_construction if schema_type == self.CREATION_T else targets_destruction

                for attr_idx in range(C.N_PREDICTABLE_ATTRIBUTES):
                    while params[attr_idx].has_free_space():
                        new_schema_vec = self._generate_new_schema(augmented_entities, targets, attr_idx, schema_type)
                        if new_schema_vec is None:
                            break
                        params[attr_idx].add_vector(new_schema_vec)

        if C.DO_LEARN_REWARD_PARAMS:
            while self._R.has_free_space():
                new_schema_vec = self._generate_new_schema(augmented_entities, rewards, None, self.REWARD_T)
                if new_schema_vec is None:
                    break
                self._R.add_vector(new_schema_vec)

        self._dump_params()
        if C.VISUALIZE_SCHEMAS:
            W_pos, W_neg, R = self.get_params()
            self._visualizer.visualize_schemas(W_pos, W_neg, R)
コード例 #8
0
ファイル: schema_learner.py プロジェクト: pkuderov/schema-rl
class GreedySchemaLearner:
    Batch = namedtuple('Batch', ['x', 'y', 'r'])

    def __init__(self):
        self._W = [np.ones((C.SCHEMA_VEC_SIZE, C.L), dtype=bool)
                   for _ in range(C.N_PREDICTABLE_ATTRIBUTES)]
        self._n_attr_schemas = np.ones(C.N_PREDICTABLE_ATTRIBUTES, dtype=np.int)

        self._R = np.ones((C.SCHEMA_VEC_SIZE, C.L), dtype=bool)
        self._n_reward_schemas = 1

        self._buff = []
        self._replay = self.Batch(np.empty((0, C.SCHEMA_VEC_SIZE), dtype=bool),
                                  np.empty((0, C.N_PREDICTABLE_ATTRIBUTES), dtype=bool),
                                  np.empty((0), dtype=bool))

        self._attr_mip_models = [MipModel() for _ in range(C.N_PREDICTABLE_ATTRIBUTES)]
        self._reward_mip_model = MipModel()
        self._solved = []

        self._curr_iter = None
        self._visualizer = Visualizer(None, None, None)

    def set_curr_iter(self, curr_iter):
        self._curr_iter = curr_iter
        self._visualizer.set_iter(curr_iter)

    def take_batch(self, batch):
        for part in batch:
            assert part.dtype == bool

        if batch.x.size:
            assert np.all(batch.r == batch.r[0])
            x, y, r = self._handle_duplicates(batch.x, batch.y, batch.r)
            filtered_batch = self.Batch(x, y, r)
            self._buff.append(filtered_batch)

    def _handle_duplicates(self, augmented_entities, target, rewards, return_index=False):
        samples, idx = np.unique(augmented_entities, axis=0, return_index=True)
        out = [samples, target[idx, :], rewards[idx]]
        if return_index:
            out.append(idx)
        return tuple(out)

    def _get_buff_batch(self):
        out = None
        if self._buff:
            # sort buff to keep r = 0 entries
            self._buff = sorted(self._buff, key=lambda batch: batch.r[0])

            x, y, r = zip(*self._buff)
            augmented_entities = np.concatenate(x, axis=0)
            targets = np.concatenate(y, axis=0)
            rewards = np.concatenate(r, axis=0)

            augmented_entities, targets, rewards = self._handle_duplicates(augmented_entities, targets, rewards)
            out = self.Batch(augmented_entities, targets, rewards)

            self._buff.clear()
        return out

    def _add_to_replay_and_constraints_buff(self, batch):
        batch_size = len(batch.x)
        old_replay_size = len(self._replay.x)

        # concatenate replay + batch
        x_concat = np.concatenate((self._replay.x, batch.x), axis=0)
        y_concat = np.concatenate((self._replay.y, batch.y), axis=0)
        r_concat = np.concatenate((self._replay.r, batch.r), axis=0)

        concat_size = len(x_concat)

        # remove duplicates
        x_filtered, y_filtered, r_filtered, unique_idx = self._handle_duplicates(
            x_concat, y_concat, r_concat, return_index=True)

        # check if np.unique prefer first elements
        mask_concat = np.zeros(concat_size, dtype=bool)
        mask_concat[unique_idx] = True
        assert mask_concat[:old_replay_size].all()

        self._replay = self.Batch(x_filtered, y_filtered, r_filtered)

        # find r = 0 duplicates (they can only locate in batch)
        duplicates_mask_concat = np.ones(concat_size, dtype=bool)
        duplicates_mask_concat[unique_idx] = False
        zero_reward_mask_concat = (r_concat == 0)
        reward_renew_indices = np.nonzero(duplicates_mask_concat & zero_reward_mask_concat)[0]
        assert (reward_renew_indices >= old_replay_size).all()
        samples_to_update = x_concat[reward_renew_indices]

        # update rewards to zero
        replay_indices_to_update = []
        for sample in samples_to_update:
            new_replay_indices = np.nonzero((self._replay.x == sample).all(axis=1))[0]
            assert len(new_replay_indices) == 1
            new_replay_idx = new_replay_indices[0]
            if self._replay.r[new_replay_idx] != 0:
                self._replay.r[new_replay_idx] = 0
                replay_indices_to_update.append(new_replay_idx)
        print('Nullified rewards of {} old samples.'.format(len(replay_indices_to_update)))

        # find non-duplicate indices in new batch (batch-based indexing)
        batch_mask_of_concat = unique_idx >= old_replay_size
        new_non_duplicate_indices = unique_idx[batch_mask_of_concat] - old_replay_size

        # find indices that will index constraints_buff + new_batch_unique synchronously with replay
        constraints_unique_idx = unique_idx.copy()
        constraints_unique_idx[batch_mask_of_concat] = old_replay_size + np.arange(len(new_non_duplicate_indices))

        for attr_idx in range(C.N_PREDICTABLE_ATTRIBUTES):
            attr_batch = (batch.x[new_non_duplicate_indices],
                          batch.y[new_non_duplicate_indices, attr_idx])
            self._attr_mip_models[attr_idx].add_to_constraints_buff(attr_batch, constraints_unique_idx)

        reward_batch = (batch.x[new_non_duplicate_indices],
                        batch.r[new_non_duplicate_indices])
        self._reward_mip_model.add_to_constraints_buff(reward_batch, constraints_unique_idx,
                                                       replay_renewed_indices=replay_indices_to_update)

    def _get_replay_batch(self):
        if self._replay.x.size:
            out = self._replay
        else:
            out = None
        return out

    def _predict_attribute(self, augmented_entities, attr_idx):
        assert augmented_entities.dtype == bool

        n_schemas = self._n_attr_schemas[attr_idx]
        W = self._W[attr_idx][:, :n_schemas]
        attribute_prediction = ~(~augmented_entities @ W)
        return attribute_prediction

    def _predict_reward(self, augmented_entities):
        assert augmented_entities.dtype == bool

        R = self._R[:, :self._n_reward_schemas]
        reward_prediction = ~(~augmented_entities @ R)
        return reward_prediction

    def _add_attr_schema_vec(self, attr_idx, schema_vec):
        vec_idx = self._n_attr_schemas[attr_idx]
        if vec_idx < C.L:
            self._W[attr_idx][:, vec_idx] = schema_vec
            self._n_attr_schemas[attr_idx] += 1

    def _add_reward_schema_vec(self, schema_vec):
        vec_idx = self._n_reward_schemas
        if vec_idx < C.L:
            self._R[:, vec_idx] = schema_vec
            self._n_reward_schemas += 1

    def _purge_matrix_columns(self, matrix, col_indices):
        n_cols_purged = len(col_indices)

        if n_cols_purged:
            col_size, _ = matrix.shape
            matrix = np.delete(matrix, col_indices, axis=1)
            padding = np.ones((col_size, n_cols_purged), dtype=bool)
            matrix = np.hstack((matrix, padding))

        return matrix, n_cols_purged

    def _delete_attr_schema_vectors(self, attr_idx, vec_indices):
        matrix, n_cols_purged = self._purge_matrix_columns(self._W[attr_idx], vec_indices)
        if n_cols_purged:
            self._W[attr_idx] = matrix
            self._n_attr_schemas[attr_idx] -= n_cols_purged
        return n_cols_purged

    def _delete_reward_schema_vectors(self, vec_indices):
        matrix, n_cols_purged = self._purge_matrix_columns(self._R, vec_indices)
        if n_cols_purged:
            self._R = matrix
            self._n_reward_schemas -= n_cols_purged
        return n_cols_purged

    def _delete_incorrect_schemas(self, batch):
        augmented_entities, targets, rewards = batch
        for attr_idx in range(C.N_PREDICTABLE_ATTRIBUTES):
            attr_prediction = self._predict_attribute(augmented_entities, attr_idx)

            # false positive predictions
            mispredicted_samples_mask = attr_prediction.any(axis=1) & ~targets[:, attr_idx]

            incorrect_schemas_mask = attr_prediction[mispredicted_samples_mask, :].any(axis=0)
            incorrect_schemas_indices = np.nonzero(incorrect_schemas_mask)[0]

            assert incorrect_schemas_indices.ndim == 1

            n_schemas_deleted = self._delete_attr_schema_vectors(attr_idx, incorrect_schemas_indices)

            if n_schemas_deleted:
                print('Deleted incorrect attr schemas: {} of {}'.format(
                    n_schemas_deleted, C.ENTITY_NAMES[attr_idx]))

        reward_prediction = self._predict_reward(augmented_entities)

        # false positive predictions
        mispredicted_samples_mask = reward_prediction.any(axis=1) & ~rewards

        incorrect_schemas_mask = reward_prediction[mispredicted_samples_mask, :].any(axis=0)
        incorrect_schemas_indices = np.nonzero(incorrect_schemas_mask)[0]

        assert incorrect_schemas_indices.ndim == 1
        n_schemas_deleted = self._delete_reward_schema_vectors(incorrect_schemas_indices)

        if n_schemas_deleted:
            print('Deleted incorrect reward schemas: {}'.format(n_schemas_deleted))

    def _find_cluster(self, zp_pl_mask, zp_nl_mask, augmented_entities, target, attr_idx, opt_model):
        """
        augmented_entities: zero-predicted only
        target: scalar vector
        """
        assert augmented_entities.dtype == np.int
        assert target.dtype == np.int

        # find all entries, that can be potentially solved (have True labels)
        candidates = augmented_entities[zp_pl_mask]

        print('finding cluster...')
        print('augmented_entities: {}'.format(augmented_entities.shape[0]))
        print('zp pos samples: {}'.format(candidates.shape[0]))

        if not candidates.size:
            return None

        zp_pl_indices = np.nonzero(zp_pl_mask)[0]

        #if not is_reward:
        # sample one entry and add it's idx to 'solved'
        idx = np.random.choice(zp_pl_indices)
        self._solved.append(idx)

        # resample candidates
        zp_pl_mask[idx] = False
        zp_pl_indices = np.nonzero(zp_pl_mask)[0]
        candidates = augmented_entities[zp_pl_mask]

        # solve LP
        objective_coefficients = (1 - candidates).sum(axis=0)
        objective_coefficients = list(objective_coefficients)

        new_schema_vector = opt_model.optimize(objective_coefficients, zp_nl_mask, self._solved)

        if new_schema_vector is None:
            print('!!! Cannot find cluster !!!')
            return None

        # add all samples that are solved by just learned schema vector
        if candidates.size:
            new_predicted_attribute = (1 - candidates) @ new_schema_vector
            cluster_members_mask = np.isclose(new_predicted_attribute, 0, rtol=0, atol=C.ADDING_SCHEMA_TOLERANCE)
            n_new_members = np.count_nonzero(cluster_members_mask)

            if n_new_members:
                print('Also added to solved: {}'.format(n_new_members))
                self._solved.extend(zp_pl_indices[cluster_members_mask])

        return new_schema_vector

    def _simplify_schema(self, zp_nl_mask, schema_vector, opt_model):
        objective_coefficients = [1] * len(schema_vector)

        new_schema_vector = opt_model.optimize(objective_coefficients, zp_nl_mask, self._solved)
        assert new_schema_vector is not None

        return new_schema_vector

    def _binarize_schema(self, schema_vector):
        threshold = 0.5
        return schema_vector > threshold

    def _generate_new_schema(self, augmented_entities, targets, attr_idx, is_reward=False):
        if not is_reward:
            target = targets[:, attr_idx].astype(np.int, copy=False)
            prediction = self._predict_attribute(augmented_entities, attr_idx)
            opt_model = self._attr_mip_models[attr_idx]
        else:
            target = targets.astype(np.int, copy=False)
            prediction = self._predict_reward(augmented_entities)
            opt_model = self._reward_mip_model

        augmented_entities = augmented_entities.astype(np.int, copy=False)

        # sample only entries with zero-prediction
        zp_mask = ~prediction.any(axis=1)
        pl_mask = target == 1
        # pos and neg labels' masks
        zp_pl_mask = zp_mask & pl_mask
        zp_nl_mask = zp_mask & ~pl_mask

        new_schema_vector = self._find_cluster(zp_pl_mask, zp_nl_mask,
                                               augmented_entities, target, attr_idx,
                                               opt_model)
        if new_schema_vector is None:
            return None

        new_schema_vector = self._simplify_schema(zp_nl_mask, new_schema_vector, opt_model)
        if new_schema_vector is None:
            print('!!! Cannot simplify !!!')
            return None

        new_schema_vector = self._binarize_schema(new_schema_vector)

        self._solved.clear()

        return new_schema_vector

    def dump_weights(self, learned_W, learned_R):
        dir_name = 'dump'
        os.makedirs(dir_name, exist_ok=True)
        for attr_idx in range(C.N_PREDICTABLE_ATTRIBUTES):
            file_name = 'w_{}.pkl'.format(attr_idx)
            path = os.path.join(dir_name, file_name)
            learned_W[attr_idx].dump(path)

        file_name = 'r_pos.pkl'
        path = os.path.join(dir_name, file_name)
        learned_R[0].dump(path)

    def get_weights(self):
        learned_W = [W[:, ~np.all(W, axis=0)] for W in self._W]
        if any(w.size == 0 for w in learned_W):
            learned_W = None

        learned_R = [self._R[:, ~np.all(self._R, axis=0)]]
        if learned_R[0].size == 0:
            learned_R = None
        return learned_W, learned_R

    def learn(self):
        # get full batch from buffer
        buff_batch = self._get_buff_batch()
        if buff_batch is not None:
            self._delete_incorrect_schemas(buff_batch)
            self._add_to_replay_and_constraints_buff(buff_batch)

        # get all data to learn on
        replay_batch = self._get_replay_batch()
        if replay_batch is None:
            return

        augmented_entities, targets, rewards = replay_batch

        if not C.USE_HANDCRAFTED_ATTRIBUTE_SCHEMAS:
            for attr_idx in range(C.N_PREDICTABLE_ATTRIBUTES):
                while self._n_attr_schemas[attr_idx] < C.L:
                    new_schema_vec = self._generate_new_schema(augmented_entities, targets, attr_idx)
                    if new_schema_vec is None:
                        break
                    self._add_attr_schema_vec(attr_idx, new_schema_vec)

        if not C.USE_HANDCRAFTED_REWARD_SCHEMAS:
            while self._n_reward_schemas < C.L:
                new_schema_vec = self._generate_new_schema(augmented_entities, rewards, None, is_reward=True)
                if new_schema_vec is None:
                    break
                self._add_reward_schema_vec(new_schema_vec)

        if C.VISUALIZE_SCHEMAS:
            learned_W = [W[:, ~np.all(W, axis=0)] for W in self._W]
            learned_R = [self._R[:, ~np.all(self._R, axis=0)]]
            self._visualizer.visualize_schemas(learned_W, learned_R)
            self.dump_weights(learned_W, learned_R)

        if C.VISUALIZE_REPLAY_BUFFER:
            self._visualizer.visualize_replay_buffer(self._replay)