def compute_log_p_log_q_log_d(model,
                              batch,
                              decoder_distribution='bernoulli',
                              num_latents_to_sample=1,
                              sampling_method='importance_sampling'):
    x_0 = ptu.from_numpy(batch["x_0"])
    data = batch["x_t"]
    imgs = ptu.from_numpy(data)
    latent_distribution_params = model.encode(imgs, x_0)
    r1 = model.latent_sizes[0]
    batch_size = data.shape[0]
    log_p, log_q, log_d = ptu.zeros(
        (batch_size, num_latents_to_sample)), ptu.zeros(
            (batch_size, num_latents_to_sample)), ptu.zeros(
                (batch_size, num_latents_to_sample))
    true_prior = Normal(ptu.zeros((batch_size, r1)), ptu.ones(
        (batch_size, r1)))
    mus, logvars = latent_distribution_params[:2]
    for i in range(num_latents_to_sample):
        if sampling_method == 'importance_sampling':
            latents = model.rsample(latent_distribution_params[:2])
        elif sampling_method == 'biased_sampling':
            latents = model.rsample(latent_distribution_params[:2])
        elif sampling_method == 'true_prior_sampling':
            latents = true_prior.rsample()
        else:
            raise EnvironmentError('Invalid Sampling Method Provided')

        stds = logvars.exp().pow(.5)
        vae_dist = Normal(mus, stds)
        log_p_z = true_prior.log_prob(latents).sum(dim=1)
        log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1)

        if len(latent_distribution_params) == 3:  # add conditioning for CVAEs
            latents = torch.cat((latents, latent_distribution_params[2]),
                                dim=1)

        if decoder_distribution == 'bernoulli':
            decoded = model.decode(latents)[0]
            log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) *
                                        (1 - decoded) + 1e-8).sum(dim=1)
        elif decoder_distribution == 'gaussian_identity_variance':
            _, obs_distribution_params = model.decode(latents)
            dec_mu, dec_logvar = obs_distribution_params
            dec_var = dec_logvar.exp()
            decoder_dist = Normal(dec_mu, dec_var.pow(.5))
            log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1)
        else:
            raise EnvironmentError('Invalid Decoder Distribution Provided')

        log_p[:, i] = log_p_z
        log_q[:, i] = log_q_z_given_x
        log_d[:, i] = log_d_x_given_z
    return log_p, log_q, log_d
def compute_log_p_log_q_log_d(
    model,
    data,
    decoder_distribution='bernoulli',
    num_latents_to_sample=1,
    sampling_method='importance_sampling'
):
    assert data.dtype != np.uint8, 'images should be normalized'
    imgs = ptu.from_numpy(data)
    latent_distribution_params = model.encode(imgs)
    representation_size = model.representation_size
    batch_size = data.shape[0]
    log_p, log_q, log_d = ptu.zeros((batch_size, num_latents_to_sample)), ptu.zeros(
        (batch_size, num_latents_to_sample)), ptu.zeros((batch_size, num_latents_to_sample))
    true_prior = Normal(ptu.zeros((batch_size, representation_size)),
                        ptu.ones((batch_size, representation_size)))
    mus, logvars = latent_distribution_params
    for i in range(num_latents_to_sample):
        if sampling_method == 'importance_sampling':
            latents = model.rsample(latent_distribution_params)
        elif sampling_method == 'biased_sampling':
            latents = model.rsample(latent_distribution_params)
        elif sampling_method == 'true_prior_sampling':
            latents = true_prior.rsample()
        else:
            raise EnvironmentError('Invalid Sampling Method Provided')

        stds = logvars.exp().pow(.5)
        vae_dist = Normal(mus, stds)
        log_p_z = true_prior.log_prob(latents).sum(dim=1)
        log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1)

        if decoder_distribution == 'bernoulli':
            decoded = model.decode(latents)[0]
            log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) * (1 - decoded) + 1e-8).sum(dim=1)
        elif decoder_distribution == 'gaussian_identity_variance':
            _, obs_distribution_params = model.decode(latents)
            dec_mu, dec_logvar = obs_distribution_params
            dec_var = dec_logvar.exp()
            decoder_dist = Normal(dec_mu, dec_var.pow(.5))
            log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1)
        else:
            raise EnvironmentError('Invalid Decoder Distribution Provided')

        log_p[:, i] = log_p_z
        log_q[:, i] = log_q_z_given_x
        log_d[:, i] = log_d_x_given_z
    return log_p, log_q, log_d
