Esempio n. 1
0
    def do_train_forced_reading(self, agent, train_dataset, tune_dataset,
                                experiment_name):
        """ Perform training """

        assert isinstance(
            agent, ReadPointerAgent
        ), "This learning algorithm works only with READPointerAgent"

        dataset_size = len(train_dataset)

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)
            action_counts = dict()
            action_counts[ReadPointerAgent.READ_MODE] = [0] * 2
            action_counts[ReadPointerAgent.
                          ACT_MODE] = [0] * self.action_space.num_actions()

            # Test on tuning data
            agent.test_forced_reading(tune_dataset,
                                      tensorboard=self.tensorboard)

            batch_replay_items = []
            total_reward = 0
            episodes_in_batch = 0

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)
                    logging.info("Training data action counts %r",
                                 action_counts)

                num_actions = 0
                max_num_actions = len(data_point.get_trajectory())
                max_num_actions += self.constants["max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)
                oracle_segments = data_point.get_instruction_oracle_segmented()
                pose = int(metadata["y_angle"] / 15.0)
                state = AgentObservedState(instruction=data_point.instruction,
                                           config=self.config,
                                           constants=self.constants,
                                           start_image=image,
                                           previous_action=None,
                                           pose=pose)

                per_segment_budget = int(max_num_actions /
                                         len(oracle_segments))
                num_segment_actions = 0

                mode = ReadPointerAgent.READ_MODE
                current_segment_ix = 0

                while True:

                    if mode == ReadPointerAgent.READ_MODE:
                        # Find the number of tokens to read for the gold segment
                        num_segment_size = len(
                            oracle_segments[current_segment_ix])
                        current_segment_ix += 1
                        for i in range(0, num_segment_size):
                            state = state.update_on_read()
                        mode = ReadPointerAgent.ACT_MODE

                    elif mode == ReadPointerAgent.ACT_MODE:

                        # Sample action using the policy
                        # Generate probabilities over actions
                        probabilities = list(
                            torch.exp(self.model.get_probs(state, mode).data))

                        # Use test policy to get the action
                        action = gp.sample_action_from_prob(probabilities)
                        action_counts[mode][action] += 1

                        # deal with act mode boundary conditions
                        if num_actions >= max_num_actions:
                            forced_stop = True
                            break

                        elif action == agent.action_space.get_stop_action_index(
                        ) or num_segment_actions > per_segment_budget:
                            if state.are_tokens_left_to_be_read():
                                # reward = self._calc_reward_act_halt(state)
                                if metadata["error"] < 5.0:
                                    reward = 1.0
                                else:
                                    reward = -1.0

                                # Add to replay memory
                                replay_item = ReplayMemoryItem(
                                    state,
                                    agent.action_space.get_stop_action_index(),
                                    reward, mode)
                                if action == agent.action_space.get_stop_action_index(
                                ):
                                    batch_replay_items.append(replay_item)

                                mode = ReadPointerAgent.READ_MODE
                                agent.server.force_goal_update()
                                state = state.update_on_act_halt()
                                num_segment_actions = 0
                            else:
                                if action == agent.action_space.get_stop_action_index(
                                ):
                                    forced_stop = False
                                else:  # stopping due to per segment budget exhaustion
                                    forced_stop = True
                                break

                        else:
                            image, reward, metadata = agent.server.send_action_receive_feedback(
                                action)

                            # Store it in the replay memory list
                            replay_item = ReplayMemoryItem(state,
                                                           action,
                                                           reward,
                                                           mode=mode)
                            batch_replay_items.append(replay_item)

                            # Update the agent state
                            pose = int(metadata["y_angle"] / 15.0)
                            state = state.update(image, action, pose=pose)

                            num_actions += 1
                            num_segment_actions += 1
                            total_reward += reward

                    else:
                        raise AssertionError(
                            "Mode should be either read or act. Unhandled mode: "
                            + str(mode))

                assert mode == ReadPointerAgent.ACT_MODE, "Agent should end on Act Mode"

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state, agent.action_space.get_stop_action_index(),
                        reward, mode)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                episodes_in_batch += 1
                if episodes_in_batch == 1:
                    loss_val = self.do_update(batch_replay_items)
                    batch_replay_items = []
                    entropy_val = float(self.entropy.data[0])
                    self.tensorboard.log(entropy_val, loss_val, total_reward)
                    total_reward = 0
                    episodes_in_batch = 0

                self.tensorboard.log_train_error(metadata["error"])

            # Save the model
            self.model.save_model(
                experiment_name +
                "/read_pointer_forced_reading_contextual_bandit_resnet_epoch_"
                + str(epoch))

            logging.info("Training data action counts %r", action_counts)
