示例#1
0
    def train(self):
        """ This allows the agent to train itself to better understand the environment dynamics.
        The agent will compute the expected reward for the state(t+1)
        and update the expected reward at step t according to this.

        The target expectation is computed through the Target Network, which is a more stable version
        of the Action Value Network for increasing training stability.

        The Target Network is a frozen copy of the Action Value Network updated as regular intervals.
        """

        agent_step = self._num_actions_taken

        if agent_step >= self._train_after:
            if (agent_step % self._train_interval) == 0:
                pre_states, actions, post_states, rewards, terminals = self._memory.minibatch(self._minibatch_size)

                self._trainer.train_minibatch(
                    self._trainer.loss_function.argument_map(
                        pre_states=pre_states,
                        actions=Value.one_hot(actions.reshape(-1, 1).tolist(), self.nb_actions),
                        post_states=post_states,
                        rewards=rewards,
                        terminals=terminals
                    )
                )

                # Update the Target Network if needed
                if (agent_step % self._target_update_interval) == 0:
                    self._target_net = self._action_value_net.clone(CloneMethod.freeze)
                    filename = "models\model%d" % agent_step
                    self._trainer.save_checkpoint(filename)
示例#2
0
    def train(self):
        """ This allows the agent to train itself to better understand the environment dynamics.
        The agent will compute the expected reward for the state(t+1)
        and update the expected reward at step t according to this.

        The target expectation is computed through the Target Network, which is a more stable version
        of the Action Value Network for increasing training stability.

        The Target Network is a frozen copy of the Action Value Network updated as regular intervals.
        """

        agent_step = self._num_actions_taken

        if agent_step >= self._train_after:
            #if (agent_step % self._train_interval) == 0:
            print('\nTraining minibatch\n')
            client.setCarControls(zero_controls)
            pre_states, actions, post_states, rewards, terminals = self._memory.minibatch(self._minibatch_size)
            self._trainer.train_minibatch(
                self._trainer.loss_function.argument_map(
                    pre_states=pre_states,
                    actions=Value.one_hot(actions.reshape(-1, 1).tolist(), self.nb_actions),
                    post_states=post_states,
                    rewards=rewards,
                    terminals=terminals
                )
            )
            self._num_trains += 1
            # Update the Target Network if needed
            if self._num_trains % 20 == 0:
                print('updating network')
                self._target_net = self._action_value_net.clone(CloneMethod.freeze)
                filename = dirname+"\model%d" % agent_step
                self._trainer.save_checkpoint(filename)
示例#3
0
    def train(self, checkpoint_dir):
        """ This allows the agent to train itself to better understand the environment dynamics.
        The agent will compute the expected reward for the state(t+1)
        and update the expected reward at step t according to this.

        The target expectation is computed through the Target Network, which is a more stable version
        of the Action Value Network for increasing training stability.

        The Target Network is a frozen copy of the Action Value Network updated as regular intervals.
        """

        agent_step = self._num_actions_taken

        if agent_step >= self._train_after:
            if (agent_step % self._train_interval) == 0:
                #print('training... number of steps: {}'.format(agent_step))

                pre_states, actions, post_states, rewards, terminals = self._memory.minibatch(
                    self._minibatch_size)
                self._trainer.train_minibatch(
                    self._trainer.loss_function.argument_map(
                        pre_states=pre_states,
                        actions=Value.one_hot(
                            actions.reshape(-1, 1).tolist(), self.nb_actions),
                        post_states=post_states,
                        rewards=rewards,
                        terminals=terminals))

                # Update the Target Network if needed
                if (agent_step % self._target_update_interval) == 0:
                    self._target_net = self._action_value_net.clone(
                        CloneMethod.freeze)
                    filename = os.path.join(checkpoint_dir,
                                            "models\model%d" % agent_step)
                    self._trainer.save_checkpoint(filename)
