Example #1
0
    def __init__(
            self,
            train_dataset,
            test_dataset,
            model,
            batch_size=128,
            beta=0.5,
            beta_schedule=None,
            lr=1e-3,
            extra_recon_logging=dict(),
            recon_weights=None,
            recon_loss_type='mse',
            **kwargs
    ):
        assert recon_loss_type in ['mse', 'wse']
        self.batch_size = batch_size
        self.beta = beta
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None:
            self.beta_schedule = ConstantSchedule(self.beta)

        if ptu.gpu_enabled():
            model.cuda()

        self.model = model
        self.representation_size = model.representation_size

        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset['next_obs'].dtype == np.float32
        assert self.test_dataset['next_obs'].dtype ==np.float32
        assert self.train_dataset['obs'].dtype == np.float32
        assert self.test_dataset['obs'].dtype == np.float32
        self.normalize = model.normalize
        self.mse = nn.MSELoss()

        if self.normalize:
            self.train_data_mean = ptu.np_to_var(np.mean(self.train_dataset['next_obs'], axis=0))
            np_std = np.std(self.train_dataset['next_obs'], axis=0)
            for i in range(len(np_std)):
                if np_std[i] < 1e-3:
                    np_std[i] = 1.0
            self.train_data_std = ptu.np_to_var(np_std)

            self.model.train_data_mean = self.train_data_mean
            self.model.train_data_std = self.train_data_std

        self.extra_recon_logging = extra_recon_logging
        self.recon_weights = recon_weights
        self.recon_loss_type = recon_loss_type
    def __init__(
        self,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        weight_decay=0,
    ):
        self.model = model
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        model.to(ptu.device)

        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = self.imsize * self.imsize * self.input_channels

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )

        self.eval_statistics = {}
Example #3
0
class VAETrainer():
    def __init__(
            self,
            train_dataset,
            test_dataset,
            model,
            batch_size=128,
            beta=0.5,
            beta_schedule=None,
            lr=1e-3,
            extra_recon_logging=dict(),
            recon_weights=None,
            recon_loss_type='mse',
            **kwargs
    ):
        assert recon_loss_type in ['mse', 'wse']
        self.batch_size = batch_size
        self.beta = beta
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None:
            self.beta_schedule = ConstantSchedule(self.beta)

        if ptu.gpu_enabled():
            model.cuda()

        self.model = model
        self.representation_size = model.representation_size

        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset['next_obs'].dtype == np.float32
        assert self.test_dataset['next_obs'].dtype ==np.float32
        assert self.train_dataset['obs'].dtype == np.float32
        assert self.test_dataset['obs'].dtype == np.float32
        self.normalize = model.normalize
        self.mse = nn.MSELoss()

        if self.normalize:
            self.train_data_mean = ptu.np_to_var(np.mean(self.train_dataset['next_obs'], axis=0))
            np_std = np.std(self.train_dataset['next_obs'], axis=0)
            for i in range(len(np_std)):
                if np_std[i] < 1e-3:
                    np_std[i] = 1.0
            self.train_data_std = ptu.np_to_var(np_std)

            self.model.train_data_mean = self.train_data_mean
            self.model.train_data_std = self.train_data_std

        self.extra_recon_logging = extra_recon_logging
        self.recon_weights = recon_weights
        self.recon_loss_type = recon_loss_type

    def get_batch(self, train=True):
        dataset = self.train_dataset if train else self.test_dataset
        ind = np.random.randint(0, len(dataset['obs']), self.batch_size)
        samples_obs = dataset['obs'][ind, :]
        samples_actions = dataset['actions'][ind, :]
        samples_next_obs = dataset['next_obs'][ind, :]
        return {
            'obs': ptu.np_to_var(samples_obs),
            'actions': ptu.np_to_var(samples_actions),
            'next_obs': ptu.np_to_var(samples_next_obs),
        }

    def logprob(self, recon_x, x, normalize=None, idx=None, unorm_weights=None):
        if normalize is None:
            normalize = self.normalize
        if normalize:
            x = (x - self.train_data_mean) / self.train_data_std
            recon_x = (recon_x - self.train_data_mean) / self.train_data_std
        if idx is not None:
            x = x[:,idx]
            recon_x = recon_x[:,idx]
            if unorm_weights is not None:
                unorm_weights = unorm_weights[idx]
        if unorm_weights is not None:
            dim = x.shape[1]
            norm_weights = unorm_weights / (np.sum(unorm_weights) / dim)
            norm_weights = ptu.np_to_var(norm_weights)
            recon_x = recon_x * norm_weights
            x = x * norm_weights

        return self.mse(recon_x, x)

    def kl_divergence(self, mu, logvar):
        kl = - torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
        return kl

    def train_epoch(self, epoch, batches=100):
        self.model.train()
        losses = []
        kles = []
        mses = []
        beta = self.beta_schedule.get_value(epoch)
        for batch_idx in range(batches):
            data = self.get_batch()
            obs = data['obs']
            next_obs = data['next_obs']
            actions = data['actions']
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(next_obs)
            mse = self.logprob(recon_batch, next_obs)
            kle = self.kl_divergence(mu, logvar)
            if self.recon_loss_type == 'mse':
                loss = mse + beta * kle
            elif self.recon_loss_type == 'wse':
                wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights)
                loss = wse + beta * kle
            loss.backward()

            losses.append(loss.data[0])
            mses.append(mse.data[0])
            kles.append(kle.data[0])

            self.optimizer.step()

        logger.record_tabular("train/epoch", epoch)
        logger.record_tabular("train/MSE", np.mean(mses))
        logger.record_tabular("train/KL", np.mean(kles))
        logger.record_tabular("train/loss", np.mean(losses))

    def test_epoch(self, epoch, save_vae=True, **kwargs):
        self.model.eval()
        losses = []
        kles = []
        zs = []

        recon_logging_dict = {
            'MSE': [],
            'WSE': [],
        }
        for k in self.extra_recon_logging:
            recon_logging_dict[k] = []

        beta = self.beta_schedule.get_value(epoch)
        for batch_idx in range(100):
            data = self.get_batch(train=False)
            obs = data['obs']
            next_obs = data['next_obs']
            actions = data['actions']
            recon_batch, mu, logvar = self.model(next_obs)
            mse = self.logprob(recon_batch, next_obs)
            wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights)
            for k, idx in self.extra_recon_logging.items():
                recon_loss = self.logprob(recon_batch, next_obs, idx=idx)
                recon_logging_dict[k].append(recon_loss.data[0])
            kle = self.kl_divergence(mu, logvar)
            if self.recon_loss_type == 'mse':
                loss = mse + beta * kle
            elif self.recon_loss_type == 'wse':
                loss = wse + beta * kle
            z_data = ptu.get_numpy(mu.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.data[0])
            recon_logging_dict['WSE'].append(wse.data[0])
            recon_logging_dict['MSE'].append(mse.data[0])
            kles.append(kle.data[0])
        zs = np.array(zs)
        self.model.dist_mu = zs.mean(axis=0)
        self.model.dist_std = zs.std(axis=0)

        for k in recon_logging_dict:
            logger.record_tabular("/".join(["test", k]), np.mean(recon_logging_dict[k]))
        logger.record_tabular("test/KL", np.mean(kles))
        logger.record_tabular("test/loss", np.mean(losses))
        logger.record_tabular("beta", beta)

        process = psutil.Process(os.getpid())
        logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000))

        num_active_dims = 0
        for std in self.model.dist_std:
            if std > 0.15:
                num_active_dims += 1
        logger.record_tabular("num_active_dims", num_active_dims)

        logger.dump_tabular()
        if save_vae:
            logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True)  # slow...
