def train_batch(self, batch_info: BatchInfo) -> None:
        """
        Batch - the most atomic unit of learning.

        For this reinforforcer, that involves:

        1. Roll out the environmnent using current policy
        2. Use that rollout to train the policy
        """
        # Calculate environment rollout on the evaluation version of the model
        self.model.train()

        rollout = self.env_roller.rollout(batch_info, self.model,
                                          self.settings.number_of_steps)

        # Process rollout by the 'algo' (e.g. perform the advantage estimation)
        rollout = self.algo.process_rollout(batch_info, rollout)

        # Perform the training step
        # Algo will aggregate data into this list:
        batch_info['sub_batch_data'] = []

        if self.settings.shuffle_transitions:
            rollout = rollout.to_transitions()

        if self.settings.stochastic_experience_replay:
            # Always play experience at least once
            experience_replay_count = max(
                np.random.poisson(self.settings.experience_replay), 1)
        else:
            experience_replay_count = self.settings.experience_replay

        # Repeat the experience N times
        for i in range(experience_replay_count):
            # We may potentially need to split rollout into multiple batches
            if self.settings.batch_size >= rollout.frames():
                batch_result = self.algo.optimizer_step(
                    batch_info=batch_info,
                    device=self.device,
                    model=self.model,
                    rollout=rollout.to_device(self.device))

                batch_info['sub_batch_data'].append(batch_result)
            else:
                # Rollout too big, need to split in batches
                for batch_rollout in rollout.shuffled_batches(
                        self.settings.batch_size):
                    batch_result = self.algo.optimizer_step(
                        batch_info=batch_info,
                        device=self.device,
                        model=self.model,
                        rollout=batch_rollout.to_device(self.device))

                    batch_info['sub_batch_data'].append(batch_result)

        batch_info['frames'] = rollout.frames()
        batch_info['episode_infos'] = rollout.episode_information()

        # Even with all the experience replay, we count the single rollout as a single batch
        batch_info.aggregate_key('sub_batch_data')
Ejemplo n.º 2
0
    def train_batch(self, batch_info: BatchInfo) -> None:
        """
        Batch - the most atomic unit of learning.

        For this reinforforcer, that involves:

        1. Roll out environment and store out experience in the buffer
        2. Sample the buffer and train the algo on sample batch
        """
        # Each DQN batch is
        # 1. Roll out environment and store out experience in the buffer
        self.model.eval()

        # Helper variables for rollouts
        episode_information = []
        frames = 0

        with torch.no_grad():
            if not self.env_roller.is_ready_for_sampling():
                while not self.env_roller.is_ready_for_sampling():
                    rollout = self.env_roller.rollout(batch_info, self.model)

                    episode_information.extend(rollout.episode_information())
                    frames += rollout.frames()
            else:
                for i in range(self.settings.batch_rollout_rounds):
                    rollout = self.env_roller.rollout(batch_info, self.model)

                    episode_information.extend(rollout.episode_information())
                    frames += rollout.frames()

        batch_info['frames'] = frames
        batch_info['episode_infos'] = episode_information

        # 2. Sample the buffer and train the algo on sample batch
        self.model.train()

        # Algo will aggregate data into this list:
        batch_info['sub_batch_data'] = []

        for i in range(self.settings.batch_training_rounds):
            sampled_rollout = self.env_roller.sample(batch_info, self.model)

            batch_result = self.algo.optimizer_step(
                batch_info=batch_info,
                device=self.device,
                model=self.model,
                rollout=sampled_rollout
            )

            self.env_roller.update(rollout=sampled_rollout, batch_info=batch_result)

            batch_info['sub_batch_data'].append(batch_result)

        batch_info.aggregate_key('sub_batch_data')
Ejemplo n.º 3
0
    def train_batch(self, batch_info: BatchInfo):
        """ Single, most atomic 'step' of learning this reinforcer can perform """
        batch_info['sub_batch_data'] = []

        self.on_policy_train_batch(batch_info)

        if self.settings.experience_replay > 0 and self.env_roller.is_ready_for_sampling(
        ):
            if self.settings.stochastic_experience_replay:
                experience_replay_count = np.random.poisson(
                    self.settings.experience_replay)
            else:
                experience_replay_count = self.settings.experience_replay

            for i in range(experience_replay_count):
                self.off_policy_train_batch(batch_info)

        # Even with all the experience replay, we count the single rollout as a single batch
        batch_info.aggregate_key('sub_batch_data')
    def train_epoch(self, epoch_info: EpochInfo) -> None:
        """ Train model for a single epoch  """
        epoch_info.on_epoch_begin()

        for batch_idx in tqdm.trange(epoch_info.batches_per_epoch, file=sys.stdout, desc="Training", unit="batch"):
            batch_info = BatchInfo(epoch_info, batch_idx)

            batch_info.on_batch_begin()
            self.train_batch(batch_info)
            batch_info.on_batch_end()

        epoch_info.result_accumulator.freeze_results()
        epoch_info.on_epoch_end()
Ejemplo n.º 5
0
    def train_epoch(self, epoch_info: EpochInfo):
        """ Train model on an epoch of a fixed number of batch updates """
        epoch_info.on_epoch_begin()

        for batch_idx in tqdm.trange(epoch_info.batches_per_epoch,
                                     file=sys.stdout,
                                     desc="Training",
                                     unit="batch"):
            batch_info = BatchInfo(epoch_info, batch_idx)

            batch_info.on_batch_begin()
            self.train_batch(batch_info)
            batch_info.on_batch_end()

        epoch_info.result_accumulator.freeze_results()
        epoch_info.on_epoch_end()