示例#4
0
文件: io_tests.py 项目: pospanet/CNTK
    def next_minibatch(self,
                       num_samples,
                       number_of_workers=1,
                       worker_rank=0,
                       device=None):
        features = []
        labels = []

        sweep_end = False

        f_sample_count = 0
        l_sample_count = 0

        while max(f_sample_count, l_sample_count) < num_samples:
            if self.next_seq_idx == len(self.sequences):
                sweep_end = True
                self.next_seq_idx = 0

            seq_id = self.sequences[self.sequences[self.next_seq_idx]]

            f_data = self.data[seq_id]['features']
            l_data = self.data[seq_id]['labels']
            if (features or labels) and max(
                    f_sample_count + len(f_data),
                    l_sample_count + len(l_data)) > num_samples:
                break
            f_sample_count += len(f_data)
            features.append(f_data)

            l_sample_count += len(l_data)
            labels.append(l_data)

            self.next_seq_idx += 1

        num_seq = len(features)

        f_data = Value.one_hot(batch=features, num_classes=self.f_dim)
        l_data = Value(batch=np.asarray(labels, dtype=np.float32))

        result = {
            self.fsi: MinibatchData(f_data, num_seq, f_sample_count,
                                    sweep_end),
            self.lsi: MinibatchData(l_data, num_seq, l_sample_count, sweep_end)
        }

        return result
示例#5
0
    def next_minibatch(self, num_samples, number_of_workers, worker_rank, device=None):
        features = []
        labels = []

        sweep_end = False

        f_sample_count = 0
        l_sample_count = 0


        while max(f_sample_count, l_sample_count) < num_samples:
            if self.next_seq_idx == len(self.sequences):
                sweep_end = True
                self.next_seq_idx = 0

            seq_id = self.sequences[self.sequences[self.next_seq_idx]]

            f_data = self.data[seq_id]['features']
            l_data = self.data[seq_id]['labels']
            if (features or labels) and max(f_sample_count+len(f_data), l_sample_count+len(l_data)) > num_samples:
                break
            f_sample_count += len(f_data)
            features.append(f_data)

            l_sample_count += len(l_data)
            labels.append(l_data)

            self.next_seq_idx += 1

        num_seq = len(features)

        f_data = Value.one_hot(batch=features, num_classes=self.f_dim)
        l_data = Value(batch=np.asarray(labels, dtype=np.float32))
        result = {
                self.fsi: MinibatchData(f_data, num_seq, f_sample_count, sweep_end),
                self.lsi: MinibatchData(l_data, num_seq, l_sample_count, sweep_end)
                }

        return result
示例#6
0
    def train(self):

        agent_step = self._num_actions_taken

        if agent_step >= self._train_after:
            if (agent_step % self._train_interval) == 0:
                pre_states, actions, post_states, rewards, terminals = self._memory.minibatch(self._minibatch_size)

                self._trainer.train_minibatch(
                    self._trainer.loss_function.argument_map(
                        pre_states=pre_states,
                        actions=Value.one_hot(actions.reshape(-1,1).tolist(), self.nb_actions),
                        post_states=post_states,
                        rewards=rewards,
                        terminals=terminals
                    )
                )

            if (agent_step % self._target_update_interval) == 0:
                self._target_net = self._action_value_net.clone(CloneMethod.freeze)
                filename = "model\model%d" % agent_step # save ???? not good at using %d
                self._trainer.save_checkpoint(filename)
    def train(self):
        """ This allows the agent to train itself to better understand the environment dynamics.
        The agent will compute the expected reward for the state(t+1)
        and update the expected reward at step t according to this.

        The target expectation is computed through the Target Network, which is a more stable version
        of the Action Value Network for increasing training stability.

        The Target Network is a frozen copy of the Action Value Network updated as regular intervals.
        """
        agent_step = self._num_actions_taken

        if agent_step >= self._train_after:
            if (agent_step % self._train_interval) == 0:
                pre_states, actions, post_states, rewards, terminals = self._memory.minibatch(
                    self._minibatch_size)

                print('Training the agent')
                x, t = self.target_qvals([
                    pre_states,
                    Value.one_hot(
                        actions.reshape(-1, 1).tolist(), self.nb_actions),
                    post_states, rewards, terminals
                ], self._action_value_net, self._target_net)
                self._action_value_net.fit(
                    np.reshape(x, (len(pre_states), self.input_shape)),
                    np.reshape(t, (len(pre_states), self.nb_actions)),
                    epochs=1,
                    verbose=0)

                # Update the Target Network if needed
                if (agent_step % self._target_update_interval) == 0:
                    print('Updating the target Network')
                    self._target_net = self._action_value_net.clone(
                        CloneMethod.freeze)
                    filename = "models\model%d" % agent_step
                    self._trainer.save_checkpoint(filename)