def update(self, num_training_epochs=10, batch_size=128, verbose=False): """Trains the neural net. Randomly sampls data from the replay buffer. An update resets the optimizer state. Args: num_training_epochs: An epoch represents one pass over the training data. The total number training iterations this corresponds to is num_training_epochs * len(replay_buffer)/batch_size. batch_size: the number of examples sampled from the replay buffer and used for each net training iteration. verbose: whether to print training metrics during training. Returns: A list of length num_training_epochs. Each element of this list is a Losses tuples, averaged per epoch. """ num_epoch_iters = math.ceil( len(self.replay_buffer) / float(batch_size)) losses = [] for epoch in range(num_training_epochs): epoch_losses = [] for _ in range(num_epoch_iters): train_data = self.replay_buffer.sample(batch_size) epoch_losses.append(self.model.update(train_data)) epoch_losses = (sum(epoch_losses, model_lib.Losses(0, 0, 0)) / len(epoch_losses)) losses.append(epoch_losses) if verbose: print("Epoch {}: {}".format(epoch, epoch_losses)) return losses
def learn(step): """Sample from the replay buffer, update weights and save a checkpoint.""" losses = [] for _ in range(len(replay_buffer) // config.train_batch_size): data = replay_buffer.sample(config.train_batch_size) losses.append(model.update(data)) # Always save a checkpoint, either for keeping or for loading the weights to # the actors. It only allows numbers, so use -1 as "latest". save_path = model.save_checkpoint(step if step % config.checkpoint_freq == 0 else -1) losses = sum(losses, model_lib.Losses(0, 0, 0)) / len(losses) logger.print("Step {}: {}, path: {}".format(step, losses, save_path)) return save_path
def learn(step, replay_buffer, model, config_learn, model_num): """Sample from the replay buffer, update weights and save a checkpoint.""" losses = [] mpv_upd = Buffer(len(replay_buffer) / 3) for i in range(len(replay_buffer) // config_learn.train_batch_size): data = replay_buffer.sample(config_learn.train_batch_size) losses.append(model.update(data)) # weight update if (i + 1) % 4 == 0: mpv_upd.append_buffer( data) # replay buffer sample for bigger n/w # Always save a checkpoint, either for keeping or for loading the weights to # the actors. It only allows numbers, so use -1 as "latest". save_path = model.save_checkpoint( step if step % config_learn.checkpoint_freq == 0 else -1) losses = sum(losses, model_lib.Losses(0, 0, 0)) / len(losses) logger.print(losses) logger.print("Checkpoint saved:", save_path) if model_num == 1: return save_path, losses, mpv_upd else: return save_path, losses