def train_epoch(self, epoch_info: EpochInfo) -> None:
        """ Train model for a single epoch  """
        epoch_info.on_epoch_begin()

        for batch_idx in tqdm.trange(epoch_info.batches_per_epoch, file=sys.stdout, desc="Training", unit="batch"):
            batch_info = BatchInfo(epoch_info, batch_idx)

            batch_info.on_batch_begin()
            self.train_batch(batch_info)
            batch_info.on_batch_end()

        epoch_info.result_accumulator.freeze_results()
        epoch_info.on_epoch_end()
Ejemplo n.º 2
0
    def train_epoch(self, epoch_info: EpochInfo):
        """ Train model on an epoch of a fixed number of batch updates """
        epoch_info.on_epoch_begin()

        for batch_idx in tqdm.trange(epoch_info.batches_per_epoch,
                                     file=sys.stdout,
                                     desc="Training",
                                     unit="batch"):
            batch_info = BatchInfo(epoch_info, batch_idx)

            batch_info.on_batch_begin()
            self.train_batch(batch_info)
            batch_info.on_batch_end()

        epoch_info.result_accumulator.freeze_results()
        epoch_info.on_epoch_end()
Ejemplo n.º 3
0
 def epoch_info(self, training_info: TrainingInfo, global_idx: int,
                local_idx: int) -> EpochInfo:
     """ Create Epoch info """
     return EpochInfo(training_info,
                      global_epoch_idx=global_idx,
                      local_epoch_idx=local_idx,
                      batches_per_epoch=0)
Ejemplo n.º 4
0
    def run(self):
        """ Run the command with supplied configuration """
        device = torch.device(self.model_config.device)
        learner = Learner(device, self.model_factory.instantiate())
        optimizer = self.optimizer_factory.instantiate(learner.model)

        # All callbacks used for learning
        callbacks = self.gather_callbacks(optimizer)

        # Metrics to track through this training
        metrics = learner.metrics()

        # Check if training was already started and potentially continue where we left off
        training_info = self.resume_training(learner, optimizer, callbacks, metrics)

        training_info.on_train_begin()

        for global_epoch_idx in range(training_info.start_epoch_idx + 1, self.epochs + 1):
            epoch_info = EpochInfo(
                training_info=training_info,
                global_epoch_idx=global_epoch_idx,
                batches_per_epoch=self.source.train_iterations_per_epoch(),
                optimizer=optimizer
            )

            # Execute learning
            learner.run_epoch(epoch_info, self.source)

            self.storage.checkpoint(epoch_info, learner.model)

        training_info.on_train_end()

        return training_info
Ejemplo n.º 5
0
    def checkpoint(self, epoch_info: EpochInfo, model: Model):
        """ When epoch is done, we persist the training state """
        self.clean(epoch_info.global_epoch_idx - 1)

        self._make_sure_dir_exists()

        # Checkpoint latest
        torch.save(model.state_dict(), self.checkpoint_filename(epoch_info.global_epoch_idx))

        hidden_state = epoch_info.state_dict()
        self.checkpoint_strategy.write_state_dict(hidden_state)

        torch.save(hidden_state, self.checkpoint_hidden_filename(epoch_info.global_epoch_idx))

        if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint(
                                                   epoch_info.global_epoch_idx):
            prev_epoch_idx = epoch_info.global_epoch_idx - 1

            os.remove(self.checkpoint_filename(prev_epoch_idx))
            os.remove(self.checkpoint_hidden_filename(prev_epoch_idx))

        if self.checkpoint_strategy.should_store_best_checkpoint(epoch_info.global_epoch_idx, epoch_info.result):
            best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx

            if best_checkpoint_idx is not None:
                os.remove(self.checkpoint_best_filename(best_checkpoint_idx))

            torch.save(model.state_dict(), self.checkpoint_best_filename(epoch_info.global_epoch_idx))

            self.checkpoint_strategy.store_best_checkpoint_idx(epoch_info.global_epoch_idx)

        self.backend.store(epoch_info.result)