Пример #3
0
 def rsample(self, ):
     z = (self.normal_means + self.normal_stds *
          Normal(ptu.zeros(self.normal_means.size()),
                 ptu.ones(self.normal_stds.size())).sample())
     z.requires_grad_()
     c = self.categorical.sample()[:, :, None]
     s = torch.matmul(z, c)
     return torch.squeeze(s, 2)
Пример #4
0
    def mean(self, ):
        """Misleading function name; this actually now samples the mean of the
        most likely component.
        c ~ argmax(C), returns mu_c

        This often computes the mode of the distribution, but not always.
        """
        c = ptu.zeros(self.weights.shape[:2])
        ind = torch.argmax(self.weights, dim=1)  # [:, 0]
        c.scatter_(1, ind, 1)
        s = torch.matmul(self.normal_means, c[:, :, None])
        return torch.squeeze(s, 2)
Пример #5
0
    def compute_density(self, data):
        orig_data_length = len(data)
        data = np.vstack([
            data for _ in range(self.n_average)
        ])
        data = ptu.from_numpy(data)
        if self.mode == 'biased':
            latents, means, log_vars, stds = (
                self.encoder.get_encoding_and_suff_stats(data)
            )
            importance_weights = ptu.ones(data.shape[0])
        elif self.mode == 'prior':
            latents = ptu.randn(len(data), self.z_dim)
            importance_weights = ptu.ones(data.shape[0])
        elif self.mode == 'importance_sampling':
            latents, means, log_vars, stds = (
                self.encoder.get_encoding_and_suff_stats(data)
            )
            prior = Normal(ptu.zeros(1), ptu.ones(1))
            prior_log_prob = prior.log_prob(latents).sum(dim=1)

            encoder_distrib = Normal(means, stds)
            encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1)

            importance_weights = (prior_log_prob - encoder_log_prob).exp()
        else:
            raise NotImplementedError()

        unweighted_data_log_prob = self.compute_log_prob(
            data, self.decoder, latents
        ).squeeze(1)
        unweighted_data_prob = unweighted_data_log_prob.exp()
        unnormalized_data_prob = unweighted_data_prob * importance_weights
        """
        Average over `n_average`
        """
        dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0)
        # pre_avg.shape = ORIG_LEN x N_AVERAGE
        dp_stacked = torch.stack(dp_split, dim=1)
        # final.shape = ORIG_LEN
        unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False)

        """
        Compute the importance weight denomintors.
        This requires summing across the `n_average` dimension.
        """
        iw_split = torch.split(importance_weights, orig_data_length, dim=0)
        iw_stacked = torch.stack(iw_split, dim=1)
        iw_denominators = iw_stacked.sum(dim=1, keepdim=False)

        final = unnormalized_dp / iw_denominators
        return ptu.get_numpy(final)
Пример #6
0
    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (self.normal_mean + self.normal_std *
             Normal(ptu.zeros(self.normal_mean.size()),
                    ptu.ones(self.normal_std.size())).sample())
        z.requires_grad_()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)