Example #4
0
    def __init__(
        self,
        max_tau=10,
        epoch_max_tau_schedule=None,
        sample_train_goals_from='replay_buffer',
        sample_rollout_goals_from='environment',
        cycle_taus_for_rollout=False,
    ):
        """
        :param max_tau: Maximum tau (planning horizon) to train with.
        :param epoch_max_tau_schedule: A schedule for the maximum planning
        horizon tau.
        :param sample_train_goals_from: Sampling strategy for goals used in
        training. Can be one of the following strings:
            - environment: Sample from the environment
            - replay_buffer: Sample from the replay_buffer
            - her: Sample from a HER-based replay_buffer
        :param sample_rollout_goals_from: Sampling strategy for goals used
        during rollout. Can be one of the following strings:
            - environment: Sample from the environment
            - replay_buffer: Sample from the replay_buffer
            - fixed: Do no resample the goal. Just use the one in the
            environment.
        :param vectorized: Train the QF in vectorized form?
        :param cycle_taus_for_rollout: Decrement the tau passed into the
        policy during rollout?
        """
        assert sample_train_goals_from in [
            'environment', 'replay_buffer', 'her'
        ]
        assert sample_rollout_goals_from in [
            'environment', 'replay_buffer', 'fixed'
        ]
        if epoch_max_tau_schedule is None:
            epoch_max_tau_schedule = ConstantSchedule(max_tau)

        self.max_tau = max_tau
        self.epoch_max_tau_schedule = epoch_max_tau_schedule
        self.sample_train_goals_from = sample_train_goals_from
        self.sample_rollout_goals_from = sample_rollout_goals_from
        self.cycle_taus_for_rollout = cycle_taus_for_rollout
        self._current_path_goal = None
        self._rollout_tau = self.max_tau

        self.policy = MakeUniversal(self.policy)
        self.eval_policy = MakeUniversal(self.eval_policy)
        self.exploration_policy = MakeUniversal(self.exploration_policy)
        self.eval_sampler = MultigoalSimplePathSampler(
            env=self.env,
            policy=self.eval_policy,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
            discount_sampling_function=self._sample_max_tau_for_rollout,
            goal_sampling_function=self._sample_goal_for_rollout,
            cycle_taus_for_rollout=self.cycle_taus_for_rollout,
        )
        if self.collection_mode == 'online-parallel':
            # TODO(murtaza): What happens to the eval env?
            # see `eval_sampler` definition above.

            self.training_env = RemoteRolloutEnv(
                env=self.env,
                policy=self.eval_policy,
                exploration_policy=self.exploration_policy,
                max_path_length=self.max_path_length,
                normalize_env=self.normalize_env,
                rollout_function=self.rollout,
            )
Example #5
0
    def __init__(
        self,
        max_tau=10,
        epoch_max_tau_schedule=None,
        vectorized=True,
        cycle_taus_for_rollout=True,
        dense_rewards=False,
        finite_horizon=True,
        tau_sample_strategy='uniform',
        goal_reached_epsilon=1e-3,
        terminate_when_goal_reached=False,
        truncated_geom_factor=2.,
        square_distance=False,
        goal_weights=None,
        normalize_distance=False,
        observation_key=None,
        desired_goal_key=None,
    ):
        """

        :param max_tau: Maximum tau (planning horizon) to train with.
        :param epoch_max_tau_schedule: A schedule for the maximum planning
        horizon tau.
        :param vectorized: Train the QF in vectorized form?
        :param cycle_taus_for_rollout: Decrement the tau passed into the
        policy during rollout?
        :param dense_rewards: If True, always give rewards. Otherwise,
        only give rewards when the episode terminates.
        :param finite_horizon: If True, use a finite horizon formulation:
        give the time as input to the Q-function and terminate.
        :param tau_sample_strategy: Sampling strategy for taus used
        during training. Can be one of the following strings:
            - no_resampling: Do not resample the tau. Use the one from rollout.
            - uniform: Sample uniformly from [0, max_tau]
            - truncated_geometric: Sample from a truncated geometric
            distribution, truncated at max_tau.
            - all_valid: Always use all 0 to max_tau values
        :param goal_reached_epsilon: Epsilon used to determine if the goal
        has been reached. Used by `indicator` version of `reward_type` and when
        `terminate_whe_goal_reached` is True.
        :param terminate_when_goal_reached: Do you terminate when you have
        reached the goal?
        :param goal_weights: None or the weights for the different goal
        dimensions. These weights are used to compute the distances to the goal.
        """
        assert tau_sample_strategy in [
            'no_resampling',
            'uniform',
            'truncated_geometric',
            'all_valid',
        ]
        if epoch_max_tau_schedule is None:
            epoch_max_tau_schedule = ConstantSchedule(max_tau)

        if not finite_horizon:
            max_tau = 0
            epoch_max_tau_schedule = ConstantSchedule(max_tau)
            cycle_taus_for_rollout = False

        self.max_tau = max_tau
        self.epoch_max_tau_schedule = epoch_max_tau_schedule
        self.vectorized = vectorized
        self.cycle_taus_for_rollout = cycle_taus_for_rollout
        self.dense_rewards = dense_rewards
        self.finite_horizon = finite_horizon
        self.tau_sample_strategy = tau_sample_strategy
        self.goal_reached_epsilon = goal_reached_epsilon
        self.terminate_when_goal_reached = terminate_when_goal_reached
        self.square_distance = square_distance
        self._rollout_tau = np.array([self.max_tau])
        self.truncated_geom_factor = float(truncated_geom_factor)
        self.goal_weights = goal_weights
        if self.goal_weights is not None:
            # In case they were passed in as (e.g.) tuples or list
            self.goal_weights = np.array(self.goal_weights)
            assert self.goal_weights.size == self.env.goal_dim
        self.normalize_distance = normalize_distance

        self.observation_key = observation_key
        self.desired_goal_key = desired_goal_key

        self.eval_sampler = MultigoalSimplePathSampler(
            env=self.env,
            policy=self.eval_policy,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
            tau_sampling_function=self._sample_max_tau_for_rollout,
            cycle_taus_for_rollout=self.cycle_taus_for_rollout,
            render=self.render_during_eval,
            observation_key=self.observation_key,
            desired_goal_key=self.desired_goal_key,
        )
        self.pretrain_obs = None