Ejemplo n.º 6
0
 def epoch_info(self, training_info: TrainingInfo, global_idx: int,
                local_idx: int) -> EpochInfo:
     """ Create Epoch info """
     return EpochInfo(
         training_info=training_info,
         global_epoch_idx=global_idx,
         local_epoch_idx=local_idx,
         batches_per_epoch=self._source.train_iterations_per_epoch(),
         optimizer=self._optimizer_instance)
Ejemplo n.º 7
0
def pivoting_rl(args):
    device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu')
    seed = 1002

    # Set random seed in python std lib, numpy and pytorch
    set_seed(seed)

    vec_env = DummyVecEnvWrapper(
        MujocoEnv('HalfCheetah-v2')
    ).instantiate(parallel_envs=1, seed=seed)

    if args.algo == 'ddpg':
       model, reinforcer = get_ddpg(vec_env, device)
    elif args.algo == 'ppo':
        model, reinforcer = get_ppo(vec_env, device)
    else:
        print('Unknown algo', args.algo); assert(False)


    # Optimizer helper - A weird regularization settings I've copied from OpenAI code
    adam_optimizer = AdamFactory(
        lr=[1.0e-4, 1.0e-3, 1.0e-3],
        weight_decay=[0.0, 0.0, 0.001],
        eps=1.0e-4,
        layer_groups=True
    ).instantiate(model)

    # Overall information store for training information
    training_info = TrainingInfo(
        metrics=[
            EpisodeRewardMetric('episode_rewards'),  # Calculate average reward from episode
        ],
        callbacks=[StdoutStreaming()]  # Print live metrics every epoch to standard output
    )

    # A bit of training initialization bookkeeping...
    training_info.initialize()
    reinforcer.initialize_training(training_info)
    training_info.on_train_begin()

    # Let's make 20 batches per epoch to average metrics nicely
    num_epochs = int(1.0e6 / 2 / 1000)

    # Normal handrolled training loop
    for i in range(1, num_epochs+1):
        epoch_info = EpochInfo(
            training_info=training_info,
            global_epoch_idx=i,
            batches_per_epoch=1000,
            optimizer=adam_optimizer
        )

        reinforcer.train_epoch(epoch_info)

    training_info.on_train_end()
Ejemplo n.º 8
0
 def epoch_info(self, training_info: TrainingInfo, global_idx: int,
                local_idx: int) -> EpochInfo:
     """ Create Epoch info """
     return EpochInfo(
         training_info=training_info,
         global_epoch_idx=global_idx,
         local_epoch_idx=local_idx,
         batches_per_epoch=self._source.train_iterations_per_epoch(),
         optimizer=self._optimizer_instance,
         # Add special callback for this epoch
         callbacks=[self.special_callback] + training_info.callbacks)
Ejemplo n.º 9
0
    def run(self):
        """ Run reinforcement learning algorithm """
        device = torch.device(self.model_config.device)
        # Reinforcer is the learner for the reinforcement learning model
        reinforcer = self.reinforcer.instantiate(device)
        optimizer = self.optimizer_factory.instantiate(reinforcer.model)

        # All callbacks used for learning
        callbacks = self.gather_callbacks(optimizer)
        # Metrics to track through this training
        metrics = reinforcer.metrics()

        training_info = self.resume_training(reinforcer, callbacks, metrics)

        reinforcer.initialize_training(training_info)
        training_info.on_train_begin()

        if training_info.optimizer_initial_state:
            optimizer.load_state_dict(training_info.optimizer_initial_state)

        global_epoch_idx = training_info.start_epoch_idx + 1
        training_info['total_frames'] = self.total_frames

        while training_info['frames'] < self.total_frames:
            epoch_info = EpochInfo(
                training_info,
                global_epoch_idx=global_epoch_idx,
                batches_per_epoch=self.batches_per_epoch,
                optimizer=optimizer,
            )

            reinforcer.train_epoch(epoch_info)

            if self.openai_logging:
                self._openai_logging(epoch_info.result)

            self.storage.checkpoint(epoch_info, reinforcer.model)

            global_epoch_idx += 1

        training_info.on_train_end()

        return training_info