Пример #7
0
    def __init__(
            self,
            env,
            dsp,
            policy,
            classifier,
            search_buffer,

            policy_lr=1e-3,
            classifier_lr=1e-3,
            optimizer_class=optim.Adam,

            use_automatic_entropy_tuning=True,
            target_entropy=None,
    ):
        super().__init__()
        self.env = env
        self.dsp = dsp
        self.policy = policy
        self.classifier = classifier
        self.search_buffer = search_buffer

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.classifier_criterion = nn.MSELoss()
        self.dsp_optimizer = optimizer_class(
            self.dsp.parameters(),
            lr=policy_lr,
        )
        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.classisier_optimizer = optimizer_class(
            self.classifier.parameters(),
            lr=classifier_lr,
        )

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
Пример #8
0
    def __init__(self,
                 hidden_sizes,
                 obs_dim,
                 action_dim,
                 std=None,
                 init_w=1e-3,
                 min_log_std=None,
                 max_log_std=None,
                 num_gaussians=1,
                 std_architecture="shared",
                 **kwargs):
        super().__init__(
            hidden_sizes,
            input_size=obs_dim,
            output_size=action_dim * num_gaussians,
            init_w=init_w,
            # output_activation=torch.tanh,
            **kwargs)
        self.action_dim = action_dim
        self.num_gaussians = num_gaussians
        self.min_log_std = min_log_std
        self.max_log_std = max_log_std
        self.log_std = None
        self.std = std
        self.std_architecture = std_architecture
        if std is None:
            last_hidden_size = obs_dim
            if len(hidden_sizes) > 0:
                last_hidden_size = hidden_sizes[-1]

            if self.std_architecture == "shared":
                self.last_fc_log_std = nn.Linear(last_hidden_size,
                                                 action_dim * num_gaussians)
                self.last_fc_log_std.weight.data.uniform_(-init_w, init_w)
                self.last_fc_log_std.bias.data.uniform_(-init_w, init_w)
            elif self.std_architecture == "values":
                self.log_std_logits = nn.Parameter(
                    ptu.zeros(action_dim * num_gaussians, requires_grad=True))
            else:
                error
        else:
            self.log_std = np.log(std)
            assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX
        self.last_fc_weights = nn.Linear(last_hidden_size, num_gaussians)
        self.last_fc_weights.weight.data.uniform_(-init_w, init_w)
        self.last_fc_weights.bias.data.uniform_(-init_w, init_w)
Пример #9
0
    def __init__(
        self,
        downsample_size,
        input_channels,
        num_feat_points,
        temperature=1.0,
        init_w=1e-3,
        input_size=32,
        hidden_init=ptu.fanin_init,
        output_activation=identity,
    ):
        super().__init__()

        self.downsample_size = downsample_size
        self.temperature = temperature
        self.num_feat_points = num_feat_points
        self.hidden_init = hidden_init
        self.output_activation = output_activation
        self.input_channels = input_channels
        self.input_size = input_size

        #        self.bn1 = nn.BatchNorm2d(1)
        self.conv1 = nn.Conv2d(input_channels, 48, kernel_size=5, stride=2)
        #        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(48, 48, kernel_size=5, stride=1)
        self.conv3 = nn.Conv2d(48,
                               self.num_feat_points,
                               kernel_size=5,
                               stride=1)

        test_mat = ptu.zeros(1, self.input_channels, self.input_size,
                             self.input_size)
        test_mat = self.conv1(test_mat)
        test_mat = self.conv2(test_mat)
        test_mat = self.conv3(test_mat)
        self.out_size = int(np.prod(test_mat.shape))
        self.fc1 = nn.Linear(2 * self.num_feat_points, 400)
        self.fc2 = nn.Linear(400, 300)
        self.last_fc = nn.Linear(
            300,
            self.input_channels * self.downsample_size * self.downsample_size)

        self.init_weights(init_w)
        self.i = 0