Example #6
0
    def __init__(
            self,
            env,
            policy,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            buffer_policy=None,

            discount=0.99,
            reward_scale=1.0,
            beta=1.0,
            beta_schedule_kwargs=None,

            policy_lr=1e-3,
            qf_lr=1e-3,
            policy_weight_decay=0,
            q_weight_decay=0,
            optimizer_class=optim.Adam,

            soft_target_tau=1e-2,
            target_update_period=1,
            plotter=None,
            render_eval_paths=False,

            use_automatic_entropy_tuning=True,
            target_entropy=None,

            bc_num_pretrain_steps=0,
            q_num_pretrain1_steps=0,
            q_num_pretrain2_steps=0,
            bc_batch_size=128,
            bc_loss_type="mle",
            awr_loss_type="mle",
            save_bc_policies=0,
            alpha=1.0,

            policy_update_period=1,
            q_update_period=1,

            weight_loss=True,
            compute_bc=True,

            bc_weight=0.0,
            rl_weight=1.0,
            use_awr_update=True,
            use_reparam_update=False,
            reparam_weight=1.0,
            awr_weight=1.0,
            post_pretrain_hyperparams=None,
            post_bc_pretrain_hyperparams=None,

            awr_use_mle_for_vf=False,
            awr_sample_actions=False,
            awr_min_q=False,

            reward_transform_class=None,
            reward_transform_kwargs=None,
            terminal_transform_class=None,
            terminal_transform_kwargs=None,

            pretraining_env_logging_period=100000,
            pretraining_logging_period=1000,
            do_pretrain_rollouts=False,

            train_bc_on_rl_buffer=False,
            use_automatic_beta_tuning=False,
            beta_epsilon=1e-10,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.buffer_policy = buffer_policy
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_awr_update = use_awr_update
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.awr_use_mle_for_vf = awr_use_mle_for_vf
        self.awr_sample_actions = awr_sample_actions
        self.awr_min_q = awr_min_q

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            weight_decay=policy_weight_decay,
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )

        if buffer_policy and train_bc_on_rl_buffer:
            self.buffer_policy_optimizer =  optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=policy_weight_decay,
                lr=policy_lr,
            )

        self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer
        self.beta_epsilon=beta_epsilon
        if self.use_automatic_beta_tuning:
            self.log_beta = ptu.zeros(1, requires_grad=True)
            self.beta_optimizer = optimizer_class(
                [self.log_beta],
                lr=policy_lr,
            )
        else:
            self.beta = beta
            self.beta_schedule_kwargs = beta_schedule_kwargs
            if beta_schedule_kwargs is None:
                self.beta_schedule = ConstantSchedule(beta)
            else:
                schedule_class = beta_schedule_kwargs.pop("schedule_class", PiecewiseLinearSchedule)
                self.beta_schedule = schedule_class(**beta_schedule_kwargs)

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.bc_num_pretrain_steps = bc_num_pretrain_steps
        self.q_num_pretrain1_steps = q_num_pretrain1_steps
        self.q_num_pretrain2_steps = q_num_pretrain2_steps
        self.bc_batch_size = bc_batch_size
        self.bc_loss_type = bc_loss_type
        self.awr_loss_type = awr_loss_type
        self.rl_weight = rl_weight
        self.bc_weight = bc_weight
        self.save_bc_policies = save_bc_policies
        self.eval_policy = MakeDeterministic(self.policy)
        self.compute_bc = compute_bc
        self.alpha = alpha
        self.q_update_period = q_update_period
        self.policy_update_period = policy_update_period
        self.weight_loss = weight_loss

        self.reparam_weight = reparam_weight
        self.awr_weight = awr_weight
        self.post_pretrain_hyperparams = post_pretrain_hyperparams
        self.post_bc_pretrain_hyperparams = post_bc_pretrain_hyperparams
        self.update_policy = True
        self.pretraining_env_logging_period = pretraining_env_logging_period
        self.pretraining_logging_period = pretraining_logging_period
        self.do_pretrain_rollouts = do_pretrain_rollouts

        self.reward_transform_class = reward_transform_class or LinearTransform
        self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1, b=0)
        self.terminal_transform_class = terminal_transform_class or LinearTransform
        self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1, b=0)
        self.reward_transform = self.reward_transform_class(**self.reward_transform_kwargs)
        self.terminal_transform = self.terminal_transform_class(**self.terminal_transform_kwargs)
        self.use_reparam_update = use_reparam_update

        self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy
Example #7
0
class AWRSACTrainer(TorchTrainer):
    def __init__(
            self,
            env,
            policy,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            buffer_policy=None,

            discount=0.99,
            reward_scale=1.0,
            beta=1.0,
            beta_schedule_kwargs=None,

            policy_lr=1e-3,
            qf_lr=1e-3,
            policy_weight_decay=0,
            q_weight_decay=0,
            optimizer_class=optim.Adam,

            soft_target_tau=1e-2,
            target_update_period=1,
            plotter=None,
            render_eval_paths=False,

            use_automatic_entropy_tuning=True,
            target_entropy=None,

            bc_num_pretrain_steps=0,
            q_num_pretrain1_steps=0,
            q_num_pretrain2_steps=0,
            bc_batch_size=128,
            bc_loss_type="mle",
            awr_loss_type="mle",
            save_bc_policies=0,
            alpha=1.0,

            policy_update_period=1,
            q_update_period=1,

            weight_loss=True,
            compute_bc=True,

            bc_weight=0.0,
            rl_weight=1.0,
            use_awr_update=True,
            use_reparam_update=False,
            reparam_weight=1.0,
            awr_weight=1.0,
            post_pretrain_hyperparams=None,
            post_bc_pretrain_hyperparams=None,

            awr_use_mle_for_vf=False,
            awr_sample_actions=False,
            awr_min_q=False,

            reward_transform_class=None,
            reward_transform_kwargs=None,
            terminal_transform_class=None,
            terminal_transform_kwargs=None,

            pretraining_env_logging_period=100000,
            pretraining_logging_period=1000,
            do_pretrain_rollouts=False,

            train_bc_on_rl_buffer=False,
            use_automatic_beta_tuning=False,
            beta_epsilon=1e-10,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.buffer_policy = buffer_policy
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_awr_update = use_awr_update
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.awr_use_mle_for_vf = awr_use_mle_for_vf
        self.awr_sample_actions = awr_sample_actions
        self.awr_min_q = awr_min_q

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            weight_decay=policy_weight_decay,
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )

        if buffer_policy and train_bc_on_rl_buffer:
            self.buffer_policy_optimizer =  optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=policy_weight_decay,
                lr=policy_lr,
            )

        self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer
        self.beta_epsilon=beta_epsilon
        if self.use_automatic_beta_tuning:
            self.log_beta = ptu.zeros(1, requires_grad=True)
            self.beta_optimizer = optimizer_class(
                [self.log_beta],
                lr=policy_lr,
            )
        else:
            self.beta = beta
            self.beta_schedule_kwargs = beta_schedule_kwargs
            if beta_schedule_kwargs is None:
                self.beta_schedule = ConstantSchedule(beta)
            else:
                schedule_class = beta_schedule_kwargs.pop("schedule_class", PiecewiseLinearSchedule)
                self.beta_schedule = schedule_class(**beta_schedule_kwargs)

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.bc_num_pretrain_steps = bc_num_pretrain_steps
        self.q_num_pretrain1_steps = q_num_pretrain1_steps
        self.q_num_pretrain2_steps = q_num_pretrain2_steps
        self.bc_batch_size = bc_batch_size
        self.bc_loss_type = bc_loss_type
        self.awr_loss_type = awr_loss_type
        self.rl_weight = rl_weight
        self.bc_weight = bc_weight
        self.save_bc_policies = save_bc_policies
        self.eval_policy = MakeDeterministic(self.policy)
        self.compute_bc = compute_bc
        self.alpha = alpha
        self.q_update_period = q_update_period
        self.policy_update_period = policy_update_period
        self.weight_loss = weight_loss

        self.reparam_weight = reparam_weight
        self.awr_weight = awr_weight
        self.post_pretrain_hyperparams = post_pretrain_hyperparams
        self.post_bc_pretrain_hyperparams = post_bc_pretrain_hyperparams
        self.update_policy = True
        self.pretraining_env_logging_period = pretraining_env_logging_period
        self.pretraining_logging_period = pretraining_logging_period
        self.do_pretrain_rollouts = do_pretrain_rollouts

        self.reward_transform_class = reward_transform_class or LinearTransform
        self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1, b=0)
        self.terminal_transform_class = terminal_transform_class or LinearTransform
        self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1, b=0)
        self.reward_transform = self.reward_transform_class(**self.reward_transform_kwargs)
        self.terminal_transform = self.terminal_transform_class(**self.terminal_transform_kwargs)
        self.use_reparam_update = use_reparam_update

        self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy


    def get_batch_from_buffer(self, replay_buffer, batch_size):
        batch = replay_buffer.random_batch(batch_size)
        batch = np_to_pytorch_batch(batch)
        return batch

    def run_bc_batch(self, replay_buffer, policy):
        batch = self.get_batch_from_buffer(replay_buffer, self.bc_batch_size)
        o = batch["observations"]
        u = batch["actions"]
        # g = batch["resampled_goals"]
        # og = torch.cat((o, g), dim=1)
        og = o
        # pred_u, *_ = self.policy(og)
        pred_u, policy_mean, policy_log_std, log_pi, entropy, policy_std, mean_action_log_prob, pretanh_value, dist = policy(
            og, deterministic=False, reparameterize=True, return_log_prob=True,
        )

        mse = (policy_mean - u) ** 2
        mse_loss = mse.mean()

        policy_logpp = dist.log_prob(u, )
        logp_loss = -policy_logpp.mean()

        # T = 0
        if self.bc_loss_type == "mle":
            policy_loss = logp_loss
        elif self.bc_loss_type == "mse":
            policy_loss = mse_loss
        else:
            error

        return policy_loss, logp_loss, mse_loss, policy_log_std

    def do_rollouts(self):
        total_ret = 0
        for _ in range(20):
            o = self.env.reset()
            ret = 0
            for _ in range(1000):
                a, _ = self.policy.get_action(o)
                o, r, done, info = self.env.step(a)
                ret += r
                if done:
                    break
            total_ret += ret
        return total_ret

    def pretrain_policy_with_bc(self):
        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'pretrain_policy.csv', relative_to_snapshot_dir=True
        )
        if self.do_pretrain_rollouts:
            total_ret = self.do_rollouts()
            print("INITIAL RETURN", total_ret/20)

        prev_time = time.time()
        for i in range(self.bc_num_pretrain_steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(self.demo_train_buffer, self.policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            self.policy_optimizer.zero_grad()
            train_policy_loss.backward()
            self.policy_optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(self.demo_test_buffer, self.policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret/20))

            if i % self.pretraining_logging_period==0:
                stats = {
                "pretrain_bc/batch": i,
                "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss),
                "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss),
                "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss),
                "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss),
                "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss),
                "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss),
                "pretrain_bc/epoch_time":time.time()-prev_time,
                }

                if self.do_pretrain_rollouts:
                    stats["pretrain_bc/avg_return"] = total_ret / 20

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)

    def pretrain_q_with_bc_data(self):
        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'pretrain_q.csv', relative_to_snapshot_dir=True
        )

        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain1_steps):
            self.eval_statistics = dict()

            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs # torch.cat((obs, goals), dim=1)
            train_data['next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data)
            if i%self.pretraining_logging_period == 0:
                logger.record_dict(self.eval_statistics)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.q_num_pretrain2_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics=True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs # torch.cat((obs, goals), dim=1)
            train_data['next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data)
            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret/20))

            if i%self.pretraining_logging_period==0:
                if self.do_pretrain_rollouts:
                    self.eval_statistics["pretrain_bc/avg_return"] = total_ret / 20
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time()-prev_time
                logger.record_dict(self.eval_statistics)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()

        if self.post_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_pretrain_hyperparams)

    def set_algorithm_weights(
        self,
        # bc_weight,
        # rl_weight,
        # use_awr_update,
        # use_reparam_update,
        # reparam_weight,
        # awr_weight,
        **kwargs
    ):
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
        # self.bc_weight = bc_weight
        # self.rl_weight = rl_weight
        # self.use_awr_update = use_awr_update
        # self.use_reparam_update = use_reparam_update
        # self.awr_weight = awr_weight

    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)

        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, entropy, policy_std, mean_action_log_prob, pretanh_value, dist = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v_pi = self.qf1(obs, policy_mean)
        else:
            v_pi = self.qf1(obs, new_obs_actions)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        if self.awr_loss_type == "mse":
            policy_logpp = -(policy_mean - actions) ** 2
        else:
            policy_logpp = dist.log_prob(u)
            policy_logpp = policy_logpp.sum(dim=1, keepdim=True)

        advantage = q_adv - v_pi

        if self.weight_loss and weights is None:
            if self.use_automatic_beta_tuning:
                _, _, _, _, _, _, _, _, buffer_dist = self.buffer_policy(
                    obs, reparameterize=True, return_log_prob=True,
                )
                beta = self.log_beta.exp()
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
                beta_loss = -1*(beta*(kldiv-self.beta_epsilon).detach()).mean()

                self.beta_optimizer.zero_grad()
                beta_loss.backward()
                self.beta_optimizer.step()
            else:
                beta = self.beta_schedule.get_value(self._n_train_steps_total)
            weights = F.softmax(advantage / beta, dim=0)

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (-policy_logpp * len(weights)*weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (-policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (-q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if self.train_bc_on_rl_buffer:
            buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(self.replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0 :
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()



        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Advantage Weights',
                ptu.get_numpy(weights),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE": ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE": ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss": ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss": ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(self.replay_buffer,
                                                                                       self.buffer_policy)
                _, _, _, _, _, _, _, _, buffer_dist = self.buffer_policy(
                    obs, reparameterize=True, return_log_prob=True,
                )

                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                self.eval_statistics.update({
                    "buffer_policy/Train Logprob Loss": ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Logprob Loss": ptu.get_numpy(test_logp_loss),
                    "buffer_policy/Train MSE": ptu.get_numpy(buffer_train_mse_loss),
                    "buffer_policy/Test MSE": ptu.get_numpy(test_mse_loss),
                    "buffer_policy/train_policy_loss": ptu.get_numpy(buffer_policy_loss),
                    "buffer_policy/test_policy_loss": ptu.get_numpy(test_policy_loss),
                    "buffer_policy/kl_div":ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()),
                })

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        stats = super().get_diagnostics()
        stats.update(self.eval_statistics)
        return stats

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        nets = [
            self.policy,
            self.qf1,
            self.qf2,
            self.target_qf1,
            self.target_qf2,
        ]
        if self.buffer_policy:
            nets.append(self.buffer_policy)
        return nets

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            qf1=self.qf1,
            qf2=self.qf2,
            target_qf1=self.qf1,
            target_qf2=self.qf2,
            buffer_policy=self.buffer_policy,
        )
    def __init__(
        self,
        model,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=1e-3,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        linearity_weight=0.0,
        distance_weight=0.0,
        loss_weights=None,
        use_linear_dynamics=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
        batch_size=64,
    ):
        #TODO:steven fix pickling
        assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first"

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.do_scatterplot = do_scatterplot
        model.to(ptu.device)

        self.model = model

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        self.linearity_weight = linearity_weight
        self.distance_weight = distance_weight
        self.loss_weights = loss_weights

        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.log_softmax = torch.nn.LogSoftmax()

        self.use_linear_dynamics = use_linear_dynamics
        self._extra_stats_to_log = None

        # stateful tracking variables, reset every epoch
        self.eval_statistics = collections.defaultdict(list)
        self.eval_data = collections.defaultdict(list)

        self.num_train_batches = 0
        self.num_test_batches = 0

        self.bin_midpoints = (
            torch.arange(0, self.model.output_classes).float() +
            0.5) / self.model.output_classes
        self.bin_midpoints = self.bin_midpoints.to(ptu.device)
Example #9
0
    def __init__(
        self,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        linearity_weight=0.0,
        distance_weight=0.0,
        loss_weights=None,
        use_linear_dynamics=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
        key_to_reconstruct="observations",
    ):
        #TODO:steven fix pickling
        assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first"

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot
        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )

        self.key_to_reconstruct = key_to_reconstruct
        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.linearity_weight = linearity_weight
        self.distance_weight = distance_weight
        self.loss_weights = loss_weights

        self.use_linear_dynamics = use_linear_dynamics
        self._extra_stats_to_log = None

        # stateful tracking variables, reset every epoch
        self.eval_statistics = collections.defaultdict(list)
        self.eval_data = collections.defaultdict(list)
Example #10
0
class VAETrainer(LossFunction):
    def __init__(
        self,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        linearity_weight=0.0,
        distance_weight=0.0,
        loss_weights=None,
        use_linear_dynamics=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
        key_to_reconstruct="observations",
    ):
        #TODO:steven fix pickling
        assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first"

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot
        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )

        self.key_to_reconstruct = key_to_reconstruct
        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.linearity_weight = linearity_weight
        self.distance_weight = distance_weight
        self.loss_weights = loss_weights

        self.use_linear_dynamics = use_linear_dynamics
        self._extra_stats_to_log = None

        # stateful tracking variables, reset every epoch
        self.eval_statistics = collections.defaultdict(list)
        self.eval_data = collections.defaultdict(list)

    @property
    def log_dir(self):
        return logger.get_snapshot_dir()

    def get_dataset_stats(self, data):
        torch_input = ptu.from_numpy(normalize_image(data))
        mus, log_vars = self.model.encode(torch_input)
        mus = ptu.get_numpy(mus)
        mean = np.mean(mus, axis=0)
        std = np.std(mus, axis=0)
        return mus, mean, std

    def _kl_np_to_np(self, np_imgs):
        torch_input = ptu.from_numpy(normalize_image(np_imgs))
        mu, log_var = self.model.encode(torch_input)
        return ptu.get_numpy(
            -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))

    def _reconstruction_squared_error_np_to_np(self, np_imgs):
        torch_input = ptu.from_numpy(normalize_image(np_imgs))
        recons, *_ = self.model(torch_input)
        error = torch_input - recons
        return ptu.get_numpy((error**2).sum(dim=1))

    def set_vae(self, vae):
        self.model = vae
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

    def get_batch(self, test_data=False, epoch=None):
        if self.use_parallel_dataloading:
            if test_data:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.test_dataset if test_data else self.train_dataset
        skew = False
        if epoch is not None:
            skew = (self.start_skew_epoch < epoch)
        if not test_data and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean
        return ptu.from_numpy(samples)

    def get_debug_batch(self, train=True):
        dataset = self.train_dataset if train else self.test_dataset
        X, Y = dataset
        ind = np.random.randint(0, Y.shape[0], self.batch_size)
        X = X[ind, :]
        Y = Y[ind, :]
        return ptu.from_numpy(X), ptu.from_numpy(Y)

    def train_epoch(self, epoch, dataset, batches=100):
        start_time = time.time()
        for b in range(batches):
            self.train_batch(epoch, dataset.random_batch(self.batch_size))
        self.eval_statistics["train/epoch_duration"].append(time.time() -
                                                            start_time)

    def test_epoch(self, epoch, dataset, batches=10):
        start_time = time.time()
        for b in range(batches):
            self.test_batch(epoch, dataset.random_batch(self.batch_size))
        self.eval_statistics["test/epoch_duration"].append(time.time() -
                                                           start_time)

    def compute_loss(self, batch, epoch=-1, test=False):
        prefix = "test/" if test else "train/"

        beta = float(self.beta_schedule.get_value(epoch))
        obs = batch[self.key_to_reconstruct]
        reconstructions, obs_distribution_params, latent_distribution_params = self.model(
            obs)
        log_prob = self.model.logprob(obs, obs_distribution_params)
        kle = self.model.kl_divergence(latent_distribution_params)
        loss = -1 * log_prob + beta * kle

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['beta'] = beta
        self.eval_statistics[prefix + "losses"].append(loss.item())
        self.eval_statistics[prefix + "log_probs"].append(log_prob.item())
        self.eval_statistics[prefix + "kles"].append(kle.item())

        encoder_mean = self.model.get_encoding_from_latent_distribution_params(
            latent_distribution_params)
        z_data = ptu.get_numpy(encoder_mean.cpu())
        for i in range(len(z_data)):
            self.eval_data[prefix + "zs"].append(z_data[i, :])
        self.eval_data[prefix + "last_batch"] = (obs, reconstructions)

        return loss

    def train_batch(self, epoch, batch):
        self.model.train()
        self.optimizer.zero_grad()

        loss = self.compute_loss(batch, epoch, False)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def test_batch(
        self,
        epoch,
        batch,
    ):
        self.model.eval()
        loss = self.compute_loss(batch, epoch, True)

    def end_epoch(self, epoch):
        self.eval_statistics = collections.defaultdict(list)
        self.test_last_batch = None

    def get_diagnostics(self):
        stats = OrderedDict()
        for k in sorted(self.eval_statistics.keys()):
            stats[k] = np.mean(self.eval_statistics[k])
        return stats

    def dump_scatterplot(self, z, epoch):
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            logger.log(__file__ + ": Unable to load matplotlib. Consider "
                       "setting do_scatterplot to False")
            return
        dim_and_stds = [(i, np.std(z[:, i])) for i in range(z.shape[1])]
        dim_and_stds = sorted(dim_and_stds, key=lambda x: x[1])
        dim1 = dim_and_stds[-1][0]
        dim2 = dim_and_stds[-2][0]
        plt.figure(figsize=(8, 8))
        plt.scatter(z[:, dim1], z[:, dim2], marker='o', edgecolor='none')
        if self.model.dist_mu is not None:
            x1 = self.model.dist_mu[dim1:dim1 + 1]
            y1 = self.model.dist_mu[dim2:dim2 + 1]
            x2 = (self.model.dist_mu[dim1:dim1 + 1] +
                  self.model.dist_std[dim1:dim1 + 1])
            y2 = (self.model.dist_mu[dim2:dim2 + 1] +
                  self.model.dist_std[dim2:dim2 + 1])
        plt.plot([x1, x2], [y1, y2], color='k', linestyle='-', linewidth=2)
        axes = plt.gca()
        axes.set_xlim([-6, 6])
        axes.set_ylim([-6, 6])
        axes.set_title('dim {} vs dim {}'.format(dim1, dim2))
        plt.grid(True)
        save_file = osp.join(self.log_dir, 'scatter%d.png' % epoch)
        plt.savefig(save_file)