Esempio n. 2
0
    def do_train_forced_reading(self, agent, train_dataset, tune_dataset, experiment_name):
        """ Perform training """

        assert isinstance(agent, ReadPointerAgent), "This learning algorithm works only with READPointerAgent"

        dataset_size = len(train_dataset)

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)
            action_counts = dict()
            action_counts[ReadPointerAgent.READ_MODE] = [0] * 2
            action_counts[ReadPointerAgent.ACT_MODE] = [0] * self.action_space.num_actions()

            # Test on tuning data
            agent.test_forced_reading(tune_dataset, tensorboard=self.tensorboard)

            batch_replay_items = []
            total_reward = 0
            episodes_in_batch = 0

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix, dataset_size)
                    logging.info("Training data action counts %r", action_counts)

                image, metadata = agent.server.reset_receive_feedback(data_point)
                pose = int(metadata["y_angle"] / 15.0)
                oracle_segments = data_point.get_instruction_oracle_segmented()
                state = AgentObservedState(instruction=data_point.instruction,
                                           config=self.config,
                                           constants=self.constants,
                                           start_image=image,
                                           previous_action=None,
                                           pose=pose)

                mode = ReadPointerAgent.READ_MODE
                current_segment_ix = 0

                trajectories = data_point.get_sub_trajectory_list()
                action_ix = 0

                while True:

                    if mode == ReadPointerAgent.READ_MODE:
                        # Find the number of tokens to read for the gold segment
                        num_segment_size = len(oracle_segments[current_segment_ix])
                        current_segment_ix += 1
                        for i in range(0, num_segment_size):
                            state = state.update_on_read()
                        mode = ReadPointerAgent.ACT_MODE

                    elif mode == ReadPointerAgent.ACT_MODE:

                        if action_ix == len(trajectories[current_segment_ix - 1]):
                            action = agent.action_space.get_stop_action_index()
                            action_ix = 0
                        else:
                            action = trajectories[current_segment_ix - 1][action_ix]
                            action_ix += 1

                        action_counts[mode][action] += 1

                        if action == agent.action_space.get_stop_action_index():
                            if state.are_tokens_left_to_be_read():
                                # Add to replay memory
                                replay_item = ReplayMemoryItem(state, agent.action_space.get_stop_action_index(),
                                                               1.0, mode)
                                batch_replay_items.append(replay_item)

                                mode = ReadPointerAgent.READ_MODE
                                agent.server.force_goal_update()
                                state = state.update_on_act_halt()
                            else:
                                break
                        else:
                            image, reward, metadata = agent.server.send_action_receive_feedback(action)

                            # Store it in the replay memory list
                            replay_item = ReplayMemoryItem(state, action, 1, mode=mode)
                            batch_replay_items.append(replay_item)

                            # Update the agent state
                            pose = int(metadata["y_angle"] / 15.0)
                            state = state.update(image, action, pose=pose)
                            total_reward += reward

                    else:
                        raise AssertionError("Mode should be either read or act. Unhandled mode: " + str(mode))

                assert mode == ReadPointerAgent.ACT_MODE, "Agent should end on Act Mode"

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback()
                total_reward += reward

                # Store it in the replay memory list
                replay_item = ReplayMemoryItem(state, agent.action_space.get_stop_action_index(), 1, mode)
                batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                episodes_in_batch += 1
                if episodes_in_batch == 1:
                    loss_val = self.do_update(batch_replay_items)
                    batch_replay_items = []
                    entropy_val = float(self.entropy.data[0])
                    self.tensorboard.log(entropy_val, loss_val, total_reward)
                    total_reward = 0
                    episodes_in_batch = 0

                self.tensorboard.log_train_error(metadata["error"])

            # Save the model
            self.model.save_model(
                experiment_name + "/ml_estimation_epoch_" + str(epoch))

            logging.info("Training data action counts %r", action_counts)
