def __init__( self, train_dataset, test_dataset, model, batch_size=128, beta=0.5, beta_schedule=None, lr=1e-3, extra_recon_logging=dict(), recon_weights=None, recon_loss_type='mse', **kwargs ): assert recon_loss_type in ['mse', 'wse'] self.batch_size = batch_size self.beta = beta self.beta_schedule = beta_schedule if self.beta_schedule is None: self.beta_schedule = ConstantSchedule(self.beta) if ptu.gpu_enabled(): model.cuda() self.model = model self.representation_size = model.representation_size self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.train_dataset, self.test_dataset = train_dataset, test_dataset assert self.train_dataset['next_obs'].dtype == np.float32 assert self.test_dataset['next_obs'].dtype ==np.float32 assert self.train_dataset['obs'].dtype == np.float32 assert self.test_dataset['obs'].dtype == np.float32 self.normalize = model.normalize self.mse = nn.MSELoss() if self.normalize: self.train_data_mean = ptu.np_to_var(np.mean(self.train_dataset['next_obs'], axis=0)) np_std = np.std(self.train_dataset['next_obs'], axis=0) for i in range(len(np_std)): if np_std[i] < 1e-3: np_std[i] = 1.0 self.train_data_std = ptu.np_to_var(np_std) self.model.train_data_mean = self.train_data_mean self.model.train_data_std = self.train_data_std self.extra_recon_logging = extra_recon_logging self.recon_weights = recon_weights self.recon_loss_type = recon_loss_type
def __init__( self, model, batch_size=128, log_interval=0, beta=0.5, beta_schedule=None, lr=None, weight_decay=0, ): self.model = model self.log_interval = log_interval self.batch_size = batch_size self.beta = beta if lr is None: if is_auto_encoder: lr = 1e-2 else: lr = 1e-3 self.beta_schedule = beta_schedule if self.beta_schedule is None or is_auto_encoder: self.beta_schedule = ConstantSchedule(self.beta) self.imsize = model.imsize model.to(ptu.device) self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = self.imsize * self.imsize * self.input_channels self.lr = lr params = list(self.model.parameters()) self.optimizer = optim.Adam( params, lr=self.lr, weight_decay=weight_decay, ) self.eval_statistics = {}
class VAETrainer(): def __init__( self, train_dataset, test_dataset, model, batch_size=128, beta=0.5, beta_schedule=None, lr=1e-3, extra_recon_logging=dict(), recon_weights=None, recon_loss_type='mse', **kwargs ): assert recon_loss_type in ['mse', 'wse'] self.batch_size = batch_size self.beta = beta self.beta_schedule = beta_schedule if self.beta_schedule is None: self.beta_schedule = ConstantSchedule(self.beta) if ptu.gpu_enabled(): model.cuda() self.model = model self.representation_size = model.representation_size self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.train_dataset, self.test_dataset = train_dataset, test_dataset assert self.train_dataset['next_obs'].dtype == np.float32 assert self.test_dataset['next_obs'].dtype ==np.float32 assert self.train_dataset['obs'].dtype == np.float32 assert self.test_dataset['obs'].dtype == np.float32 self.normalize = model.normalize self.mse = nn.MSELoss() if self.normalize: self.train_data_mean = ptu.np_to_var(np.mean(self.train_dataset['next_obs'], axis=0)) np_std = np.std(self.train_dataset['next_obs'], axis=0) for i in range(len(np_std)): if np_std[i] < 1e-3: np_std[i] = 1.0 self.train_data_std = ptu.np_to_var(np_std) self.model.train_data_mean = self.train_data_mean self.model.train_data_std = self.train_data_std self.extra_recon_logging = extra_recon_logging self.recon_weights = recon_weights self.recon_loss_type = recon_loss_type def get_batch(self, train=True): dataset = self.train_dataset if train else self.test_dataset ind = np.random.randint(0, len(dataset['obs']), self.batch_size) samples_obs = dataset['obs'][ind, :] samples_actions = dataset['actions'][ind, :] samples_next_obs = dataset['next_obs'][ind, :] return { 'obs': ptu.np_to_var(samples_obs), 'actions': ptu.np_to_var(samples_actions), 'next_obs': ptu.np_to_var(samples_next_obs), } def logprob(self, recon_x, x, normalize=None, idx=None, unorm_weights=None): if normalize is None: normalize = self.normalize if normalize: x = (x - self.train_data_mean) / self.train_data_std recon_x = (recon_x - self.train_data_mean) / self.train_data_std if idx is not None: x = x[:,idx] recon_x = recon_x[:,idx] if unorm_weights is not None: unorm_weights = unorm_weights[idx] if unorm_weights is not None: dim = x.shape[1] norm_weights = unorm_weights / (np.sum(unorm_weights) / dim) norm_weights = ptu.np_to_var(norm_weights) recon_x = recon_x * norm_weights x = x * norm_weights return self.mse(recon_x, x) def kl_divergence(self, mu, logvar): kl = - torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean() return kl def train_epoch(self, epoch, batches=100): self.model.train() losses = [] kles = [] mses = [] beta = self.beta_schedule.get_value(epoch) for batch_idx in range(batches): data = self.get_batch() obs = data['obs'] next_obs = data['next_obs'] actions = data['actions'] self.optimizer.zero_grad() recon_batch, mu, logvar = self.model(next_obs) mse = self.logprob(recon_batch, next_obs) kle = self.kl_divergence(mu, logvar) if self.recon_loss_type == 'mse': loss = mse + beta * kle elif self.recon_loss_type == 'wse': wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights) loss = wse + beta * kle loss.backward() losses.append(loss.data[0]) mses.append(mse.data[0]) kles.append(kle.data[0]) self.optimizer.step() logger.record_tabular("train/epoch", epoch) logger.record_tabular("train/MSE", np.mean(mses)) logger.record_tabular("train/KL", np.mean(kles)) logger.record_tabular("train/loss", np.mean(losses)) def test_epoch(self, epoch, save_vae=True, **kwargs): self.model.eval() losses = [] kles = [] zs = [] recon_logging_dict = { 'MSE': [], 'WSE': [], } for k in self.extra_recon_logging: recon_logging_dict[k] = [] beta = self.beta_schedule.get_value(epoch) for batch_idx in range(100): data = self.get_batch(train=False) obs = data['obs'] next_obs = data['next_obs'] actions = data['actions'] recon_batch, mu, logvar = self.model(next_obs) mse = self.logprob(recon_batch, next_obs) wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights) for k, idx in self.extra_recon_logging.items(): recon_loss = self.logprob(recon_batch, next_obs, idx=idx) recon_logging_dict[k].append(recon_loss.data[0]) kle = self.kl_divergence(mu, logvar) if self.recon_loss_type == 'mse': loss = mse + beta * kle elif self.recon_loss_type == 'wse': loss = wse + beta * kle z_data = ptu.get_numpy(mu.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) losses.append(loss.data[0]) recon_logging_dict['WSE'].append(wse.data[0]) recon_logging_dict['MSE'].append(mse.data[0]) kles.append(kle.data[0]) zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) for k in recon_logging_dict: logger.record_tabular("/".join(["test", k]), np.mean(recon_logging_dict[k])) logger.record_tabular("test/KL", np.mean(kles)) logger.record_tabular("test/loss", np.mean(losses)) logger.record_tabular("beta", beta) process = psutil.Process(os.getpid()) logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000)) num_active_dims = 0 for std in self.model.dist_std: if std > 0.15: num_active_dims += 1 logger.record_tabular("num_active_dims", num_active_dims) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True) # slow...
def __init__( self, max_tau=10, epoch_max_tau_schedule=None, sample_train_goals_from='replay_buffer', sample_rollout_goals_from='environment', cycle_taus_for_rollout=False, ): """ :param max_tau: Maximum tau (planning horizon) to train with. :param epoch_max_tau_schedule: A schedule for the maximum planning horizon tau. :param sample_train_goals_from: Sampling strategy for goals used in training. Can be one of the following strings: - environment: Sample from the environment - replay_buffer: Sample from the replay_buffer - her: Sample from a HER-based replay_buffer :param sample_rollout_goals_from: Sampling strategy for goals used during rollout. Can be one of the following strings: - environment: Sample from the environment - replay_buffer: Sample from the replay_buffer - fixed: Do no resample the goal. Just use the one in the environment. :param vectorized: Train the QF in vectorized form? :param cycle_taus_for_rollout: Decrement the tau passed into the policy during rollout? """ assert sample_train_goals_from in [ 'environment', 'replay_buffer', 'her' ] assert sample_rollout_goals_from in [ 'environment', 'replay_buffer', 'fixed' ] if epoch_max_tau_schedule is None: epoch_max_tau_schedule = ConstantSchedule(max_tau) self.max_tau = max_tau self.epoch_max_tau_schedule = epoch_max_tau_schedule self.sample_train_goals_from = sample_train_goals_from self.sample_rollout_goals_from = sample_rollout_goals_from self.cycle_taus_for_rollout = cycle_taus_for_rollout self._current_path_goal = None self._rollout_tau = self.max_tau self.policy = MakeUniversal(self.policy) self.eval_policy = MakeUniversal(self.eval_policy) self.exploration_policy = MakeUniversal(self.exploration_policy) self.eval_sampler = MultigoalSimplePathSampler( env=self.env, policy=self.eval_policy, max_samples=self.num_steps_per_eval, max_path_length=self.max_path_length, discount_sampling_function=self._sample_max_tau_for_rollout, goal_sampling_function=self._sample_goal_for_rollout, cycle_taus_for_rollout=self.cycle_taus_for_rollout, ) if self.collection_mode == 'online-parallel': # TODO(murtaza): What happens to the eval env? # see `eval_sampler` definition above. self.training_env = RemoteRolloutEnv( env=self.env, policy=self.eval_policy, exploration_policy=self.exploration_policy, max_path_length=self.max_path_length, normalize_env=self.normalize_env, rollout_function=self.rollout, )
def __init__( self, max_tau=10, epoch_max_tau_schedule=None, vectorized=True, cycle_taus_for_rollout=True, dense_rewards=False, finite_horizon=True, tau_sample_strategy='uniform', goal_reached_epsilon=1e-3, terminate_when_goal_reached=False, truncated_geom_factor=2., square_distance=False, goal_weights=None, normalize_distance=False, observation_key=None, desired_goal_key=None, ): """ :param max_tau: Maximum tau (planning horizon) to train with. :param epoch_max_tau_schedule: A schedule for the maximum planning horizon tau. :param vectorized: Train the QF in vectorized form? :param cycle_taus_for_rollout: Decrement the tau passed into the policy during rollout? :param dense_rewards: If True, always give rewards. Otherwise, only give rewards when the episode terminates. :param finite_horizon: If True, use a finite horizon formulation: give the time as input to the Q-function and terminate. :param tau_sample_strategy: Sampling strategy for taus used during training. Can be one of the following strings: - no_resampling: Do not resample the tau. Use the one from rollout. - uniform: Sample uniformly from [0, max_tau] - truncated_geometric: Sample from a truncated geometric distribution, truncated at max_tau. - all_valid: Always use all 0 to max_tau values :param goal_reached_epsilon: Epsilon used to determine if the goal has been reached. Used by `indicator` version of `reward_type` and when `terminate_whe_goal_reached` is True. :param terminate_when_goal_reached: Do you terminate when you have reached the goal? :param goal_weights: None or the weights for the different goal dimensions. These weights are used to compute the distances to the goal. """ assert tau_sample_strategy in [ 'no_resampling', 'uniform', 'truncated_geometric', 'all_valid', ] if epoch_max_tau_schedule is None: epoch_max_tau_schedule = ConstantSchedule(max_tau) if not finite_horizon: max_tau = 0 epoch_max_tau_schedule = ConstantSchedule(max_tau) cycle_taus_for_rollout = False self.max_tau = max_tau self.epoch_max_tau_schedule = epoch_max_tau_schedule self.vectorized = vectorized self.cycle_taus_for_rollout = cycle_taus_for_rollout self.dense_rewards = dense_rewards self.finite_horizon = finite_horizon self.tau_sample_strategy = tau_sample_strategy self.goal_reached_epsilon = goal_reached_epsilon self.terminate_when_goal_reached = terminate_when_goal_reached self.square_distance = square_distance self._rollout_tau = np.array([self.max_tau]) self.truncated_geom_factor = float(truncated_geom_factor) self.goal_weights = goal_weights if self.goal_weights is not None: # In case they were passed in as (e.g.) tuples or list self.goal_weights = np.array(self.goal_weights) assert self.goal_weights.size == self.env.goal_dim self.normalize_distance = normalize_distance self.observation_key = observation_key self.desired_goal_key = desired_goal_key self.eval_sampler = MultigoalSimplePathSampler( env=self.env, policy=self.eval_policy, max_samples=self.num_steps_per_eval, max_path_length=self.max_path_length, tau_sampling_function=self._sample_max_tau_for_rollout, cycle_taus_for_rollout=self.cycle_taus_for_rollout, render=self.render_during_eval, observation_key=self.observation_key, desired_goal_key=self.desired_goal_key, ) self.pretrain_obs = None
def __init__( self, env, policy, qf1, qf2, target_qf1, target_qf2, buffer_policy=None, discount=0.99, reward_scale=1.0, beta=1.0, beta_schedule_kwargs=None, policy_lr=1e-3, qf_lr=1e-3, policy_weight_decay=0, q_weight_decay=0, optimizer_class=optim.Adam, soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, use_automatic_entropy_tuning=True, target_entropy=None, bc_num_pretrain_steps=0, q_num_pretrain1_steps=0, q_num_pretrain2_steps=0, bc_batch_size=128, bc_loss_type="mle", awr_loss_type="mle", save_bc_policies=0, alpha=1.0, policy_update_period=1, q_update_period=1, weight_loss=True, compute_bc=True, bc_weight=0.0, rl_weight=1.0, use_awr_update=True, use_reparam_update=False, reparam_weight=1.0, awr_weight=1.0, post_pretrain_hyperparams=None, post_bc_pretrain_hyperparams=None, awr_use_mle_for_vf=False, awr_sample_actions=False, awr_min_q=False, reward_transform_class=None, reward_transform_kwargs=None, terminal_transform_class=None, terminal_transform_kwargs=None, pretraining_env_logging_period=100000, pretraining_logging_period=1000, do_pretrain_rollouts=False, train_bc_on_rl_buffer=False, use_automatic_beta_tuning=False, beta_epsilon=1e-10, ): super().__init__() self.env = env self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.target_qf1 = target_qf1 self.target_qf2 = target_qf2 self.buffer_policy = buffer_policy self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.use_awr_update = use_awr_update self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=policy_lr, ) self.awr_use_mle_for_vf = awr_use_mle_for_vf self.awr_sample_actions = awr_sample_actions self.awr_min_q = awr_min_q self.plotter = plotter self.render_eval_paths = render_eval_paths self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class( self.policy.parameters(), weight_decay=policy_weight_decay, lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self.qf1.parameters(), weight_decay=q_weight_decay, lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self.qf2.parameters(), weight_decay=q_weight_decay, lr=qf_lr, ) if buffer_policy and train_bc_on_rl_buffer: self.buffer_policy_optimizer = optimizer_class( self.buffer_policy.parameters(), weight_decay=policy_weight_decay, lr=policy_lr, ) self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer self.beta_epsilon=beta_epsilon if self.use_automatic_beta_tuning: self.log_beta = ptu.zeros(1, requires_grad=True) self.beta_optimizer = optimizer_class( [self.log_beta], lr=policy_lr, ) else: self.beta = beta self.beta_schedule_kwargs = beta_schedule_kwargs if beta_schedule_kwargs is None: self.beta_schedule = ConstantSchedule(beta) else: schedule_class = beta_schedule_kwargs.pop("schedule_class", PiecewiseLinearSchedule) self.beta_schedule = schedule_class(**beta_schedule_kwargs) self.discount = discount self.reward_scale = reward_scale self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True self.bc_num_pretrain_steps = bc_num_pretrain_steps self.q_num_pretrain1_steps = q_num_pretrain1_steps self.q_num_pretrain2_steps = q_num_pretrain2_steps self.bc_batch_size = bc_batch_size self.bc_loss_type = bc_loss_type self.awr_loss_type = awr_loss_type self.rl_weight = rl_weight self.bc_weight = bc_weight self.save_bc_policies = save_bc_policies self.eval_policy = MakeDeterministic(self.policy) self.compute_bc = compute_bc self.alpha = alpha self.q_update_period = q_update_period self.policy_update_period = policy_update_period self.weight_loss = weight_loss self.reparam_weight = reparam_weight self.awr_weight = awr_weight self.post_pretrain_hyperparams = post_pretrain_hyperparams self.post_bc_pretrain_hyperparams = post_bc_pretrain_hyperparams self.update_policy = True self.pretraining_env_logging_period = pretraining_env_logging_period self.pretraining_logging_period = pretraining_logging_period self.do_pretrain_rollouts = do_pretrain_rollouts self.reward_transform_class = reward_transform_class or LinearTransform self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1, b=0) self.terminal_transform_class = terminal_transform_class or LinearTransform self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1, b=0) self.reward_transform = self.reward_transform_class(**self.reward_transform_kwargs) self.terminal_transform = self.terminal_transform_class(**self.terminal_transform_kwargs) self.use_reparam_update = use_reparam_update self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy
class AWRSACTrainer(TorchTrainer): def __init__( self, env, policy, qf1, qf2, target_qf1, target_qf2, buffer_policy=None, discount=0.99, reward_scale=1.0, beta=1.0, beta_schedule_kwargs=None, policy_lr=1e-3, qf_lr=1e-3, policy_weight_decay=0, q_weight_decay=0, optimizer_class=optim.Adam, soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, use_automatic_entropy_tuning=True, target_entropy=None, bc_num_pretrain_steps=0, q_num_pretrain1_steps=0, q_num_pretrain2_steps=0, bc_batch_size=128, bc_loss_type="mle", awr_loss_type="mle", save_bc_policies=0, alpha=1.0, policy_update_period=1, q_update_period=1, weight_loss=True, compute_bc=True, bc_weight=0.0, rl_weight=1.0, use_awr_update=True, use_reparam_update=False, reparam_weight=1.0, awr_weight=1.0, post_pretrain_hyperparams=None, post_bc_pretrain_hyperparams=None, awr_use_mle_for_vf=False, awr_sample_actions=False, awr_min_q=False, reward_transform_class=None, reward_transform_kwargs=None, terminal_transform_class=None, terminal_transform_kwargs=None, pretraining_env_logging_period=100000, pretraining_logging_period=1000, do_pretrain_rollouts=False, train_bc_on_rl_buffer=False, use_automatic_beta_tuning=False, beta_epsilon=1e-10, ): super().__init__() self.env = env self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.target_qf1 = target_qf1 self.target_qf2 = target_qf2 self.buffer_policy = buffer_policy self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.use_awr_update = use_awr_update self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=policy_lr, ) self.awr_use_mle_for_vf = awr_use_mle_for_vf self.awr_sample_actions = awr_sample_actions self.awr_min_q = awr_min_q self.plotter = plotter self.render_eval_paths = render_eval_paths self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class( self.policy.parameters(), weight_decay=policy_weight_decay, lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self.qf1.parameters(), weight_decay=q_weight_decay, lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self.qf2.parameters(), weight_decay=q_weight_decay, lr=qf_lr, ) if buffer_policy and train_bc_on_rl_buffer: self.buffer_policy_optimizer = optimizer_class( self.buffer_policy.parameters(), weight_decay=policy_weight_decay, lr=policy_lr, ) self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer self.beta_epsilon=beta_epsilon if self.use_automatic_beta_tuning: self.log_beta = ptu.zeros(1, requires_grad=True) self.beta_optimizer = optimizer_class( [self.log_beta], lr=policy_lr, ) else: self.beta = beta self.beta_schedule_kwargs = beta_schedule_kwargs if beta_schedule_kwargs is None: self.beta_schedule = ConstantSchedule(beta) else: schedule_class = beta_schedule_kwargs.pop("schedule_class", PiecewiseLinearSchedule) self.beta_schedule = schedule_class(**beta_schedule_kwargs) self.discount = discount self.reward_scale = reward_scale self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True self.bc_num_pretrain_steps = bc_num_pretrain_steps self.q_num_pretrain1_steps = q_num_pretrain1_steps self.q_num_pretrain2_steps = q_num_pretrain2_steps self.bc_batch_size = bc_batch_size self.bc_loss_type = bc_loss_type self.awr_loss_type = awr_loss_type self.rl_weight = rl_weight self.bc_weight = bc_weight self.save_bc_policies = save_bc_policies self.eval_policy = MakeDeterministic(self.policy) self.compute_bc = compute_bc self.alpha = alpha self.q_update_period = q_update_period self.policy_update_period = policy_update_period self.weight_loss = weight_loss self.reparam_weight = reparam_weight self.awr_weight = awr_weight self.post_pretrain_hyperparams = post_pretrain_hyperparams self.post_bc_pretrain_hyperparams = post_bc_pretrain_hyperparams self.update_policy = True self.pretraining_env_logging_period = pretraining_env_logging_period self.pretraining_logging_period = pretraining_logging_period self.do_pretrain_rollouts = do_pretrain_rollouts self.reward_transform_class = reward_transform_class or LinearTransform self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1, b=0) self.terminal_transform_class = terminal_transform_class or LinearTransform self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1, b=0) self.reward_transform = self.reward_transform_class(**self.reward_transform_kwargs) self.terminal_transform = self.terminal_transform_class(**self.terminal_transform_kwargs) self.use_reparam_update = use_reparam_update self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy def get_batch_from_buffer(self, replay_buffer, batch_size): batch = replay_buffer.random_batch(batch_size) batch = np_to_pytorch_batch(batch) return batch def run_bc_batch(self, replay_buffer, policy): batch = self.get_batch_from_buffer(replay_buffer, self.bc_batch_size) o = batch["observations"] u = batch["actions"] # g = batch["resampled_goals"] # og = torch.cat((o, g), dim=1) og = o # pred_u, *_ = self.policy(og) pred_u, policy_mean, policy_log_std, log_pi, entropy, policy_std, mean_action_log_prob, pretanh_value, dist = policy( og, deterministic=False, reparameterize=True, return_log_prob=True, ) mse = (policy_mean - u) ** 2 mse_loss = mse.mean() policy_logpp = dist.log_prob(u, ) logp_loss = -policy_logpp.mean() # T = 0 if self.bc_loss_type == "mle": policy_loss = logp_loss elif self.bc_loss_type == "mse": policy_loss = mse_loss else: error return policy_loss, logp_loss, mse_loss, policy_log_std def do_rollouts(self): total_ret = 0 for _ in range(20): o = self.env.reset() ret = 0 for _ in range(1000): a, _ = self.policy.get_action(o) o, r, done, info = self.env.step(a) ret += r if done: break total_ret += ret return total_ret def pretrain_policy_with_bc(self): logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True ) logger.add_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True ) if self.do_pretrain_rollouts: total_ret = self.do_rollouts() print("INITIAL RETURN", total_ret/20) prev_time = time.time() for i in range(self.bc_num_pretrain_steps): train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(self.demo_train_buffer, self.policy) train_policy_loss = train_policy_loss * self.bc_weight self.policy_optimizer.zero_grad() train_policy_loss.backward() self.policy_optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(self.demo_test_buffer, self.policy) test_policy_loss = test_policy_loss * self.bc_weight if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret/20)) if i % self.pretraining_logging_period==0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time":time.time()-prev_time, } if self.do_pretrain_rollouts: stats["pretrain_bc/avg_return"] = total_ret / 20 logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams) def pretrain_q_with_bc_data(self): logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True ) logger.add_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True ) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain1_steps): self.eval_statistics = dict() train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data['next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) if i%self.pretraining_logging_period == 0: logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.q_num_pretrain2_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics=True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data['next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret/20)) if i%self.pretraining_logging_period==0: if self.do_pretrain_rollouts: self.eval_statistics["pretrain_bc/avg_return"] = total_ret / 20 self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time()-prev_time logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict() if self.post_pretrain_hyperparams: self.set_algorithm_weights(**self.post_pretrain_hyperparams) def set_algorithm_weights( self, # bc_weight, # rl_weight, # use_awr_update, # use_reparam_update, # reparam_weight, # awr_weight, **kwargs ): for key in kwargs: self.__dict__[key] = kwargs[key] # self.bc_weight = bc_weight # self.rl_weight = rl_weight # self.use_awr_update = use_awr_update # self.use_reparam_update = use_reparam_update # self.awr_weight = awr_weight def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch.get('weights', None) if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, entropy, policy_std, mean_action_log_prob, pretanh_value, dist = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions) qf2_new_actions = self.qf2(obs, new_obs_actions) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.awr_use_mle_for_vf: v_pi = self.qf1(obs, policy_mean) else: v_pi = self.qf1(obs, new_obs_actions) if self.awr_sample_actions: u = new_obs_actions if self.awr_min_q: q_adv = q_new_actions else: q_adv = qf1_new_actions else: u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred if self.awr_loss_type == "mse": policy_logpp = -(policy_mean - actions) ** 2 else: policy_logpp = dist.log_prob(u) policy_logpp = policy_logpp.sum(dim=1, keepdim=True) advantage = q_adv - v_pi if self.weight_loss and weights is None: if self.use_automatic_beta_tuning: _, _, _, _, _, _, _, _, buffer_dist = self.buffer_policy( obs, reparameterize=True, return_log_prob=True, ) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1*(beta*(kldiv-self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) weights = F.softmax(advantage / beta, dim=0) policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * (-policy_logpp * len(weights)*weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * (-policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.reparam_weight * (-q_new_actions).mean() policy_loss = self.rl_weight * policy_loss if self.compute_bc: train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(self.demo_train_buffer, self.policy) policy_loss = policy_loss + self.bc_weight * train_policy_loss if self.train_bc_on_rl_buffer: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(self.replay_buffer, self.buffer_policy) """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0 : self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward() self.buffer_policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) self.eval_statistics.update(create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.compute_bc: test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(self.demo_test_buffer, self.policy) self.eval_statistics.update({ "bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "bc/Train MSE": ptu.get_numpy(train_mse_loss), "bc/Test MSE": ptu.get_numpy(test_mse_loss), "bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "bc/test_policy_loss": ptu.get_numpy(test_policy_loss), }) if self.train_bc_on_rl_buffer: test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(self.replay_buffer, self.buffer_policy) _, _, _, _, _, _, _, _, buffer_dist = self.buffer_policy( obs, reparameterize=True, return_log_prob=True, ) kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) self.eval_statistics.update({ "buffer_policy/Train Logprob Loss": ptu.get_numpy(buffer_train_logp_loss), "buffer_policy/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "buffer_policy/Train MSE": ptu.get_numpy(buffer_train_mse_loss), "buffer_policy/Test MSE": ptu.get_numpy(test_mse_loss), "buffer_policy/train_policy_loss": ptu.get_numpy(buffer_policy_loss), "buffer_policy/test_policy_loss": ptu.get_numpy(test_policy_loss), "buffer_policy/kl_div":ptu.get_numpy(kldiv.mean()), }) if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta":ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) self._n_train_steps_total += 1 def get_diagnostics(self): stats = super().get_diagnostics() stats.update(self.eval_statistics) return stats def end_epoch(self, epoch): self._need_to_update_eval_statistics = True @property def networks(self): nets = [ self.policy, self.qf1, self.qf2, self.target_qf1, self.target_qf2, ] if self.buffer_policy: nets.append(self.buffer_policy) return nets def get_snapshot(self): return dict( policy=self.policy, qf1=self.qf1, qf2=self.qf2, target_qf1=self.qf1, target_qf2=self.qf2, buffer_policy=self.buffer_policy, )
def __init__( self, model, log_interval=0, beta=0.5, beta_schedule=None, lr=1e-3, do_scatterplot=False, normalize=False, mse_weight=0.1, is_auto_encoder=False, background_subtract=False, linearity_weight=0.0, distance_weight=0.0, loss_weights=None, use_linear_dynamics=False, use_parallel_dataloading=False, train_data_workers=2, skew_dataset=False, skew_config=None, priority_function_kwargs=None, start_skew_epoch=0, weight_decay=0, batch_size=64, ): #TODO:steven fix pickling assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first" if skew_config is None: skew_config = {} self.log_interval = log_interval self.beta = beta if is_auto_encoder: self.beta = 0 self.beta_schedule = beta_schedule if self.beta_schedule is None or is_auto_encoder: self.beta_schedule = ConstantSchedule(self.beta) self.do_scatterplot = do_scatterplot model.to(ptu.device) self.model = model self.lr = lr params = list(self.model.parameters()) self.optimizer = optim.Adam( params, lr=self.lr, weight_decay=weight_decay, ) self.batch_size = batch_size self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers self.skew_dataset = skew_dataset self.skew_config = skew_config self.start_skew_epoch = start_skew_epoch if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs self.normalize = normalize self.mse_weight = mse_weight self.background_subtract = background_subtract self.linearity_weight = linearity_weight self.distance_weight = distance_weight self.loss_weights = loss_weights self.loss_fn = torch.nn.CrossEntropyLoss() self.log_softmax = torch.nn.LogSoftmax() self.use_linear_dynamics = use_linear_dynamics self._extra_stats_to_log = None # stateful tracking variables, reset every epoch self.eval_statistics = collections.defaultdict(list) self.eval_data = collections.defaultdict(list) self.num_train_batches = 0 self.num_test_batches = 0 self.bin_midpoints = ( torch.arange(0, self.model.output_classes).float() + 0.5) / self.model.output_classes self.bin_midpoints = self.bin_midpoints.to(ptu.device)
def __init__( self, model, batch_size=128, log_interval=0, beta=0.5, beta_schedule=None, lr=None, do_scatterplot=False, normalize=False, mse_weight=0.1, is_auto_encoder=False, background_subtract=False, linearity_weight=0.0, distance_weight=0.0, loss_weights=None, use_linear_dynamics=False, use_parallel_dataloading=False, train_data_workers=2, skew_dataset=False, skew_config=None, priority_function_kwargs=None, start_skew_epoch=0, weight_decay=0, key_to_reconstruct="observations", ): #TODO:steven fix pickling assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first" if skew_config is None: skew_config = {} self.log_interval = log_interval self.batch_size = batch_size self.beta = beta if is_auto_encoder: self.beta = 0 if lr is None: if is_auto_encoder: lr = 1e-2 else: lr = 1e-3 self.beta_schedule = beta_schedule if self.beta_schedule is None or is_auto_encoder: self.beta_schedule = ConstantSchedule(self.beta) self.imsize = model.imsize self.do_scatterplot = do_scatterplot model.to(ptu.device) self.model = model self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = model.imlength self.lr = lr params = list(self.model.parameters()) self.optimizer = optim.Adam( params, lr=self.lr, weight_decay=weight_decay, ) self.key_to_reconstruct = key_to_reconstruct self.batch_size = batch_size self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers self.skew_dataset = skew_dataset self.skew_config = skew_config self.start_skew_epoch = start_skew_epoch if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if use_parallel_dataloading: self.train_dataset_pt = ImageDataset(train_dataset, should_normalize=True) self.test_dataset_pt = ImageDataset(test_dataset, should_normalize=True) if self.skew_dataset: base_sampler = InfiniteWeightedRandomSampler( self.train_dataset, self._train_weights) else: base_sampler = InfiniteRandomSampler(self.train_dataset) self.train_dataloader = DataLoader( self.train_dataset_pt, sampler=InfiniteRandomSampler(self.train_dataset), batch_size=batch_size, drop_last=False, num_workers=train_data_workers, pin_memory=True, ) self.test_dataloader = DataLoader( self.test_dataset_pt, sampler=InfiniteRandomSampler(self.test_dataset), batch_size=batch_size, drop_last=False, num_workers=0, pin_memory=True, ) self.train_dataloader = iter(self.train_dataloader) self.test_dataloader = iter(self.test_dataloader) self.normalize = normalize self.mse_weight = mse_weight self.background_subtract = background_subtract if self.normalize or self.background_subtract: self.train_data_mean = np.mean(self.train_dataset, axis=0) self.train_data_mean = normalize_image( np.uint8(self.train_data_mean)) self.linearity_weight = linearity_weight self.distance_weight = distance_weight self.loss_weights = loss_weights self.use_linear_dynamics = use_linear_dynamics self._extra_stats_to_log = None # stateful tracking variables, reset every epoch self.eval_statistics = collections.defaultdict(list) self.eval_data = collections.defaultdict(list)
class VAETrainer(LossFunction): def __init__( self, model, batch_size=128, log_interval=0, beta=0.5, beta_schedule=None, lr=None, do_scatterplot=False, normalize=False, mse_weight=0.1, is_auto_encoder=False, background_subtract=False, linearity_weight=0.0, distance_weight=0.0, loss_weights=None, use_linear_dynamics=False, use_parallel_dataloading=False, train_data_workers=2, skew_dataset=False, skew_config=None, priority_function_kwargs=None, start_skew_epoch=0, weight_decay=0, key_to_reconstruct="observations", ): #TODO:steven fix pickling assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first" if skew_config is None: skew_config = {} self.log_interval = log_interval self.batch_size = batch_size self.beta = beta if is_auto_encoder: self.beta = 0 if lr is None: if is_auto_encoder: lr = 1e-2 else: lr = 1e-3 self.beta_schedule = beta_schedule if self.beta_schedule is None or is_auto_encoder: self.beta_schedule = ConstantSchedule(self.beta) self.imsize = model.imsize self.do_scatterplot = do_scatterplot model.to(ptu.device) self.model = model self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = model.imlength self.lr = lr params = list(self.model.parameters()) self.optimizer = optim.Adam( params, lr=self.lr, weight_decay=weight_decay, ) self.key_to_reconstruct = key_to_reconstruct self.batch_size = batch_size self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers self.skew_dataset = skew_dataset self.skew_config = skew_config self.start_skew_epoch = start_skew_epoch if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if use_parallel_dataloading: self.train_dataset_pt = ImageDataset(train_dataset, should_normalize=True) self.test_dataset_pt = ImageDataset(test_dataset, should_normalize=True) if self.skew_dataset: base_sampler = InfiniteWeightedRandomSampler( self.train_dataset, self._train_weights) else: base_sampler = InfiniteRandomSampler(self.train_dataset) self.train_dataloader = DataLoader( self.train_dataset_pt, sampler=InfiniteRandomSampler(self.train_dataset), batch_size=batch_size, drop_last=False, num_workers=train_data_workers, pin_memory=True, ) self.test_dataloader = DataLoader( self.test_dataset_pt, sampler=InfiniteRandomSampler(self.test_dataset), batch_size=batch_size, drop_last=False, num_workers=0, pin_memory=True, ) self.train_dataloader = iter(self.train_dataloader) self.test_dataloader = iter(self.test_dataloader) self.normalize = normalize self.mse_weight = mse_weight self.background_subtract = background_subtract if self.normalize or self.background_subtract: self.train_data_mean = np.mean(self.train_dataset, axis=0) self.train_data_mean = normalize_image( np.uint8(self.train_data_mean)) self.linearity_weight = linearity_weight self.distance_weight = distance_weight self.loss_weights = loss_weights self.use_linear_dynamics = use_linear_dynamics self._extra_stats_to_log = None # stateful tracking variables, reset every epoch self.eval_statistics = collections.defaultdict(list) self.eval_data = collections.defaultdict(list) @property def log_dir(self): return logger.get_snapshot_dir() def get_dataset_stats(self, data): torch_input = ptu.from_numpy(normalize_image(data)) mus, log_vars = self.model.encode(torch_input) mus = ptu.get_numpy(mus) mean = np.mean(mus, axis=0) std = np.std(mus, axis=0) return mus, mean, std def _kl_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) mu, log_var = self.model.encode(torch_input) return ptu.get_numpy( -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)) def _reconstruction_squared_error_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) recons, *_ = self.model(torch_input) error = torch_input - recons return ptu.get_numpy((error**2).sum(dim=1)) def set_vae(self, vae): self.model = vae self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) def get_batch(self, test_data=False, epoch=None): if self.use_parallel_dataloading: if test_data: dataloader = self.test_dataloader else: dataloader = self.train_dataloader samples = next(dataloader).to(ptu.device) return samples dataset = self.test_dataset if test_data else self.train_dataset skew = False if epoch is not None: skew = (self.start_skew_epoch < epoch) if not test_data and self.skew_dataset and skew: probs = self._train_weights / np.sum(self._train_weights) ind = np.random.choice( len(probs), self.batch_size, p=probs, ) else: ind = np.random.randint(0, len(dataset), self.batch_size) samples = normalize_image(dataset[ind, :]) if self.normalize: samples = ((samples - self.train_data_mean) + 1) / 2 if self.background_subtract: samples = samples - self.train_data_mean return ptu.from_numpy(samples) def get_debug_batch(self, train=True): dataset = self.train_dataset if train else self.test_dataset X, Y = dataset ind = np.random.randint(0, Y.shape[0], self.batch_size) X = X[ind, :] Y = Y[ind, :] return ptu.from_numpy(X), ptu.from_numpy(Y) def train_epoch(self, epoch, dataset, batches=100): start_time = time.time() for b in range(batches): self.train_batch(epoch, dataset.random_batch(self.batch_size)) self.eval_statistics["train/epoch_duration"].append(time.time() - start_time) def test_epoch(self, epoch, dataset, batches=10): start_time = time.time() for b in range(batches): self.test_batch(epoch, dataset.random_batch(self.batch_size)) self.eval_statistics["test/epoch_duration"].append(time.time() - start_time) def compute_loss(self, batch, epoch=-1, test=False): prefix = "test/" if test else "train/" beta = float(self.beta_schedule.get_value(epoch)) obs = batch[self.key_to_reconstruct] reconstructions, obs_distribution_params, latent_distribution_params = self.model( obs) log_prob = self.model.logprob(obs, obs_distribution_params) kle = self.model.kl_divergence(latent_distribution_params) loss = -1 * log_prob + beta * kle self.eval_statistics['epoch'] = epoch self.eval_statistics['beta'] = beta self.eval_statistics[prefix + "losses"].append(loss.item()) self.eval_statistics[prefix + "log_probs"].append(log_prob.item()) self.eval_statistics[prefix + "kles"].append(kle.item()) encoder_mean = self.model.get_encoding_from_latent_distribution_params( latent_distribution_params) z_data = ptu.get_numpy(encoder_mean.cpu()) for i in range(len(z_data)): self.eval_data[prefix + "zs"].append(z_data[i, :]) self.eval_data[prefix + "last_batch"] = (obs, reconstructions) return loss def train_batch(self, epoch, batch): self.model.train() self.optimizer.zero_grad() loss = self.compute_loss(batch, epoch, False) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def test_batch( self, epoch, batch, ): self.model.eval() loss = self.compute_loss(batch, epoch, True) def end_epoch(self, epoch): self.eval_statistics = collections.defaultdict(list) self.test_last_batch = None def get_diagnostics(self): stats = OrderedDict() for k in sorted(self.eval_statistics.keys()): stats[k] = np.mean(self.eval_statistics[k]) return stats def dump_scatterplot(self, z, epoch): try: import matplotlib.pyplot as plt except ImportError: logger.log(__file__ + ": Unable to load matplotlib. Consider " "setting do_scatterplot to False") return dim_and_stds = [(i, np.std(z[:, i])) for i in range(z.shape[1])] dim_and_stds = sorted(dim_and_stds, key=lambda x: x[1]) dim1 = dim_and_stds[-1][0] dim2 = dim_and_stds[-2][0] plt.figure(figsize=(8, 8)) plt.scatter(z[:, dim1], z[:, dim2], marker='o', edgecolor='none') if self.model.dist_mu is not None: x1 = self.model.dist_mu[dim1:dim1 + 1] y1 = self.model.dist_mu[dim2:dim2 + 1] x2 = (self.model.dist_mu[dim1:dim1 + 1] + self.model.dist_std[dim1:dim1 + 1]) y2 = (self.model.dist_mu[dim2:dim2 + 1] + self.model.dist_std[dim2:dim2 + 1]) plt.plot([x1, x2], [y1, y2], color='k', linestyle='-', linewidth=2) axes = plt.gca() axes.set_xlim([-6, 6]) axes.set_ylim([-6, 6]) axes.set_title('dim {} vs dim {}'.format(dim1, dim2)) plt.grid(True) save_file = osp.join(self.log_dir, 'scatter%d.png' % epoch) plt.savefig(save_file)
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant.get("representation_size", variant.get("latent_sizes", None)) use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics variant['generate_vae_dataset_kwargs']['batch_size'] = variant[ 'algo_kwargs']['batch_size'] train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if 'context_schedule' in variant: schedule = variant['context_schedule'] if type(schedule) is dict: context_schedule = PiecewiseLinearSchedule(**schedule) else: context_schedule = ConstantSchedule(schedule) variant['algo_kwargs']['context_schedule'] = context_schedule if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): model = SpatialAutoEncoder( representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) else: vae_class = variant.get('vae_class', ConvVAE) if use_linear_dynamics: model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim, **variant['vae_kwargs']) else: model = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) model.to(ptu.device) vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def __init__( self, train_dataset, test_dataset, model, batch_size=64, beta=0.5, beta_schedule=None, lr=None, linearity_weight=0.0, use_linear_dynamics=False, noisy_linear_dynamics=False, scale_linear_dynamics=False, use_parallel_dataloading=True, train_data_workers=2, ): self.quick_init(locals()) self.batch_size = batch_size self.beta = beta if lr is None: lr = 1e-3 self.beta_schedule = beta_schedule if self.beta_schedule is None: self.beta_schedule = ConstantSchedule(self.beta) self.imsize = model.imsize if ptu.gpu_enabled(): model.cuda() self.model = model self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = model.imlength self.lr = lr self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.train_dataset, self.test_dataset = train_dataset, test_dataset assert self.train_dataset['next_obs'].dtype == np.uint8 assert self.test_dataset['next_obs'].dtype == np.uint8 assert self.train_dataset['obs'].dtype == np.uint8 assert self.test_dataset['obs'].dtype == np.uint8 self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers self.gaussian_decoder_loss = self.model.gaussian_decoder if use_parallel_dataloading: self.train_dataset_pt = ImageDataset(train_dataset, should_normalize=True) self.test_dataset_pt = ImageDataset(test_dataset, should_normalize=True) self._train_weights = None base_sampler = InfiniteRandomSampler(self.train_dataset) self.train_dataloader = DataLoader( self.train_dataset_pt, sampler=BatchSampler( base_sampler, batch_size=batch_size, drop_last=False, ), num_workers=train_data_workers, pin_memory=True, ) self.test_dataloader = DataLoader( self.test_dataset_pt, sampler=BatchSampler( InfiniteRandomSampler(self.test_dataset), batch_size=batch_size, drop_last=False, ), num_workers=0, pin_memory=True, ) self.train_dataloader = iter(self.train_dataloader) self.test_dataloader = iter(self.test_dataloader) self.linearity_weight = linearity_weight self.use_linear_dynamics = use_linear_dynamics self.noisy_linear_dynamics = noisy_linear_dynamics self.scale_linear_dynamics = scale_linear_dynamics self.vae_logger_stats_for_rl = {} self._extra_stats_to_log = None
class ConvVAETrainer(Serializable): def __init__( self, train_dataset, test_dataset, model, batch_size=64, beta=0.5, beta_schedule=None, lr=None, linearity_weight=0.0, use_linear_dynamics=False, noisy_linear_dynamics=False, scale_linear_dynamics=False, use_parallel_dataloading=True, train_data_workers=2, ): self.quick_init(locals()) self.batch_size = batch_size self.beta = beta if lr is None: lr = 1e-3 self.beta_schedule = beta_schedule if self.beta_schedule is None: self.beta_schedule = ConstantSchedule(self.beta) self.imsize = model.imsize if ptu.gpu_enabled(): model.cuda() self.model = model self.representation_size = model.representation_size self.input_channels = model.input_channels self.imlength = model.imlength self.lr = lr self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.train_dataset, self.test_dataset = train_dataset, test_dataset assert self.train_dataset['next_obs'].dtype == np.uint8 assert self.test_dataset['next_obs'].dtype == np.uint8 assert self.train_dataset['obs'].dtype == np.uint8 assert self.test_dataset['obs'].dtype == np.uint8 self.use_parallel_dataloading = use_parallel_dataloading self.train_data_workers = train_data_workers self.gaussian_decoder_loss = self.model.gaussian_decoder if use_parallel_dataloading: self.train_dataset_pt = ImageDataset(train_dataset, should_normalize=True) self.test_dataset_pt = ImageDataset(test_dataset, should_normalize=True) self._train_weights = None base_sampler = InfiniteRandomSampler(self.train_dataset) self.train_dataloader = DataLoader( self.train_dataset_pt, sampler=BatchSampler( base_sampler, batch_size=batch_size, drop_last=False, ), num_workers=train_data_workers, pin_memory=True, ) self.test_dataloader = DataLoader( self.test_dataset_pt, sampler=BatchSampler( InfiniteRandomSampler(self.test_dataset), batch_size=batch_size, drop_last=False, ), num_workers=0, pin_memory=True, ) self.train_dataloader = iter(self.train_dataloader) self.test_dataloader = iter(self.test_dataloader) self.linearity_weight = linearity_weight self.use_linear_dynamics = use_linear_dynamics self.noisy_linear_dynamics = noisy_linear_dynamics self.scale_linear_dynamics = scale_linear_dynamics self.vae_logger_stats_for_rl = {} self._extra_stats_to_log = None def _kl_np_to_np(self, np_imgs): torch_input = ptu.np_to_var(normalize_image(np_imgs)) mu, log_var = self.model.encode(torch_input) return ptu.get_numpy( -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)) def set_vae(self, vae): self.model = vae self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) def get_batch(self, train=True): if self.use_parallel_dataloading: if not train: dataloader = self.test_dataloader else: dataloader = self.train_dataloader samples = next(dataloader) return { 'obs': ptu.Variable(samples[0][0]), 'actions': ptu.Variable(samples[1][0]), 'next_obs': ptu.Variable(samples[2][0]), } dataset = self.train_dataset if train else self.test_dataset ind = np.random.randint(0, len(dataset), self.batch_size) samples = normalize_image(dataset[ind, :]) return ptu.np_to_var(samples) def logprob_iwae(self, recon_x, x): if self.gaussian_decoder_loss: error = -(recon_x - x)**2 else: error = x * torch.log(torch.clamp(recon_x, min=1e-30)) \ + (1-x) * torch.log(torch.clamp(1-recon_x, min=1e-30)) return error def logprob_vae(self, recon_x, x): batch_size = recon_x.shape[0] if self.gaussian_decoder_loss: return -((recon_x - x)**2).sum() / batch_size else: # Divide by batch_size rather than setting size_average=True because # otherwise the averaging will also happen across dimension 1 (the # pixels) return -F.binary_cross_entropy( recon_x, x.narrow(start=0, length=self.imlength, dimension=1).contiguous().view(-1, self.imlength), size_average=False, ) / batch_size def kl_divergence(self, mu, logvar): return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean() def compute_vae_loss(self, x_recon, x, z_mu, z_logvar, z_sampled, beta): batch_size = x_recon.shape[0] k = x_recon.shape[1] x_recon = x_recon.view((batch_size * k, -1)) x = x.view((batch_size * k, -1)) z_mu = z_mu.view((batch_size * k, -1)) z_logvar = z_logvar.view((batch_size * k, -1)) de = -self.logprob_vae(x_recon, x) kle = self.kl_divergence(z_mu, z_logvar) loss = de + beta * kle return loss, de, kle def compute_iwae_loss(self, x_recon, x, z_mu, z_logvar, z_sampled, beta): batch_size = x_recon.shape[0] log_p_xgz = self.logprob_iwae(x_recon, x).sum(dim=-1) prior_dist = torch.distributions.Normal( ptu.Variable(torch.zeros(z_sampled.shape)), ptu.Variable(torch.ones(z_sampled.shape))) log_p_z = prior_dist.log_prob(z_sampled).sum(dim=-1) z_std = torch.exp(0.5 * z_logvar) encoder_dist = torch.distributions.Normal(z_mu, z_std) log_q_zgx = encoder_dist.log_prob(z_sampled).sum(dim=-1) log_w = log_p_xgz + beta * (log_p_z - log_q_zgx) w_tilde = F.softmax(log_w, dim=-1).detach() loss = -(log_w * w_tilde).sum() / batch_size return loss def state_linearity_loss(self, obs, next_obs, actions): latent_obs_mu, latent_obs_logvar = self.model.encode(obs) latent_next_obs_mu, latent_next_obs_logvar = self.model.encode( next_obs) if self.noisy_linear_dynamics: latent_obs = self.model.reparameterize(latent_obs_mu, latent_obs_logvar) else: latent_obs = latent_obs_mu action_obs_pair = torch.cat([latent_obs, actions], dim=1) prediction = self.model.linear_constraint_fc(action_obs_pair) if self.scale_linear_dynamics: std = latent_next_obs_logvar.mul(0.5).exp_() scaling = 1 / std else: scaling = 1.0 return torch.norm( scaling * (prediction - latent_next_obs_mu))**2 / self.batch_size def train_epoch(self, epoch, batches=100): self.model.own_train() vae_losses = [] losses = [] des = [] kles = [] linear_losses = [] beta = float(self.beta_schedule.get_value(epoch)) for batch_idx in range(batches): data = self.get_batch() obs = data['obs'] next_obs = data['next_obs'] actions = data['actions'] self.optimizer.zero_grad() x_recon, z_mu, z_logvar, z = self.model(next_obs) batch_size = x_recon.shape[0] k = x_recon.shape[1] x = next_obs.view((batch_size, 1, -1)).repeat(torch.Size([1, k, 1])) vae_loss, de, kle = self.compute_vae_loss(x_recon, x, z_mu, z_logvar, z, beta) loss = vae_loss '''print("---------------------") print("loss ", loss.data) print("z_mu ", z_mu.data[0][0]) print("z_logvar ", z_logvar.data[0][0]) print("z ", z.data[0][0]) print("---------------------")''' if self.use_linear_dynamics: linear_dynamics_loss = self.state_linearity_loss( obs, next_obs, actions) loss += self.linearity_weight * linear_dynamics_loss linear_losses.append(float(linear_dynamics_loss.data[0])) loss.backward() vae_losses.append(float(vae_loss.data[0])) losses.append(float(loss.data[0])) des.append(float(de.data[0])) kles.append(float(kle.data[0])) self.optimizer.step() del data, obs, next_obs, actions, x_recon, z_mu, z_logvar, \ z, x, vae_loss, de, kle, loss logger.record_tabular("train/epoch", epoch) logger.record_tabular("train/decoder_loss", np.mean(des)) logger.record_tabular("train/KL", np.mean(kles)) if self.use_linear_dynamics: logger.record_tabular("train/linear_loss", np.mean(linear_losses)) logger.record_tabular("train/vae_loss", np.mean(vae_losses)) logger.record_tabular("train/loss", np.mean(losses)) def test_epoch(self, epoch, save_reconstruction=True, save_interpolation=True, save_vae=True): self.model.eval() vae_losses = [] iwae_losses = [] losses = [] des = [] kles = [] linear_losses = [] zs = [] beta = float(self.beta_schedule.get_value(epoch)) for batch_idx in range(10): data = self.get_batch(train=False) obs = data['obs'] obs = obs.detach() next_obs = data['next_obs'] next_obs = next_obs.detach() actions = data['actions'] actions = actions.detach() x_recon, z_mu, z_logvar, z = self.model(next_obs, n_imp=25) x_recon = x_recon.detach() z_mu = z_mu.detach() z = z.detach() batch_size = x_recon.shape[0] k = x_recon.shape[1] x = next_obs.view((batch_size, 1, -1)).repeat(torch.Size([1, k, 1])) x = x.detach() vae_loss, de, kle = self.compute_vae_loss(x_recon, x, z_mu, z_logvar, z, beta) vae_loss, de, kle = vae_loss.detach(), de.detach(), kle.detach() iwae_loss = self.compute_iwae_loss(x_recon, x, z_mu, z_logvar, z, beta) iwae_loss = iwae_loss.detach() loss = vae_loss if self.use_linear_dynamics: linear_dynamics_loss = self.state_linearity_loss( obs, next_obs, actions) linear_dynamics_loss = linear_dynamics_loss.detach() loss += self.linearity_weight * linear_dynamics_loss linear_losses.append(float( linear_dynamics_loss.data[0])) #here too z_data = ptu.get_numpy(z_mu[:, 0].cpu()) for i in range(len(z_data)): zs.append(z_data[i, :].copy()) vae_losses.append(float(vae_loss.data[0])) iwae_losses.append(float(iwae_loss.data[0])) losses.append(float(loss.data[0])) des.append(float(de.data[0])) kles.append(float(kle.data[0])) if batch_idx == 0 and save_reconstruction: n = min(data['next_obs'].size(0), 16) comparison = torch.cat([ data['next_obs'][:n].narrow(start=0, length=self.imlength, dimension=1).contiguous().view( -1, self.input_channels, self.imsize, self.imsize), x_recon[:, 0].contiguous().view( self.batch_size, self.input_channels, self.imsize, self.imsize, )[:n] ]) save_dir = osp.join(logger.get_snapshot_dir(), 'r_%d.png' % epoch) save_image(comparison.data.cpu(), save_dir, nrow=n) del comparison if batch_idx == 0 and save_interpolation: n = min(data['next_obs'].size(0), 10) z1 = z_mu[:n, 0] z2 = z_mu[n:2 * n, 0] num_steps = 8 z_interp = [] for i in np.linspace(0.0, 1.0, num_steps): z_interp.append(float(i) * z1 + float(1 - i) * z2) z_interp = torch.cat(z_interp) imgs = self.model.decode(z_interp) imgs = imgs.view((num_steps, n, 3, self.imsize, self.imsize)) imgs = imgs.permute([1, 0, 2, 3, 4]) imgs = imgs.contiguous().view( (n * num_steps, 3, self.imsize, self.imsize)) save_dir = osp.join(logger.get_snapshot_dir(), 'i_%d.png' % epoch) save_image( imgs.data.cpu(), save_dir, nrow=num_steps, ) del imgs del z_interp del obs, next_obs, actions, x_recon, z_mu, z_logvar, \ z, x, vae_loss, de, kle, loss zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) del zs logger.record_tabular("test/decoder_loss", np.mean(des)) logger.record_tabular("test/KL", np.mean(kles)) if self.use_linear_dynamics: logger.record_tabular("test/linear_loss", np.mean(linear_losses)) logger.record_tabular("test/loss", np.mean(losses)) logger.record_tabular("test/vae_loss", np.mean(vae_losses)) logger.record_tabular("test/iwae_loss", np.mean(iwae_losses)) logger.record_tabular( "test/iwae_vae_diff", np.mean(np.array(iwae_losses) - np.array(vae_losses))) logger.record_tabular("beta", beta) process = psutil.Process(os.getpid()) logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000)) num_active_dims = 0 num_active_dims2 = 0 for std in self.model.dist_std: if std > 0.15: num_active_dims += 1 if std > 0.05: num_active_dims2 += 1 logger.record_tabular("num_active_dims", num_active_dims) logger.record_tabular("num_active_dims2", num_active_dims2) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True) # slow...
def __init__( self, vae, *args, decoded_obs_key='image_observation', decoded_achieved_goal_key='image_achieved_goal', decoded_desired_goal_key='image_desired_goal', exploration_rewards_type='None', exploration_rewards_scale=1.0, vae_priority_type='None', start_skew_epoch=0, power=1.0, internal_keys=None, exploration_schedule_kwargs=None, priority_function_kwargs=None, exploration_counter_kwargs=None, relabeling_goal_sampling_mode='vae_prior', decode_vae_goals=False, **kwargs ): if internal_keys is None: internal_keys = [] for key in [ decoded_obs_key, decoded_achieved_goal_key, decoded_desired_goal_key ]: if key not in internal_keys: internal_keys.append(key) super().__init__(internal_keys=internal_keys, *args, **kwargs) # assert isinstance(self.env, VAEWrappedEnv) self.vae = vae self.decoded_obs_key = decoded_obs_key self.decoded_desired_goal_key = decoded_desired_goal_key self.decoded_achieved_goal_key = decoded_achieved_goal_key self.exploration_rewards_type = exploration_rewards_type self.exploration_rewards_scale = exploration_rewards_scale self.start_skew_epoch = start_skew_epoch self.vae_priority_type = vae_priority_type self.power = power self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode self.decode_vae_goals = decode_vae_goals if exploration_schedule_kwargs is None: self.explr_reward_scale_schedule = \ ConstantSchedule(self.exploration_rewards_scale) else: self.explr_reward_scale_schedule = \ PiecewiseLinearSchedule(**exploration_schedule_kwargs) self._give_explr_reward_bonus = ( exploration_rewards_type != 'None' and exploration_rewards_scale != 0. ) self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32) self._prioritize_vae_samples = ( vae_priority_type != 'None' and power != 0. ) self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32) self._vae_sample_probs = None self.use_dynamics_model = ( self.exploration_rewards_type == 'forward_model_error' ) if self.use_dynamics_model: self.initialize_dynamics_model() type_to_function = { 'reconstruction_error': self.reconstruction_mse, 'bce': self.binary_cross_entropy, 'latent_distance': self.latent_novelty, 'latent_distance_true_prior': self.latent_novelty_true_prior, 'forward_model_error': self.forward_model_error, 'gaussian_inv_prob': self.gaussian_inv_prob, 'bernoulli_inv_prob': self.bernoulli_inv_prob, 'vae_prob': self.vae_prob, 'hash_count': self.hash_count_reward, 'None': self.no_reward, } self.exploration_reward_func = ( type_to_function[self.exploration_rewards_type] ) self.vae_prioritization_func = ( type_to_function[self.vae_priority_type] ) if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.exploration_rewards_type == 'hash_count': if exploration_counter_kwargs is None: exploration_counter_kwargs = dict() self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs) self.epoch = 0
def __init__(self, env, qf, policy, target_qf, target_policy, exploration_policy, policy_learning_rate=1e-4, qf_learning_rate=1e-3, qf_weight_decay=0, target_hard_update_period=1000, tau=1e-2, use_soft_update=True, qf_criterion=None, residual_gradient_weight=0, epoch_discount_schedule=None, eval_with_target_policy=False, policy_pre_activation_weight=0., optimizer_class=optim.Adam, plotter=None, render_eval_paths=False, obs_normalizer: TorchFixedNormalizer = None, action_normalizer: TorchFixedNormalizer = None, num_paths_for_normalization=0, min_q_value=-np.inf, max_q_value=np.inf, **kwargs): """ :param env: :param qf: :param policy: :param exploration_policy: :param policy_learning_rate: :param qf_learning_rate: :param qf_weight_decay: :param target_hard_update_period: :param tau: :param use_soft_update: :param qf_criterion: Loss function to use for the q function. Should be a function that takes in two inputs (y_predicted, y_target). :param residual_gradient_weight: c, float between 0 and 1. The gradient used for training the Q function is then (1-c) * normal td gradient + c * residual gradient :param epoch_discount_schedule: A schedule for the discount factor that varies with the epoch. :param kwargs: """ self.target_policy = target_policy if eval_with_target_policy: eval_policy = self.target_policy else: eval_policy = policy super().__init__(env, exploration_policy, eval_policy=eval_policy, **kwargs) if qf_criterion is None: qf_criterion = nn.MSELoss() self.qf = qf self.policy = policy self.policy_learning_rate = policy_learning_rate self.qf_learning_rate = qf_learning_rate self.qf_weight_decay = qf_weight_decay self.target_hard_update_period = target_hard_update_period self.tau = tau self.use_soft_update = use_soft_update self.residual_gradient_weight = residual_gradient_weight self.policy_pre_activation_weight = policy_pre_activation_weight self.qf_criterion = qf_criterion if epoch_discount_schedule is None: epoch_discount_schedule = ConstantSchedule(self.discount) self.epoch_discount_schedule = epoch_discount_schedule self.plotter = plotter self.render_eval_paths = render_eval_paths self.obs_normalizer = obs_normalizer self.action_normalizer = action_normalizer self.num_paths_for_normalization = num_paths_for_normalization self.min_q_value = min_q_value self.max_q_value = max_q_value self.target_qf = target_qf self.qf_optimizer = optimizer_class( self.qf.parameters(), lr=self.qf_learning_rate, ) self.policy_optimizer = optimizer_class( self.policy.parameters(), lr=self.policy_learning_rate, )
def __init__(self, env, exploration_policy, beta_q, beta_q2, beta_v, policy, train_with='both', goal_reached_epsilon=1e-3, learning_rate=1e-3, prioritized_replay=False, always_reset_env=True, finite_horizon=False, max_num_steps_left=0, flip_training_period=100, train_simultaneously=True, policy_and_target_update_period=2, target_policy_noise=0.2, target_policy_noise_clip=0.5, soft_target_tau=0.005, per_beta_schedule=None, **kwargs): self.train_simultaneously = train_simultaneously assert train_with in ['both', 'off_policy', 'on_policy'] super().__init__(env, exploration_policy, **kwargs) self.eval_sampler = MultigoalSimplePathSampler( env=self.env, policy=self.eval_policy, max_samples=self.num_steps_per_eval, max_path_length=self.max_path_length, tau_sampling_function=lambda: 0, goal_sampling_function=self.env.sample_goal_for_rollout, cycle_taus_for_rollout=False, render=self.render_during_eval) self.goal_reached_epsilon = goal_reached_epsilon self.beta_q = beta_q self.beta_v = beta_v self.beta_q2 = beta_q2 self.target_beta_q = self.beta_q.copy() self.target_beta_q2 = self.beta_q2.copy() self.train_with = train_with self.policy = policy self.target_policy = policy self.prioritized_replay = prioritized_replay self.flip_training_period = flip_training_period self.always_reset_env = always_reset_env self.finite_horizon = finite_horizon self.max_num_steps_left = max_num_steps_left assert max_num_steps_left >= 0 self.policy_and_target_update_period = policy_and_target_update_period self.target_policy_noise = target_policy_noise self.target_policy_noise_clip = target_policy_noise_clip self.soft_target_tau = soft_target_tau if per_beta_schedule is None: per_beta_schedule = ConstantSchedule(1.0) self.per_beta_schedule = per_beta_schedule self.beta_q_optimizer = Adam(self.beta_q.parameters(), lr=learning_rate) self.beta_q2_optimizer = Adam(self.beta_q2.parameters(), lr=learning_rate) self.beta_v_optimizer = Adam(self.beta_v.parameters(), lr=learning_rate) self.policy_optimizer = Adam( self.policy.parameters(), lr=learning_rate, ) self.q_criterion = nn.BCELoss() self.v_criterion = nn.BCELoss() # For the multitask env self._rollout_goal = None self.extra_eval_statistics = OrderedDict() for key_not_always_updated in [ 'Policy Gradient Norms', 'Beta Q Gradient Norms', 'dQ/da', ]: self.eval_statistics.update( create_stats_ordered_dict( key_not_always_updated, np.zeros(2), )) self.training_policy = False # For debugging self.train_batches = []