Example #11
0
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule
    from railrl.torch.vae.conv_vae import (
        ConvVAE,
        SpatialAutoEncoder,
        AutoEncoder,
    )
    import railrl.torch.vae.conv_vae as conv_vae
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant.get("representation_size",
                                      variant.get("latent_sizes", None))
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    variant['generate_vae_dataset_kwargs']['batch_size'] = variant[
        'algo_kwargs']['batch_size']
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if 'context_schedule' in variant:
        schedule = variant['context_schedule']
        if type(schedule) is dict:
            context_schedule = PiecewiseLinearSchedule(**schedule)
        else:
            context_schedule = ConstantSchedule(schedule)
        variant['algo_kwargs']['context_schedule'] = context_schedule
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    if variant['algo_kwargs'].get('is_auto_encoder', False):
        model = AutoEncoder(representation_size,
                            decoder_output_activation=decoder_activation,
                            **variant['vae_kwargs'])
    elif variant.get('use_spatial_auto_encoder', False):
        model = SpatialAutoEncoder(
            representation_size,
            decoder_output_activation=decoder_activation,
            **variant['vae_kwargs'])
    else:
        vae_class = variant.get('vae_class', ConvVAE)
        if use_linear_dynamics:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              action_dim=action_dim,
                              **variant['vae_kwargs'])
        else:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              **variant['vae_kwargs'])
    model.to(ptu.device)

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
Example #12
0
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        batch_size=64,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        linearity_weight=0.0,
        use_linear_dynamics=False,
        noisy_linear_dynamics=False,
        scale_linear_dynamics=False,
        use_parallel_dataloading=True,
        train_data_workers=2,
    ):
        self.quick_init(locals())
        self.batch_size = batch_size
        self.beta = beta
        if lr is None:
            lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize

        if ptu.gpu_enabled():
            model.cuda()

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset['next_obs'].dtype == np.uint8
        assert self.test_dataset['next_obs'].dtype == np.uint8
        assert self.train_dataset['obs'].dtype == np.uint8
        assert self.test_dataset['obs'].dtype == np.uint8

        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.gaussian_decoder_loss = self.model.gaussian_decoder
        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            self._train_weights = None
            base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=BatchSampler(
                    base_sampler,
                    batch_size=batch_size,
                    drop_last=False,
                ),
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=BatchSampler(
                    InfiniteRandomSampler(self.test_dataset),
                    batch_size=batch_size,
                    drop_last=False,
                ),
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.linearity_weight = linearity_weight
        self.use_linear_dynamics = use_linear_dynamics
        self.noisy_linear_dynamics = noisy_linear_dynamics
        self.scale_linear_dynamics = scale_linear_dynamics
        self.vae_logger_stats_for_rl = {}
        self._extra_stats_to_log = None