Ejemplo n.º 6
0
    def run(self):
        """ Run the command with supplied configuration """
        device = torch.device(self.model_config.device)
        learner = Learner(device, self.model.instantiate())

        lr_schedule = interp.interpolate_series(self.start_lr, self.end_lr,
                                                self.num_it,
                                                self.interpolation)

        if self.freeze:
            learner.model.freeze()

        # Optimizer shoudl be created after freeze
        optimizer = self.optimizer_factory.instantiate(learner.model)

        iterator = iter(self.source.train_loader())

        # Metrics to track through this training
        metrics = learner.metrics() + [AveragingNamedMetric("lr")]

        learner.train()

        best_value = None

        training_info = TrainingInfo(start_epoch_idx=0, metrics=metrics)

        # Treat it all as one epoch
        epoch_info = EpochInfo(training_info,
                               global_epoch_idx=1,
                               batches_per_epoch=1,
                               optimizer=optimizer)

        for iteration_idx, lr in enumerate(tqdm.tqdm(lr_schedule)):
            batch_info = BatchInfo(epoch_info, iteration_idx)

            # First, set the learning rate, the same for each parameter group
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            try:
                data, target = next(iterator)
            except StopIteration:
                iterator = iter(self.source.train_loader())
                data, target = next(iterator)

            learner.train_batch(batch_info, data, target)

            batch_info['lr'] = lr

            # METRIC RECORDING PART
            epoch_info.result_accumulator.calculate(batch_info)

            current_value = epoch_info.result_accumulator.intermediate_value(
                self.metric)

            final_metrics = {
                'epoch_idx': iteration_idx,
                self.metric: current_value,
                'lr': lr
            }

            if best_value is None or current_value < best_value:
                best_value = current_value

            # Stop on divergence
            if self.stop_dv and (np.isnan(current_value) or current_value >
                                 best_value * self.divergence_threshold):
                break

            training_info.history.add(final_metrics)

        frame = training_info.history.frame()

        fig, ax = plt.subplots(1, 2)

        ax[0].plot(frame.index, frame.lr)
        ax[0].set_title("LR Schedule")
        ax[0].set_xlabel("Num iterations")
        ax[0].set_ylabel("Learning rate")

        if self.interpolation == 'logscale':
            ax[0].set_yscale("log", nonposy='clip')

        ax[1].plot(frame.lr, frame[self.metric], label=self.metric)
        # ax[1].plot(frame.lr, frame[self.metric].ewm(com=20).mean(), label=self.metric + ' smooth')
        ax[1].set_title(self.metric)
        ax[1].set_xlabel("Learning rate")
        ax[1].set_ylabel(self.metric)
        # ax[1].legend()

        if self.interpolation == 'logscale':
            ax[1].set_xscale("log", nonposx='clip')

        plt.show()
    def train_batch(self, batch_info: BatchInfo) -> None:
        """
        Batch - the most atomic unit of learning.

        For this reinforforcer, that involves:

        1. Roll out environment and store out experience in the buffer
        2. Sample the buffer and train the algo on sample batch
        """
        # Each DQN batch is
        # 1. Roll out environment and store out experience in the buffer
        self.model.eval()

        # Helper variables for rollouts
        episode_information = []
        rollout_actions = []
        rollout_values = []
        frames = 0

        with torch.no_grad():
            if not self.env_roller.is_ready_for_sampling():
                while not self.env_roller.is_ready_for_sampling():
                    rollout = self.env_roller.rollout(batch_info, self.model)
                    maybe_episode_info = rollout['episode_information']

                    if maybe_episode_info is not None:
                        episode_information.append(maybe_episode_info)

                    frames += 1
                    rollout_actions.append(rollout['action'].detach().cpu().numpy())
                    rollout_values.append(rollout['value'].detach().cpu().numpy())
            else:
                for i in range(self.settings.batch_rollout_rounds):
                    rollout = self.env_roller.rollout(batch_info, self.model)
                    maybe_episode_info = rollout['episode_information']

                    if maybe_episode_info is not None:
                        episode_information.append(maybe_episode_info)

                    frames += 1
                    rollout_actions.append(rollout['action'].detach().cpu().numpy())
                    rollout_values.append(rollout['value'].detach().cpu().numpy())

        batch_info['rollout_action_mean'] = np.mean(rollout_actions)
        batch_info['rollout_action_std'] = np.std(rollout_actions)
        batch_info['rollout_value_mean'] = np.std(rollout_values)

        batch_info['frames'] = frames
        batch_info['episode_infos'] = episode_information

        # 2. Sample the buffer and train the algo on sample batch
        self.model.train()

        # Algo will aggregate data into this list:
        batch_info['sub_batch_data'] = []

        for i in range(self.settings.batch_training_rounds):
            batch_sample = self.env_roller.sample(batch_info, self.model)

            batch_result = self.algo.optimizer_step(
                batch_info=batch_info,
                device=self.device,
                model=self.model,
                rollout=batch_sample
            )

            self.env_roller.update(sample=batch_sample, batch_info=batch_result)

            batch_info['sub_batch_data'].append(batch_result)

        batch_info.aggregate_key('sub_batch_data')