Пример #10
0
    def __init__(
            self,
            env,
            policy,
            qf1,
            qf2,
            target_qf1,
            target_qf2,

            discount=0.99,
            reward_scale=1.0,

            policy_lr=1e-3,
            qf_lr=1e-3,
            optimizer_class=optim.Adam,

            soft_target_tau=1e-2,
            target_update_period=1,
            plotter=None,
            render_eval_paths=False,

            use_automatic_entropy_tuning=True,
            target_entropy=None,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
Пример #11
0
    def __init__(
            self,
            env,
            policy,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            behavior_policy=None,
            dim_mult=1,
            discount=0.99,
            reward_scale=1.0,
            policy_lr=1e-3,
            qf_lr=1e-3,
            optimizer_class=optim.Adam,
            soft_target_tau=1e-2,
            target_update_period=1,
            plotter=None,
            render_eval_paths=False,
            use_automatic_entropy_tuning=True,
            target_entropy=None,
            use_target_nets=True,
            policy_eval_start=0,
            num_qs=2,

            ## For min_Q runs
            with_min_q=False,
            new_min_q=False,
            min_q_version=0,
            temp=1.0,
            hinge_bellman=False,
            use_projected_grad=False,
            normalize_magnitudes=False,
            regress_constant=False,
            min_q_weight=1.0,
            data_subtract=True,

            ## sort of backup
            max_q_backup=False,
            deterministic_backup=False,
            num_random=4,

            ## handle lagrange
            with_lagrange=False,
            lagrange_thresh=10.0,

            ## Handling discrete actions
            discrete=False,
            *args,
            **kwargs):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                if self.env is None:
                    self.target_entropy = -2
                else:
                    self.target_entropy = -np.prod(
                        self.env.action_space.shape).item(
                        )  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.with_lagrange = with_lagrange
        if self.with_lagrange:
            self.target_action_gap = lagrange_thresh
            self.log_alpha_prime = ptu.zeros(1, requires_grad=True)
            self.alpha_prime_optimizer = optimizer_class(
                [
                    self.log_alpha_prime,
                ],
                lr=qf_lr,
            )

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self._use_target_nets = use_target_nets
        self.policy_eval_start = policy_eval_start

        self._current_epoch = 0
        self._policy_update_ctr = 0
        self._num_q_update_steps = 0
        self._num_policy_update_steps = 0
        self.policy_eval_start = policy_eval_start
        self._num_policy_steps = 1

        if not self._use_target_nets:
            self.target_qf1 = qf1
            self.target_qf2 = qf2

        self.softmax = torch.nn.Softmax(dim=-1)
        self.num_qs = num_qs

        ## min Q
        self.with_min_q = with_min_q
        self.new_min_q = new_min_q
        self.temp = temp
        self.min_q_version = min_q_version
        self.use_projected_grad = use_projected_grad
        self.normalize_magnitudes = normalize_magnitudes
        self.regress_constant = regress_constant
        self.min_q_weight = min_q_weight
        self.softmax = torch.nn.Softmax(dim=1)
        self.hinge_bellman = hinge_bellman
        self.softplus = torch.nn.Softplus(beta=self.temp, threshold=20)
        self.data_subtract = data_subtract

        self.max_q_backup = max_q_backup
        self.deterministic_backup = deterministic_backup
        self.num_random = num_random

        self.discrete = discrete
Пример #12
0
    def __init__(
            self,
            env,
            policy,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            buffer_policy=None,

            discount=0.99,
            reward_scale=1.0,
            beta=1.0,
            beta_schedule_kwargs=None,

            policy_lr=1e-3,
            qf_lr=1e-3,
            policy_weight_decay=0,
            q_weight_decay=0,
            optimizer_class=optim.Adam,

            soft_target_tau=1e-2,
            target_update_period=1,
            plotter=None,
            render_eval_paths=False,

            use_automatic_entropy_tuning=True,
            target_entropy=None,

            bc_num_pretrain_steps=0,
            q_num_pretrain1_steps=0,
            q_num_pretrain2_steps=0,
            bc_batch_size=128,
            bc_loss_type="mle",
            awr_loss_type="mle",
            save_bc_policies=0,
            alpha=1.0,

            policy_update_period=1,
            q_update_period=1,

            weight_loss=True,
            compute_bc=True,

            bc_weight=0.0,
            rl_weight=1.0,
            use_awr_update=True,
            use_reparam_update=False,
            reparam_weight=1.0,
            awr_weight=1.0,
            post_pretrain_hyperparams=None,
            post_bc_pretrain_hyperparams=None,

            awr_use_mle_for_vf=False,
            awr_sample_actions=False,
            awr_min_q=False,

            reward_transform_class=None,
            reward_transform_kwargs=None,
            terminal_transform_class=None,
            terminal_transform_kwargs=None,

            pretraining_env_logging_period=100000,
            pretraining_logging_period=1000,
            do_pretrain_rollouts=False,

            train_bc_on_rl_buffer=False,
            use_automatic_beta_tuning=False,
            beta_epsilon=1e-10,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.buffer_policy = buffer_policy
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_awr_update = use_awr_update
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # heuristic value from Tuomas
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.awr_use_mle_for_vf = awr_use_mle_for_vf
        self.awr_sample_actions = awr_sample_actions
        self.awr_min_q = awr_min_q

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            weight_decay=policy_weight_decay,
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            weight_decay=q_weight_decay,
            lr=qf_lr,
        )

        if buffer_policy and train_bc_on_rl_buffer:
            self.buffer_policy_optimizer =  optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=policy_weight_decay,
                lr=policy_lr,
            )

        self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer
        self.beta_epsilon=beta_epsilon
        if self.use_automatic_beta_tuning:
            self.log_beta = ptu.zeros(1, requires_grad=True)
            self.beta_optimizer = optimizer_class(
                [self.log_beta],
                lr=policy_lr,
            )
        else:
            self.beta = beta
            self.beta_schedule_kwargs = beta_schedule_kwargs
            if beta_schedule_kwargs is None:
                self.beta_schedule = ConstantSchedule(beta)
            else:
                schedule_class = beta_schedule_kwargs.pop("schedule_class", PiecewiseLinearSchedule)
                self.beta_schedule = schedule_class(**beta_schedule_kwargs)

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.bc_num_pretrain_steps = bc_num_pretrain_steps
        self.q_num_pretrain1_steps = q_num_pretrain1_steps
        self.q_num_pretrain2_steps = q_num_pretrain2_steps
        self.bc_batch_size = bc_batch_size
        self.bc_loss_type = bc_loss_type
        self.awr_loss_type = awr_loss_type
        self.rl_weight = rl_weight
        self.bc_weight = bc_weight
        self.save_bc_policies = save_bc_policies
        self.eval_policy = MakeDeterministic(self.policy)
        self.compute_bc = compute_bc
        self.alpha = alpha
        self.q_update_period = q_update_period
        self.policy_update_period = policy_update_period
        self.weight_loss = weight_loss

        self.reparam_weight = reparam_weight
        self.awr_weight = awr_weight
        self.post_pretrain_hyperparams = post_pretrain_hyperparams
        self.post_bc_pretrain_hyperparams = post_bc_pretrain_hyperparams
        self.update_policy = True
        self.pretraining_env_logging_period = pretraining_env_logging_period
        self.pretraining_logging_period = pretraining_logging_period
        self.do_pretrain_rollouts = do_pretrain_rollouts

        self.reward_transform_class = reward_transform_class or LinearTransform
        self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1, b=0)
        self.terminal_transform_class = terminal_transform_class or LinearTransform
        self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1, b=0)
        self.reward_transform = self.reward_transform_class(**self.reward_transform_kwargs)
        self.terminal_transform = self.terminal_transform_class(**self.terminal_transform_kwargs)
        self.use_reparam_update = use_reparam_update

        self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy
Пример #13
0
    def __init__(self,
                 env,
                 policy,
                 qf1,
                 qf2,
                 vf,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 vf_lr=1e-3,
                 policy_mean_reg_weight=1e-3,
                 policy_std_reg_weight=1e-3,
                 policy_pre_activation_weight=0.,
                 optimizer_class=optim.Adam,
                 train_policy_with_reparameterization=True,
                 soft_target_tau=1e-3,
                 policy_update_period=1,
                 target_update_period=1,
                 plotter=None,
                 render_eval_paths=False,
                 eval_deterministic=True,
                 eval_policy=None,
                 exploration_policy=None,
                 use_automatic_entropy_tuning=True,
                 target_entropy=None,
                 **kwargs):
        if eval_policy is None:
            if eval_deterministic:
                eval_policy = MakeDeterministic(policy)
            else:
                eval_policy = policy
        super().__init__(env=env,
                         exploration_policy=exploration_policy or policy,
                         eval_policy=eval_policy,
                         **kwargs)
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.vf = vf
        self.soft_target_tau = soft_target_tau
        self.policy_update_period = policy_update_period
        self.target_update_period = target_update_period
        self.policy_mean_reg_weight = policy_mean_reg_weight
        self.policy_std_reg_weight = policy_std_reg_weight
        self.policy_pre_activation_weight = policy_pre_activation_weight
        self.train_policy_with_reparameterization = (
            train_policy_with_reparameterization)

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(
                    self.env.action_space.shape).item(
                    )  # heuristic value from Tuomas
            self.log_alpha = ptu.Variable(ptu.zeros(1), requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.target_vf = vf.copy()
        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self.vf.parameters(),
            lr=vf_lr,
        )