Example #13
0
class ConvVAETrainer(Serializable):
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        batch_size=64,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        linearity_weight=0.0,
        use_linear_dynamics=False,
        noisy_linear_dynamics=False,
        scale_linear_dynamics=False,
        use_parallel_dataloading=True,
        train_data_workers=2,
    ):
        self.quick_init(locals())
        self.batch_size = batch_size
        self.beta = beta
        if lr is None:
            lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize

        if ptu.gpu_enabled():
            model.cuda()

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset['next_obs'].dtype == np.uint8
        assert self.test_dataset['next_obs'].dtype == np.uint8
        assert self.train_dataset['obs'].dtype == np.uint8
        assert self.test_dataset['obs'].dtype == np.uint8

        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.gaussian_decoder_loss = self.model.gaussian_decoder
        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            self._train_weights = None
            base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=BatchSampler(
                    base_sampler,
                    batch_size=batch_size,
                    drop_last=False,
                ),
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=BatchSampler(
                    InfiniteRandomSampler(self.test_dataset),
                    batch_size=batch_size,
                    drop_last=False,
                ),
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.linearity_weight = linearity_weight
        self.use_linear_dynamics = use_linear_dynamics
        self.noisy_linear_dynamics = noisy_linear_dynamics
        self.scale_linear_dynamics = scale_linear_dynamics
        self.vae_logger_stats_for_rl = {}
        self._extra_stats_to_log = None

    def _kl_np_to_np(self, np_imgs):
        torch_input = ptu.np_to_var(normalize_image(np_imgs))
        mu, log_var = self.model.encode(torch_input)
        return ptu.get_numpy(
            -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))

    def set_vae(self, vae):
        self.model = vae
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

    def get_batch(self, train=True):
        if self.use_parallel_dataloading:
            if not train:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader)
            return {
                'obs': ptu.Variable(samples[0][0]),
                'actions': ptu.Variable(samples[1][0]),
                'next_obs': ptu.Variable(samples[2][0]),
            }

        dataset = self.train_dataset if train else self.test_dataset
        ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        return ptu.np_to_var(samples)

    def logprob_iwae(self, recon_x, x):
        if self.gaussian_decoder_loss:
            error = -(recon_x - x)**2
        else:
            error = x * torch.log(torch.clamp(recon_x, min=1e-30)) \
                    + (1-x) * torch.log(torch.clamp(1-recon_x, min=1e-30))
        return error

    def logprob_vae(self, recon_x, x):
        batch_size = recon_x.shape[0]

        if self.gaussian_decoder_loss:
            return -((recon_x - x)**2).sum() / batch_size
        else:
            # Divide by batch_size rather than setting size_average=True because
            # otherwise the averaging will also happen across dimension 1 (the
            # pixels)
            return -F.binary_cross_entropy(
                recon_x,
                x.narrow(start=0, length=self.imlength,
                         dimension=1).contiguous().view(-1, self.imlength),
                size_average=False,
            ) / batch_size

    def kl_divergence(self, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),
                                dim=1).mean()

    def compute_vae_loss(self, x_recon, x, z_mu, z_logvar, z_sampled, beta):
        batch_size = x_recon.shape[0]
        k = x_recon.shape[1]

        x_recon = x_recon.view((batch_size * k, -1))
        x = x.view((batch_size * k, -1))
        z_mu = z_mu.view((batch_size * k, -1))
        z_logvar = z_logvar.view((batch_size * k, -1))

        de = -self.logprob_vae(x_recon, x)
        kle = self.kl_divergence(z_mu, z_logvar)
        loss = de + beta * kle
        return loss, de, kle

    def compute_iwae_loss(self, x_recon, x, z_mu, z_logvar, z_sampled, beta):
        batch_size = x_recon.shape[0]
        log_p_xgz = self.logprob_iwae(x_recon, x).sum(dim=-1)

        prior_dist = torch.distributions.Normal(
            ptu.Variable(torch.zeros(z_sampled.shape)),
            ptu.Variable(torch.ones(z_sampled.shape)))
        log_p_z = prior_dist.log_prob(z_sampled).sum(dim=-1)

        z_std = torch.exp(0.5 * z_logvar)
        encoder_dist = torch.distributions.Normal(z_mu, z_std)
        log_q_zgx = encoder_dist.log_prob(z_sampled).sum(dim=-1)

        log_w = log_p_xgz + beta * (log_p_z - log_q_zgx)
        w_tilde = F.softmax(log_w, dim=-1).detach()

        loss = -(log_w * w_tilde).sum() / batch_size

        return loss

    def state_linearity_loss(self, obs, next_obs, actions):
        latent_obs_mu, latent_obs_logvar = self.model.encode(obs)
        latent_next_obs_mu, latent_next_obs_logvar = self.model.encode(
            next_obs)
        if self.noisy_linear_dynamics:
            latent_obs = self.model.reparameterize(latent_obs_mu,
                                                   latent_obs_logvar)
        else:
            latent_obs = latent_obs_mu
        action_obs_pair = torch.cat([latent_obs, actions], dim=1)
        prediction = self.model.linear_constraint_fc(action_obs_pair)
        if self.scale_linear_dynamics:
            std = latent_next_obs_logvar.mul(0.5).exp_()
            scaling = 1 / std
        else:
            scaling = 1.0
        return torch.norm(
            scaling * (prediction - latent_next_obs_mu))**2 / self.batch_size

    def train_epoch(self, epoch, batches=100):
        self.model.own_train()
        vae_losses = []
        losses = []
        des = []
        kles = []
        linear_losses = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(batches):
            data = self.get_batch()
            obs = data['obs']
            next_obs = data['next_obs']
            actions = data['actions']
            self.optimizer.zero_grad()
            x_recon, z_mu, z_logvar, z = self.model(next_obs)
            batch_size = x_recon.shape[0]
            k = x_recon.shape[1]
            x = next_obs.view((batch_size, 1, -1)).repeat(torch.Size([1, k,
                                                                      1]))
            vae_loss, de, kle = self.compute_vae_loss(x_recon, x, z_mu,
                                                      z_logvar, z, beta)
            loss = vae_loss
            '''print("---------------------")
            print("loss ", loss.data)
            print("z_mu ", z_mu.data[0][0])
            print("z_logvar ", z_logvar.data[0][0])
            print("z ", z.data[0][0])
            print("---------------------")'''

            if self.use_linear_dynamics:
                linear_dynamics_loss = self.state_linearity_loss(
                    obs, next_obs, actions)
                loss += self.linearity_weight * linear_dynamics_loss
                linear_losses.append(float(linear_dynamics_loss.data[0]))
            loss.backward()

            vae_losses.append(float(vae_loss.data[0]))
            losses.append(float(loss.data[0]))
            des.append(float(de.data[0]))
            kles.append(float(kle.data[0]))
            self.optimizer.step()
            del data, obs, next_obs, actions, x_recon, z_mu, z_logvar, \
                z, x, vae_loss, de, kle, loss

        logger.record_tabular("train/epoch", epoch)
        logger.record_tabular("train/decoder_loss", np.mean(des))
        logger.record_tabular("train/KL", np.mean(kles))
        if self.use_linear_dynamics:
            logger.record_tabular("train/linear_loss", np.mean(linear_losses))
        logger.record_tabular("train/vae_loss", np.mean(vae_losses))
        logger.record_tabular("train/loss", np.mean(losses))

    def test_epoch(self,
                   epoch,
                   save_reconstruction=True,
                   save_interpolation=True,
                   save_vae=True):
        self.model.eval()
        vae_losses = []
        iwae_losses = []
        losses = []
        des = []
        kles = []
        linear_losses = []
        zs = []
        beta = float(self.beta_schedule.get_value(epoch))

        for batch_idx in range(10):
            data = self.get_batch(train=False)
            obs = data['obs']
            obs = obs.detach()
            next_obs = data['next_obs']
            next_obs = next_obs.detach()
            actions = data['actions']
            actions = actions.detach()

            x_recon, z_mu, z_logvar, z = self.model(next_obs, n_imp=25)
            x_recon = x_recon.detach()
            z_mu = z_mu.detach()
            z = z.detach()

            batch_size = x_recon.shape[0]
            k = x_recon.shape[1]
            x = next_obs.view((batch_size, 1, -1)).repeat(torch.Size([1, k,
                                                                      1]))
            x = x.detach()
            vae_loss, de, kle = self.compute_vae_loss(x_recon, x, z_mu,
                                                      z_logvar, z, beta)
            vae_loss, de, kle = vae_loss.detach(), de.detach(), kle.detach()
            iwae_loss = self.compute_iwae_loss(x_recon, x, z_mu, z_logvar, z,
                                               beta)
            iwae_loss = iwae_loss.detach()
            loss = vae_loss
            if self.use_linear_dynamics:
                linear_dynamics_loss = self.state_linearity_loss(
                    obs, next_obs, actions)
                linear_dynamics_loss = linear_dynamics_loss.detach()
                loss += self.linearity_weight * linear_dynamics_loss
                linear_losses.append(float(
                    linear_dynamics_loss.data[0]))  #here too

            z_data = ptu.get_numpy(z_mu[:, 0].cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :].copy())
            vae_losses.append(float(vae_loss.data[0]))
            iwae_losses.append(float(iwae_loss.data[0]))
            losses.append(float(loss.data[0]))
            des.append(float(de.data[0]))
            kles.append(float(kle.data[0]))

            if batch_idx == 0 and save_reconstruction:
                n = min(data['next_obs'].size(0), 16)
                comparison = torch.cat([
                    data['next_obs'][:n].narrow(start=0,
                                                length=self.imlength,
                                                dimension=1).contiguous().view(
                                                    -1, self.input_channels,
                                                    self.imsize, self.imsize),
                    x_recon[:, 0].contiguous().view(
                        self.batch_size,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    )[:n]
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    'r_%d.png' % epoch)
                save_image(comparison.data.cpu(), save_dir, nrow=n)
                del comparison

            if batch_idx == 0 and save_interpolation:
                n = min(data['next_obs'].size(0), 10)

                z1 = z_mu[:n, 0]
                z2 = z_mu[n:2 * n, 0]

                num_steps = 8

                z_interp = []
                for i in np.linspace(0.0, 1.0, num_steps):
                    z_interp.append(float(i) * z1 + float(1 - i) * z2)
                z_interp = torch.cat(z_interp)

                imgs = self.model.decode(z_interp)
                imgs = imgs.view((num_steps, n, 3, self.imsize, self.imsize))
                imgs = imgs.permute([1, 0, 2, 3, 4])
                imgs = imgs.contiguous().view(
                    (n * num_steps, 3, self.imsize, self.imsize))

                save_dir = osp.join(logger.get_snapshot_dir(),
                                    'i_%d.png' % epoch)
                save_image(
                    imgs.data.cpu(),
                    save_dir,
                    nrow=num_steps,
                )
                del imgs
                del z_interp

            del obs, next_obs, actions, x_recon, z_mu, z_logvar, \
                z, x, vae_loss, de, kle, loss

        zs = np.array(zs)
        self.model.dist_mu = zs.mean(axis=0)
        self.model.dist_std = zs.std(axis=0)
        del zs

        logger.record_tabular("test/decoder_loss", np.mean(des))
        logger.record_tabular("test/KL", np.mean(kles))
        if self.use_linear_dynamics:
            logger.record_tabular("test/linear_loss", np.mean(linear_losses))
        logger.record_tabular("test/loss", np.mean(losses))
        logger.record_tabular("test/vae_loss", np.mean(vae_losses))
        logger.record_tabular("test/iwae_loss", np.mean(iwae_losses))
        logger.record_tabular(
            "test/iwae_vae_diff",
            np.mean(np.array(iwae_losses) - np.array(vae_losses)))
        logger.record_tabular("beta", beta)

        process = psutil.Process(os.getpid())
        logger.record_tabular("RAM Usage (Mb)",
                              int(process.memory_info().rss / 1000000))

        num_active_dims = 0
        num_active_dims2 = 0
        for std in self.model.dist_std:
            if std > 0.15:
                num_active_dims += 1
            if std > 0.05:
                num_active_dims2 += 1
        logger.record_tabular("num_active_dims", num_active_dims)
        logger.record_tabular("num_active_dims2", num_active_dims2)

        logger.dump_tabular()
        if save_vae:
            logger.save_itr_params(epoch,
                                   self.model,
                                   prefix='vae',
                                   save_anyway=True)  # slow...
    def __init__(
            self,
            vae,
            *args,
            decoded_obs_key='image_observation',
            decoded_achieved_goal_key='image_achieved_goal',
            decoded_desired_goal_key='image_desired_goal',
            exploration_rewards_type='None',
            exploration_rewards_scale=1.0,
            vae_priority_type='None',
            start_skew_epoch=0,
            power=1.0,
            internal_keys=None,
            exploration_schedule_kwargs=None,
            priority_function_kwargs=None,
            exploration_counter_kwargs=None,
            relabeling_goal_sampling_mode='vae_prior',
            decode_vae_goals=False,
            **kwargs
    ):
        if internal_keys is None:
            internal_keys = []

        for key in [
            decoded_obs_key,
            decoded_achieved_goal_key,
            decoded_desired_goal_key
        ]:
            if key not in internal_keys:
                internal_keys.append(key)
        super().__init__(internal_keys=internal_keys, *args, **kwargs)
        # assert isinstance(self.env, VAEWrappedEnv)
        self.vae = vae
        self.decoded_obs_key = decoded_obs_key
        self.decoded_desired_goal_key = decoded_desired_goal_key
        self.decoded_achieved_goal_key = decoded_achieved_goal_key
        self.exploration_rewards_type = exploration_rewards_type
        self.exploration_rewards_scale = exploration_rewards_scale
        self.start_skew_epoch = start_skew_epoch
        self.vae_priority_type = vae_priority_type
        self.power = power
        self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode
        self.decode_vae_goals = decode_vae_goals

        if exploration_schedule_kwargs is None:
            self.explr_reward_scale_schedule = \
                ConstantSchedule(self.exploration_rewards_scale)
        else:
            self.explr_reward_scale_schedule = \
                PiecewiseLinearSchedule(**exploration_schedule_kwargs)

        self._give_explr_reward_bonus = (
                exploration_rewards_type != 'None'
                and exploration_rewards_scale != 0.
        )
        self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32)
        self._prioritize_vae_samples = (
                vae_priority_type != 'None'
                and power != 0.
        )
        self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32)
        self._vae_sample_probs = None

        self.use_dynamics_model = (
                self.exploration_rewards_type == 'forward_model_error'
        )
        if self.use_dynamics_model:
            self.initialize_dynamics_model()

        type_to_function = {
            'reconstruction_error': self.reconstruction_mse,
            'bce': self.binary_cross_entropy,
            'latent_distance': self.latent_novelty,
            'latent_distance_true_prior': self.latent_novelty_true_prior,
            'forward_model_error': self.forward_model_error,
            'gaussian_inv_prob': self.gaussian_inv_prob,
            'bernoulli_inv_prob': self.bernoulli_inv_prob,
            'vae_prob': self.vae_prob,
            'hash_count': self.hash_count_reward,
            'None': self.no_reward,
        }

        self.exploration_reward_func = (
            type_to_function[self.exploration_rewards_type]
        )
        self.vae_prioritization_func = (
            type_to_function[self.vae_priority_type]
        )

        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.exploration_rewards_type == 'hash_count':
            if exploration_counter_kwargs is None:
                exploration_counter_kwargs = dict()
            self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs)
        self.epoch = 0