Esempio n. 3
0
    def do_train(self, agent, train_dataset, tune_dataset, experiment_name):
        """ Perform training """

        assert isinstance(
            agent, ReadPointerAgent
        ), "This learning algorithm works only with READPointerAgent"

        dataset_size = len(train_dataset)

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)
            action_counts = dict()
            action_counts[ReadPointerAgent.READ_MODE] = [0] * 2
            action_counts[ReadPointerAgent.
                          ACT_MODE] = [0] * self.action_space.num_actions()

            # Test on tuning data
            agent.test(tune_dataset, tensorboard=self.tensorboard)

            batch_replay_items = []
            total_reward = 0
            episodes_in_batch = 0

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)
                    logging.info("Training data action counts %r",
                                 action_counts)

                num_actions = 0
                max_num_actions = len(data_point.get_trajectory())
                max_num_actions += self.constants["max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)
                state = AgentObservedState(instruction=data_point.instruction,
                                           config=self.config,
                                           constants=self.constants,
                                           start_image=image,
                                           previous_action=None)

                mode = ReadPointerAgent.READ_MODE
                last_action_was_halt = False

                instruction = instruction_to_string(
                    data_point.get_instruction(), self.config)
                print "TRAIN INSTRUCTION: %r" % instruction
                print ""

                while True:

                    # Sample action using the policy
                    # Generate probabilities over actions
                    probabilities = list(
                        torch.exp(self.model.get_probs(state, mode).data))

                    # Use test policy to get the action
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[mode][action] += 1

                    if mode == ReadPointerAgent.READ_MODE:
                        # read mode boundary conditions
                        forced_action = False
                        if not state.are_tokens_left_to_be_read():
                            # force halt
                            action = 1
                            forced_action = True
                        elif num_actions >= max_num_actions or last_action_was_halt:
                            # force read
                            action = 0
                            forced_action = True

                        if not forced_action:
                            # Store reward in the replay memory list
                            reward = self._calc_reward_read_mode(state, action)
                            replay_item = ReplayMemoryItem(state,
                                                           action,
                                                           reward,
                                                           mode=mode)
                            batch_replay_items.append(replay_item)

                        if action == 0:
                            last_action_was_halt = False
                            state = state.update_on_read()
                        elif action == 1:
                            last_action_was_halt = True
                            mode = ReadPointerAgent.ACT_MODE
                        else:
                            raise AssertionError(
                                "Read mode only supports two actions: read(0) and halt(1). "
                                + "Found " + str(action))

                    elif mode == ReadPointerAgent.ACT_MODE:
                        # deal with act mode boundary conditions
                        if num_actions >= max_num_actions:
                            forced_stop = True
                            break

                        elif action == agent.action_space.get_stop_action_index(
                        ):
                            if state.are_tokens_left_to_be_read():
                                reward = self._calc_reward_act_halt(state)

                                # Add to replay memory
                                replay_item = ReplayMemoryItem(
                                    state,
                                    agent.action_space.get_stop_action_index(),
                                    reward, mode)
                                batch_replay_items.append(replay_item)

                                mode = ReadPointerAgent.READ_MODE
                                last_action_was_halt = True
                                state = state.update_on_act_halt()
                            else:
                                forced_stop = False
                                break

                        else:
                            image, reward, metadata = agent.server.send_action_receive_feedback(
                                action)

                            # Store it in the replay memory list
                            replay_item = ReplayMemoryItem(state,
                                                           action,
                                                           reward,
                                                           mode=mode)
                            batch_replay_items.append(replay_item)

                            # Update the agent state
                            state = state.update(image, action)

                            num_actions += 1
                            total_reward += reward
                            last_action_was_halt = False

                    else:
                        raise AssertionError(
                            "Mode should be either read or act. Unhandled mode: "
                            + str(mode))

                assert mode == ReadPointerAgent.ACT_MODE, "Agent should end on Act Mode"

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state, agent.action_space.get_stop_action_index(),
                        reward, mode)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                episodes_in_batch += 1
                if episodes_in_batch == 1:
                    loss_val = self.do_update(batch_replay_items)
                    batch_replay_items = []
                    entropy_val = float(self.entropy.data[0])
                    self.tensorboard.log(entropy_val, loss_val, total_reward)
                    total_reward = 0
                    episodes_in_batch = 0

                self.tensorboard.log_train_error(metadata["error"])

            # Save the model
            self.model.save_model(
                experiment_name +
                "/read_pointer_contextual_bandit_resnet_epoch_" + str(epoch))

            logging.info("Training data action counts %r", action_counts)