Пример #14
0
    def __init__(
        self,
        env,
        policy,
        qf1,
        qf2,
        target_qf1,
        target_qf2,
        vae,
        discount=0.99,
        reward_scale=1.0,
        policy_lr=1e-3,
        qf_lr=1e-3,
        optimizer_class=optim.Adam,
        soft_target_tau=1e-2,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,

        # BEAR specific params
        mode='auto',
        kernel_choice='laplacian',
        policy_update_style=0,
        mmd_sigma=10.0,
        target_mmd_thresh=0.05,
        num_samples_mmd_match=4,
        with_grad_penalty_v1=False,
        with_grad_penalty_v2=False,
        grad_coefficient_policy=0.0,
        grad_coefficient_q=0.0,
        use_target_nets=False,
        policy_update_delay=100,
        start_epoch_grad_penalty=0,
        num_steps_policy_update_only=50,
        bc_pretrain_steps=20000,
        target_update_method='default',
        use_adv_weighting=False,
        positive_reward=False,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.vae = vae
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )
        self.vae_optimizer = optimizer_class(
            self.vae.parameters(),
            lr=3e-4,
        )

        self.mode = mode
        if self.mode == 'auto':
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=1e-3,
            )
        self.mmd_sigma = mmd_sigma
        self.kernel_choice = kernel_choice
        self.num_samples_mmd_match = num_samples_mmd_match
        self.policy_update_style = policy_update_style
        self.target_mmd_thresh = target_mmd_thresh

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self._with_gradient_penalty_v1 = with_grad_penalty_v1
        self._with_gradient_penalty_v2 = with_grad_penalty_v2
        self._grad_coefficient_q = grad_coefficient_q
        self._grad_coefficient_policy = grad_coefficient_policy
        self._use_target_nets = use_target_nets
        self._policy_delay_update = policy_update_delay
        self._num_policy_steps = num_steps_policy_update_only
        self._start_epoch_grad_penalty = start_epoch_grad_penalty
        self._bc_pretrain_steps = bc_pretrain_steps
        self._target_update_method = target_update_method
        self._use_adv_weighting = use_adv_weighting
        self._positive_reward = positive_reward

        if self._target_update_method == 'distillation':
            self.target_qf1_opt = optimizer_class(self.target_qf1.parameters(),
                                                  lr=qf_lr)
            self.target_qf2_opt = optimizer_class(self.target_qf2.parameters(),
                                                  lr=qf_lr)

        self._current_epoch = 0
        self._policy_update_ctr = 0
        self._num_q_update_steps = 0
        self._num_policy_update_steps = 0

        if not self._use_target_nets:
            self.target_qf1 = qf1
            self.target_qf2 = qf2