Example #15
0
    def __init__(self,
                 env,
                 qf,
                 policy,
                 target_qf,
                 target_policy,
                 exploration_policy,
                 policy_learning_rate=1e-4,
                 qf_learning_rate=1e-3,
                 qf_weight_decay=0,
                 target_hard_update_period=1000,
                 tau=1e-2,
                 use_soft_update=True,
                 qf_criterion=None,
                 residual_gradient_weight=0,
                 epoch_discount_schedule=None,
                 eval_with_target_policy=False,
                 policy_pre_activation_weight=0.,
                 optimizer_class=optim.Adam,
                 plotter=None,
                 render_eval_paths=False,
                 obs_normalizer: TorchFixedNormalizer = None,
                 action_normalizer: TorchFixedNormalizer = None,
                 num_paths_for_normalization=0,
                 min_q_value=-np.inf,
                 max_q_value=np.inf,
                 **kwargs):
        """

        :param env:
        :param qf:
        :param policy:
        :param exploration_policy:
        :param policy_learning_rate:
        :param qf_learning_rate:
        :param qf_weight_decay:
        :param target_hard_update_period:
        :param tau:
        :param use_soft_update:
        :param qf_criterion: Loss function to use for the q function. Should
        be a function that takes in two inputs (y_predicted, y_target).
        :param residual_gradient_weight: c, float between 0 and 1. The gradient
        used for training the Q function is then
            (1-c) * normal td gradient + c * residual gradient
        :param epoch_discount_schedule: A schedule for the discount factor
        that varies with the epoch.
        :param kwargs:
        """
        self.target_policy = target_policy
        if eval_with_target_policy:
            eval_policy = self.target_policy
        else:
            eval_policy = policy
        super().__init__(env,
                         exploration_policy,
                         eval_policy=eval_policy,
                         **kwargs)
        if qf_criterion is None:
            qf_criterion = nn.MSELoss()
        self.qf = qf
        self.policy = policy
        self.policy_learning_rate = policy_learning_rate
        self.qf_learning_rate = qf_learning_rate
        self.qf_weight_decay = qf_weight_decay
        self.target_hard_update_period = target_hard_update_period
        self.tau = tau
        self.use_soft_update = use_soft_update
        self.residual_gradient_weight = residual_gradient_weight
        self.policy_pre_activation_weight = policy_pre_activation_weight
        self.qf_criterion = qf_criterion
        if epoch_discount_schedule is None:
            epoch_discount_schedule = ConstantSchedule(self.discount)
        self.epoch_discount_schedule = epoch_discount_schedule
        self.plotter = plotter
        self.render_eval_paths = render_eval_paths
        self.obs_normalizer = obs_normalizer
        self.action_normalizer = action_normalizer
        self.num_paths_for_normalization = num_paths_for_normalization
        self.min_q_value = min_q_value
        self.max_q_value = max_q_value

        self.target_qf = target_qf
        self.qf_optimizer = optimizer_class(
            self.qf.parameters(),
            lr=self.qf_learning_rate,
        )
        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=self.policy_learning_rate,
        )
    def __init__(self,
                 env,
                 exploration_policy,
                 beta_q,
                 beta_q2,
                 beta_v,
                 policy,
                 train_with='both',
                 goal_reached_epsilon=1e-3,
                 learning_rate=1e-3,
                 prioritized_replay=False,
                 always_reset_env=True,
                 finite_horizon=False,
                 max_num_steps_left=0,
                 flip_training_period=100,
                 train_simultaneously=True,
                 policy_and_target_update_period=2,
                 target_policy_noise=0.2,
                 target_policy_noise_clip=0.5,
                 soft_target_tau=0.005,
                 per_beta_schedule=None,
                 **kwargs):
        self.train_simultaneously = train_simultaneously
        assert train_with in ['both', 'off_policy', 'on_policy']
        super().__init__(env, exploration_policy, **kwargs)
        self.eval_sampler = MultigoalSimplePathSampler(
            env=self.env,
            policy=self.eval_policy,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
            tau_sampling_function=lambda: 0,
            goal_sampling_function=self.env.sample_goal_for_rollout,
            cycle_taus_for_rollout=False,
            render=self.render_during_eval)
        self.goal_reached_epsilon = goal_reached_epsilon
        self.beta_q = beta_q
        self.beta_v = beta_v
        self.beta_q2 = beta_q2
        self.target_beta_q = self.beta_q.copy()
        self.target_beta_q2 = self.beta_q2.copy()
        self.train_with = train_with
        self.policy = policy
        self.target_policy = policy
        self.prioritized_replay = prioritized_replay
        self.flip_training_period = flip_training_period

        self.always_reset_env = always_reset_env
        self.finite_horizon = finite_horizon
        self.max_num_steps_left = max_num_steps_left
        assert max_num_steps_left >= 0

        self.policy_and_target_update_period = policy_and_target_update_period
        self.target_policy_noise = target_policy_noise
        self.target_policy_noise_clip = target_policy_noise_clip
        self.soft_target_tau = soft_target_tau
        if per_beta_schedule is None:
            per_beta_schedule = ConstantSchedule(1.0)
        self.per_beta_schedule = per_beta_schedule

        self.beta_q_optimizer = Adam(self.beta_q.parameters(),
                                     lr=learning_rate)
        self.beta_q2_optimizer = Adam(self.beta_q2.parameters(),
                                      lr=learning_rate)
        self.beta_v_optimizer = Adam(self.beta_v.parameters(),
                                     lr=learning_rate)
        self.policy_optimizer = Adam(
            self.policy.parameters(),
            lr=learning_rate,
        )
        self.q_criterion = nn.BCELoss()
        self.v_criterion = nn.BCELoss()

        # For the multitask env
        self._rollout_goal = None

        self.extra_eval_statistics = OrderedDict()
        for key_not_always_updated in [
                'Policy Gradient Norms',
                'Beta Q Gradient Norms',
                'dQ/da',
        ]:
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    key_not_always_updated,
                    np.zeros(2),
                ))

        self.training_policy = False

        # For debugging
        self.train_batches = []