Ejemplo n.º 10
0
        metrics=[
            EpisodeRewardMetric(
                'episode_rewards'),  # Calculate average reward from episode
        ],
        callbacks=[StdoutStreaming()
                   ]  # Print live metrics every epoch to standard output
    )

    # A bit of training initialization bookkeeping...
    training_info.initialize()
    reinforcer.initialize_training(training_info)
    training_info.on_train_begin()

    # Let's make 20 batches per epoch to average metrics nicely
    num_epochs = int(1.0e6 / 2 / 1000)

    # Normal handrolled training loop
    for i in range(1, num_epochs + 1):
        epoch_info = EpochInfo(training_info=training_info,
                               global_epoch_idx=i,
                               batches_per_epoch=1000,
                               optimizer=adam_optimizer)

        reinforcer.train_epoch(epoch_info)

    training_info.on_train_end()


if __name__ == '__main__':
    half_cheetah_ddpg()
Ejemplo n.º 11
0
    def run(self):
        """ Run the command with supplied configuration """
        device = torch.device(self.model_config.device)
        learner = Learner(device, self.model.instantiate())

        lr_schedule = interp.interpolate_series(self.start_lr, self.end_lr,
                                                self.num_it,
                                                self.interpolation)

        if self.freeze:
            learner.model.freeze()

        # Optimizer shoudl be created after freeze
        optimizer = self.optimizer_factory.instantiate(learner.model)

        iterator = iter(self.source.train_loader())

        # Metrics to track through this training
        metrics = learner.metrics() + [AveragingNamedMetric("lr")]

        learner.train()

        best_value = None

        training_info = TrainingInfo(start_epoch_idx=0, metrics=metrics)

        # Treat it all as one epoch
        epoch_info = EpochInfo(training_info,
                               global_epoch_idx=1,
                               batches_per_epoch=1,
                               optimizer=optimizer)

        for iteration_idx, lr in enumerate(tqdm.tqdm(lr_schedule)):
            batch_info = BatchInfo(epoch_info, iteration_idx)

            # First, set the learning rate, the same for each parameter group
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            try:
                data, target = next(iterator)
            except StopIteration:
                iterator = iter(self.source.train_loader())
                data, target = next(iterator)

            learner.train_batch(batch_info, data, target)

            batch_info['lr'] = lr

            # METRIC RECORDING PART
            epoch_info.result_accumulator.calculate(batch_info)

            current_value = epoch_info.result_accumulator.intermediate_value(
                self.metric)

            final_metrics = {
                'epoch_idx': iteration_idx,
                self.metric: current_value,
                'lr': lr
            }

            if best_value is None or current_value < best_value:
                best_value = current_value

            # Stop on divergence
            if self.stop_dv and (np.isnan(current_value) or current_value >
                                 best_value * self.divergence_threshold):
                break

            training_info.history.add(final_metrics)

        frame = training_info.history.frame()

        fig, ax = plt.subplots(1, 2)

        ax[0].plot(frame.index, frame.lr)
        ax[0].set_title("LR Schedule")
        ax[0].set_xlabel("Num iterations")
        ax[0].set_ylabel("Learning rate")

        if self.interpolation == 'logscale':
            ax[0].set_yscale("log", nonposy='clip')

        ax[1].plot(frame.lr, frame[self.metric], label=self.metric)
        # ax[1].plot(frame.lr, frame[self.metric].ewm(com=20).mean(), label=self.metric + ' smooth')
        ax[1].set_title(self.metric)
        ax[1].set_xlabel("Learning rate")
        ax[1].set_ylabel(self.metric)
        # ax[1].legend()

        if self.interpolation == 'logscale':
            ax[1].set_xscale("log", nonposx='clip')

        plt.show()