Esempio n. 4
0
    def do_train_forced_reading(self, agent, train_dataset, tune_dataset,
                                experiment_name):
        """ Perform training """

        assert isinstance(
            agent, ReadPointerAgent
        ), "This learning algorithm works only with READPointerAgent"

        dataset_size = len(train_dataset)

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)
            total_cb_segments = 0
            num_reached_acceptable_circle = 0
            total_segments = 0
            total_supervised_segments = 0

            action_counts = dict()
            action_counts[ReadPointerAgent.READ_MODE] = [0] * 2
            action_counts[ReadPointerAgent.
                          ACT_MODE] = [0] * self.action_space.num_actions()

            # Test on tuning data
            agent.test_forced_reading(tune_dataset,
                                      tensorboard=self.tensorboard)

            batch_replay_items = []
            total_reward = 0
            episodes_in_batch = 0

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)
                    logging.info(
                        "Contextual bandit segments %r, success %r per.",
                        total_cb_segments,
                        (num_reached_acceptable_circle * 100) /
                        float(max(1, total_cb_segments)))
                    logging.info("Num segments %r, Percent supervised %r",
                                 total_segments,
                                 (total_supervised_segments * 100) /
                                 float(max(1, total_segments)))
                    logging.info("Training data action counts %r",
                                 action_counts)

                num_actions = 0
                max_num_actions = len(data_point.get_trajectory())
                max_num_actions += self.constants["max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)
                oracle_segments = data_point.get_instruction_oracle_segmented()
                pose = int(metadata["y_angle"] / 15.0)
                state = AgentObservedState(instruction=data_point.instruction,
                                           config=self.config,
                                           constants=self.constants,
                                           start_image=image,
                                           previous_action=None,
                                           pose=pose)

                per_segment_budget = int(max_num_actions /
                                         len(oracle_segments))
                num_segment_actions = 0
                trajectory_segments = data_point.get_sub_trajectory_list()

                mode = ReadPointerAgent.READ_MODE
                current_segment_ix = 0
                num_supervised_rollout = self.rollin_policy.num_oracle_rollin_segments(
                    len(trajectory_segments))
                total_segments += len(trajectory_segments)

                while True:

                    if mode == ReadPointerAgent.READ_MODE:
                        # Find the number of tokens to read for the gold segment
                        num_segment_size = len(
                            oracle_segments[current_segment_ix])
                        current_segment_ix += 1
                        for i in range(0, num_segment_size):
                            state = state.update_on_read()
                        mode = ReadPointerAgent.ACT_MODE
                        total_segments += 1

                    elif mode == ReadPointerAgent.ACT_MODE:

                        if current_segment_ix <= num_supervised_rollout:
                            # Do supervised learning for this segment
                            for action in trajectory_segments[
                                    current_segment_ix - 1]:
                                image, reward, metadata = agent.server.send_action_receive_feedback(
                                    action)

                                # Store it in the replay memory list. Use reward of 1 as it is supervised learning
                                all_rewards = self._get_all_rewards(metadata)
                                replay_item = ReplayMemoryItem(
                                    state,
                                    action,
                                    reward=1,
                                    mode=mode,
                                    all_rewards=all_rewards)
                                batch_replay_items.append(replay_item)

                                # Update the agent state
                                pose = int(metadata["y_angle"] / 15.0)
                                state = state.update(image, action, pose=pose)

                                num_actions += 1
                                total_reward += reward

                            # Change the segment
                            assert metadata[
                                "goal_dist"] < 5.0, "oracle segments out of acceptable circle"

                            if state.are_tokens_left_to_be_read():

                                mode = ReadPointerAgent.READ_MODE

                                # Jump to the next goal
                                agent.server.force_goal_update()
                                state = state.update_on_act_halt()
                                num_segment_actions = 0
                            else:
                                forced_stop = True
                                break

                        else:
                            # Do contextual bandit for this segment and future

                            # Generate probabilities over actions
                            probabilities = list(
                                torch.exp(
                                    self.model.get_probs(state, mode).data))

                            # Sample an action from the distribution
                            action = gp.sample_action_from_prob(probabilities)

                            action_counts[mode][action] += 1

                            # deal with act mode boundary conditions
                            if num_actions >= max_num_actions:
                                break

                            elif action == agent.action_space.get_stop_action_index(
                            ) or num_segment_actions > per_segment_budget:

                                within_acceptable_circle = metadata[
                                    "goal_dist"] < 5.0
                                if within_acceptable_circle:
                                    num_reached_acceptable_circle += 1
                                total_cb_segments += 1

                                if state.are_tokens_left_to_be_read():
                                    if within_acceptable_circle:

                                        if metadata["error"] < 5.0:
                                            reward = 1.0
                                        else:
                                            reward = -1.0

                                        # Add to replay memory
                                        all_rewards = metadata["all_reward"]
                                        replay_item = ReplayMemoryItem(
                                            state,
                                            agent.action_space.
                                            get_stop_action_index(),
                                            reward,
                                            mode,
                                            all_rewards=all_rewards)
                                        batch_replay_items.append(replay_item)

                                        mode = ReadPointerAgent.READ_MODE
                                        # Jump to the next goal
                                        agent.server.force_goal_update()

                                        state = state.update_on_act_halt()
                                        num_segment_actions = 0
                                    else:
                                        # No point going any further so break
                                        break
                                else:
                                    break

                            else:
                                image, reward, metadata = agent.server.send_action_receive_feedback(
                                    action)

                                # Store it in the replay memory list
                                all_rewards = self._get_all_rewards(metadata)
                                replay_item = ReplayMemoryItem(
                                    state,
                                    action,
                                    reward,
                                    mode=mode,
                                    all_rewards=all_rewards)
                                batch_replay_items.append(replay_item)

                                # Update the agent state
                                pose = int(metadata["y_angle"] / 15.0)
                                state = state.update(image, action, pose=pose)

                                num_actions += 1
                                num_segment_actions += 1
                                total_reward += reward

                    else:
                        raise AssertionError(
                            "Mode should be either read or act. Unhandled mode: "
                            + str(mode))

                assert mode == ReadPointerAgent.ACT_MODE, "Agent should end on Act Mode"

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                episodes_in_batch += 1
                if episodes_in_batch == 1:
                    loss_val = self.do_update(batch_replay_items)
                    batch_replay_items = []
                    entropy_val = float(self.entropy.data[0])
                    self.tensorboard.log(entropy_val, loss_val, total_reward)
                    total_reward = 0
                    episodes_in_batch = 0

                if self.tensorboard is not None:
                    self.tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

            # Save the model
            self.model.save_model(
                experiment_name +
                "/read_pointer_forced_reading_curriculum_contextual_bandit_epoch_"
                + str(epoch))

            logging.info("Training data action counts %r", action_counts)