def train_epoch(self, epoch_info: EpochInfo) -> None: """ Train model for a single epoch """ epoch_info.on_epoch_begin() for batch_idx in tqdm.trange(epoch_info.batches_per_epoch, file=sys.stdout, desc="Training", unit="batch"): batch_info = BatchInfo(epoch_info, batch_idx) batch_info.on_batch_begin() self.train_batch(batch_info) batch_info.on_batch_end() epoch_info.result_accumulator.freeze_results() epoch_info.on_epoch_end()
def train_epoch(self, epoch_info: EpochInfo): """ Train model on an epoch of a fixed number of batch updates """ epoch_info.on_epoch_begin() for batch_idx in tqdm.trange(epoch_info.batches_per_epoch, file=sys.stdout, desc="Training", unit="batch"): batch_info = BatchInfo(epoch_info, batch_idx) batch_info.on_batch_begin() self.train_batch(batch_info) batch_info.on_batch_end() epoch_info.result_accumulator.freeze_results() epoch_info.on_epoch_end()
def epoch_info(self, training_info: TrainingInfo, global_idx: int, local_idx: int) -> EpochInfo: """ Create Epoch info """ return EpochInfo(training_info, global_epoch_idx=global_idx, local_epoch_idx=local_idx, batches_per_epoch=0)
def run(self): """ Run the command with supplied configuration """ device = torch.device(self.model_config.device) learner = Learner(device, self.model_factory.instantiate()) optimizer = self.optimizer_factory.instantiate(learner.model) # All callbacks used for learning callbacks = self.gather_callbacks(optimizer) # Metrics to track through this training metrics = learner.metrics() # Check if training was already started and potentially continue where we left off training_info = self.resume_training(learner, optimizer, callbacks, metrics) training_info.on_train_begin() for global_epoch_idx in range(training_info.start_epoch_idx + 1, self.epochs + 1): epoch_info = EpochInfo( training_info=training_info, global_epoch_idx=global_epoch_idx, batches_per_epoch=self.source.train_iterations_per_epoch(), optimizer=optimizer ) # Execute learning learner.run_epoch(epoch_info, self.source) self.storage.checkpoint(epoch_info, learner.model) training_info.on_train_end() return training_info
def checkpoint(self, epoch_info: EpochInfo, model: Model): """ When epoch is done, we persist the training state """ self.clean(epoch_info.global_epoch_idx - 1) self._make_sure_dir_exists() # Checkpoint latest torch.save(model.state_dict(), self.checkpoint_filename(epoch_info.global_epoch_idx)) hidden_state = epoch_info.state_dict() self.checkpoint_strategy.write_state_dict(hidden_state) torch.save(hidden_state, self.checkpoint_hidden_filename(epoch_info.global_epoch_idx)) if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint( epoch_info.global_epoch_idx): prev_epoch_idx = epoch_info.global_epoch_idx - 1 os.remove(self.checkpoint_filename(prev_epoch_idx)) os.remove(self.checkpoint_hidden_filename(prev_epoch_idx)) if self.checkpoint_strategy.should_store_best_checkpoint(epoch_info.global_epoch_idx, epoch_info.result): best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx if best_checkpoint_idx is not None: os.remove(self.checkpoint_best_filename(best_checkpoint_idx)) torch.save(model.state_dict(), self.checkpoint_best_filename(epoch_info.global_epoch_idx)) self.checkpoint_strategy.store_best_checkpoint_idx(epoch_info.global_epoch_idx) self.backend.store(epoch_info.result)
def epoch_info(self, training_info: TrainingInfo, global_idx: int, local_idx: int) -> EpochInfo: """ Create Epoch info """ return EpochInfo( training_info=training_info, global_epoch_idx=global_idx, local_epoch_idx=local_idx, batches_per_epoch=self._source.train_iterations_per_epoch(), optimizer=self._optimizer_instance)
def pivoting_rl(args): device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu') seed = 1002 # Set random seed in python std lib, numpy and pytorch set_seed(seed) vec_env = DummyVecEnvWrapper( MujocoEnv('HalfCheetah-v2') ).instantiate(parallel_envs=1, seed=seed) if args.algo == 'ddpg': model, reinforcer = get_ddpg(vec_env, device) elif args.algo == 'ppo': model, reinforcer = get_ppo(vec_env, device) else: print('Unknown algo', args.algo); assert(False) # Optimizer helper - A weird regularization settings I've copied from OpenAI code adam_optimizer = AdamFactory( lr=[1.0e-4, 1.0e-3, 1.0e-3], weight_decay=[0.0, 0.0, 0.001], eps=1.0e-4, layer_groups=True ).instantiate(model) # Overall information store for training information training_info = TrainingInfo( metrics=[ EpisodeRewardMetric('episode_rewards'), # Calculate average reward from episode ], callbacks=[StdoutStreaming()] # Print live metrics every epoch to standard output ) # A bit of training initialization bookkeeping... training_info.initialize() reinforcer.initialize_training(training_info) training_info.on_train_begin() # Let's make 20 batches per epoch to average metrics nicely num_epochs = int(1.0e6 / 2 / 1000) # Normal handrolled training loop for i in range(1, num_epochs+1): epoch_info = EpochInfo( training_info=training_info, global_epoch_idx=i, batches_per_epoch=1000, optimizer=adam_optimizer ) reinforcer.train_epoch(epoch_info) training_info.on_train_end()
def epoch_info(self, training_info: TrainingInfo, global_idx: int, local_idx: int) -> EpochInfo: """ Create Epoch info """ return EpochInfo( training_info=training_info, global_epoch_idx=global_idx, local_epoch_idx=local_idx, batches_per_epoch=self._source.train_iterations_per_epoch(), optimizer=self._optimizer_instance, # Add special callback for this epoch callbacks=[self.special_callback] + training_info.callbacks)
def run(self): """ Run reinforcement learning algorithm """ device = torch.device(self.model_config.device) # Reinforcer is the learner for the reinforcement learning model reinforcer = self.reinforcer.instantiate(device) optimizer = self.optimizer_factory.instantiate(reinforcer.model) # All callbacks used for learning callbacks = self.gather_callbacks(optimizer) # Metrics to track through this training metrics = reinforcer.metrics() training_info = self.resume_training(reinforcer, callbacks, metrics) reinforcer.initialize_training(training_info) training_info.on_train_begin() if training_info.optimizer_initial_state: optimizer.load_state_dict(training_info.optimizer_initial_state) global_epoch_idx = training_info.start_epoch_idx + 1 training_info['total_frames'] = self.total_frames while training_info['frames'] < self.total_frames: epoch_info = EpochInfo( training_info, global_epoch_idx=global_epoch_idx, batches_per_epoch=self.batches_per_epoch, optimizer=optimizer, ) reinforcer.train_epoch(epoch_info) if self.openai_logging: self._openai_logging(epoch_info.result) self.storage.checkpoint(epoch_info, reinforcer.model) global_epoch_idx += 1 training_info.on_train_end() return training_info
metrics=[ EpisodeRewardMetric( 'episode_rewards'), # Calculate average reward from episode ], callbacks=[StdoutStreaming() ] # Print live metrics every epoch to standard output ) # A bit of training initialization bookkeeping... training_info.initialize() reinforcer.initialize_training(training_info) training_info.on_train_begin() # Let's make 20 batches per epoch to average metrics nicely num_epochs = int(1.0e6 / 2 / 1000) # Normal handrolled training loop for i in range(1, num_epochs + 1): epoch_info = EpochInfo(training_info=training_info, global_epoch_idx=i, batches_per_epoch=1000, optimizer=adam_optimizer) reinforcer.train_epoch(epoch_info) training_info.on_train_end() if __name__ == '__main__': half_cheetah_ddpg()
def run(self): """ Run the command with supplied configuration """ device = torch.device(self.model_config.device) learner = Learner(device, self.model.instantiate()) lr_schedule = interp.interpolate_series(self.start_lr, self.end_lr, self.num_it, self.interpolation) if self.freeze: learner.model.freeze() # Optimizer shoudl be created after freeze optimizer = self.optimizer_factory.instantiate(learner.model) iterator = iter(self.source.train_loader()) # Metrics to track through this training metrics = learner.metrics() + [AveragingNamedMetric("lr")] learner.train() best_value = None training_info = TrainingInfo(start_epoch_idx=0, metrics=metrics) # Treat it all as one epoch epoch_info = EpochInfo(training_info, global_epoch_idx=1, batches_per_epoch=1, optimizer=optimizer) for iteration_idx, lr in enumerate(tqdm.tqdm(lr_schedule)): batch_info = BatchInfo(epoch_info, iteration_idx) # First, set the learning rate, the same for each parameter group for param_group in optimizer.param_groups: param_group['lr'] = lr try: data, target = next(iterator) except StopIteration: iterator = iter(self.source.train_loader()) data, target = next(iterator) learner.train_batch(batch_info, data, target) batch_info['lr'] = lr # METRIC RECORDING PART epoch_info.result_accumulator.calculate(batch_info) current_value = epoch_info.result_accumulator.intermediate_value( self.metric) final_metrics = { 'epoch_idx': iteration_idx, self.metric: current_value, 'lr': lr } if best_value is None or current_value < best_value: best_value = current_value # Stop on divergence if self.stop_dv and (np.isnan(current_value) or current_value > best_value * self.divergence_threshold): break training_info.history.add(final_metrics) frame = training_info.history.frame() fig, ax = plt.subplots(1, 2) ax[0].plot(frame.index, frame.lr) ax[0].set_title("LR Schedule") ax[0].set_xlabel("Num iterations") ax[0].set_ylabel("Learning rate") if self.interpolation == 'logscale': ax[0].set_yscale("log", nonposy='clip') ax[1].plot(frame.lr, frame[self.metric], label=self.metric) # ax[1].plot(frame.lr, frame[self.metric].ewm(com=20).mean(), label=self.metric + ' smooth') ax[1].set_title(self.metric) ax[1].set_xlabel("Learning rate") ax[1].set_ylabel(self.metric) # ax[1].legend() if self.interpolation == 'logscale': ax[1].set_xscale("log", nonposx='clip') plt.show()