Exemple #1
0
    def forward_impl(self, inputs, outputs):
        x = inputs[0].data
        m = inputs[1].data
        M = inputs[2].data
        y = outputs[0].data
        y.copy_from(x)

        if not self.training:
            return
        mb = F.min(x, keepdims=True)
        Mb = F.max(x, keepdims=True)
        F.minimum2(m, mb, outputs=[m])
        F.maximum2(M, Mb, outputs=[M])
Exemple #2
0
    def chamfer_hausdorff_oneside_dists(X0, X1):
        b0 = X0.shape[0]
        b1 = X1.shape[0]

        sum_ = 0
        max_ = nn.NdArray.from_numpy_array(np.array(-np.inf))
        n = 0
        for i in tqdm.tqdm(range(0, b0, sub_batch_size),
                           desc="cdist-outer-loop"):
            x0 = nn.NdArray.from_numpy_array(X0[i:i + sub_batch_size])
            norm_x0 = F.sum(x0**2.0, axis=1, keepdims=True)
            min_ = nn.NdArray.from_numpy_array(np.ones(x0.shape[0]) * np.inf)
            for j in tqdm.tqdm(range(0, b1, sub_batch_size),
                               desc="cdist-inner-loop"):
                x1 = nn.NdArray.from_numpy_array(X1[j:j + sub_batch_size])
                # block pwd
                norm_x1 = F.transpose(F.sum(x1**2.0, axis=1, keepdims=True),
                                      (1, 0))
                x1_T = F.transpose(x1, (1, 0))
                x01 = F.affine(x0, x1_T)
                bpwd = (norm_x0 + norm_x1 - 2.0 * x01)**0.5
                # block min
                min_ = F.minimum2(min_, F.min(bpwd, axis=1))
            # sum/max over cols
            sum_ += F.sum(min_)
            n += bpwd.shape[0]
            max_ = F.maximum2(max_, F.max(min_))
        ocd = sum_.data / n
        ohd = max_.data
        return ocd, ohd
Exemple #3
0
def clip_quant_vals():
    p = nn.get_parameters()
    if cfg.w_quantize in [
            'parametric_fp_b_xmax', 'parametric_fp_d_xmax',
            'parametric_fp_d_b', 'parametric_pow2_b_xmax',
            'parametric_pow2_b_xmin', 'parametric_pow2_xmin_xmax'
    ]:
        for k in p:
            if 'Wquant' in k.split('/') or 'bquant' in k.split('/'):
                if k.endswith('/m'):  # range
                    p[k].data = clip_scalar(p[k].data,
                                            cfg.w_dynrange_min + 1e-5,
                                            cfg.w_dynrange_max - 1e-5)
                elif k.endswith('/n'):  # bits
                    p[k].data = clip_scalar(p[k].data,
                                            cfg.w_bitwidth_min + 1e-5,
                                            cfg.w_bitwidth_max - 1e-5)
                elif k.endswith('/d'):  # delta
                    if cfg.w_quantize == 'parametric_fp_d_xmax':
                        g = k.replace('/d', '/xmax')
                        min_value = F.minimum2(p[k].data, p[g].data - 1e-5)
                        max_value = F.maximum2(p[k].data + 1e-5, p[g].data)
                        p[k].data = min_value
                        p[g].data = max_value
                    p[k].data = clip_scalar(p[k].data,
                                            cfg.w_stepsize_min + 1e-5,
                                            cfg.w_stepsize_max - 1e-5)
                elif k.endswith('/xmin'):  # xmin
                    if cfg.w_quantize == 'parametric_pow2_xmin_xmax':
                        g = k.replace('/xmin', '/xmax')
                        min_value = F.minimum2(p[k].data, p[g].data - 1e-5)
                        max_value = F.maximum2(p[k].data + 1e-5, p[g].data)
                        p[k].data = min_value
                        p[g].data = max_value
                    p[k].data = clip_scalar(p[k].data, cfg.w_xmin_min + 1e-5,
                                            cfg.w_xmin_max - 1e-5)
                elif k.endswith('/xmax'):  # xmax
                    p[k].data = clip_scalar(p[k].data, cfg.w_xmax_min + 1e-5,
                                            cfg.w_xmax_max - 1e-5)

    if cfg.a_quantize in [
            'parametric_fp_b_xmax_relu', 'parametric_fp_d_xmax_relu',
            'parametric_fp_d_b_relu', 'parametric_pow2_b_xmax_relu',
            'parametric_pow2_b_xmin_relu', 'parametric_pow2_xmin_xmax_relu'
    ]:
        for k in p:
            if 'Aquant' in k.split('/'):
                if k.endswith('/m'):  # range
                    p[k].data = clip_scalar(p[k].data,
                                            cfg.a_dynrange_min + 1e-5,
                                            cfg.a_dynrange_max - 1e-5)
                elif k.endswith('/n'):  # bits
                    p[k].data = clip_scalar(p[k].data,
                                            cfg.a_bitwidth_min + 1e-5,
                                            cfg.a_bitwidth_max - 1e-5)
                elif k.endswith('/d'):  # delta
                    if cfg.a_quantize == 'parametric_fp_d_xmax_relu':
                        g = k.replace('/d', '/xmax')
                        min_value = F.minimum2(p[k].data, p[g].data - 1e-5)
                        max_value = F.maximum2(p[k].data + 1e-5, p[g].data)
                        p[k].data = min_value
                        p[g].data = max_value
                    p[k].data = clip_scalar(p[k].data,
                                            cfg.a_stepsize_min + 1e-5,
                                            cfg.a_stepsize_max - 1e-5)
                elif k.endswith('/xmin'):  # xmin
                    if cfg.a_quantize == 'parametric_pow2_xmin_xmax_relu':
                        g = k.replace('/xmin', '/xmax')
                        min_value = F.minimum2(p[k].data, p[g].data - 1e-5)
                        max_value = F.maximum2(p[k].data + 1e-5, p[g].data)
                        p[k].data = min_value
                        p[g].data = max_value
                    p[k].data = clip_scalar(p[k].data, cfg.a_xmin_min + 1e-5,
                                            cfg.a_xmin_max - 1e-5)
                elif k.endswith('/xmax'):  # xmax
                    p[k].data = clip_scalar(p[k].data, cfg.a_xmax_min + 1e-5,
                                            cfg.a_xmax_max - 1e-5)
Exemple #4
0
    def _build(self):
        # inference
        self.infer_obs_t = nn.Variable((1,) + self.obs_shape)
        with nn.parameter_scope('trainable'):
            self.infer_policy_t = policy_network(self.infer_obs_t,
                                                 self.action_size, 'actor')

        # training
        self.obss_t = nn.Variable((self.batch_size,) + self.obs_shape)
        self.acts_t = nn.Variable((self.batch_size, self.action_size))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size,) + self.obs_shape)
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        # critic loss
        with nn.parameter_scope('trainable'):
            # critic functions
            q1_t = q_network(self.obss_t, self.acts_t, 'critic/1')
            q2_t = q_network(self.obss_t, self.acts_t, 'critic/2')
        with nn.parameter_scope('target'):
            # target functions
            policy_tp1 = policy_network(self.obss_tp1, self.action_size,
                                        'actor')
            smoothed_target = _smoothing_target(policy_tp1,
                                                self.target_reg_sigma,
                                                self.target_reg_clip)
            q1_tp1 = q_network(self.obss_tp1, smoothed_target, 'critic/1')
            q2_tp1 = q_network(self.obss_tp1, smoothed_target, 'critic/2')
        q_tp1 = F.minimum2(q1_tp1, q2_tp1)
        y = self.rews_tp1 + self.gamma * q_tp1 * (1.0 - self.ters_tp1)
        td1 = F.mean(F.squared_error(q1_t, y))
        td2 = F.mean(F.squared_error(q2_t, y))
        self.critic_loss = td1 + td2

        # actor loss
        with nn.parameter_scope('trainable'):
            policy_t = policy_network(self.obss_t, self.action_size, 'actor')
            q1_t_with_actor = q_network(self.obss_t, policy_t, 'critic/1')
            q2_t_with_actor = q_network(self.obss_t, policy_t, 'critic/2')
        q_t_with_actor = F.minimum2(q1_t_with_actor, q2_t_with_actor)
        self.actor_loss = -F.mean(q_t_with_actor)

        # get neural network parameters
        with nn.parameter_scope('trainable'):
            with nn.parameter_scope('critic'):
                critic_params = nn.get_parameters()
            with nn.parameter_scope('actor'):
                actor_params = nn.get_parameters()

        # setup optimizers
        self.critic_solver = S.Adam(self.critic_lr)
        self.critic_solver.set_parameters(critic_params)
        self.actor_solver = S.Adam(self.actor_lr)
        self.actor_solver.set_parameters(actor_params)

        with nn.parameter_scope('trainable'):
            trainable_params = nn.get_parameters()
        with nn.parameter_scope('target'):
            target_params = nn.get_parameters()

        # build target update
        update_targets = []
        sync_targets = []
        for key, src in trainable_params.items():
            dst = target_params[key]
            updated_dst = (1.0 - self.tau) * dst + self.tau * src
            update_targets.append(F.assign(dst, updated_dst))
            sync_targets.append(F.assign(dst, src))
        self.update_target_expr = F.sink(*update_targets)
        self.sync_target_expr = F.sink(*sync_targets)
Exemple #5
0
    def _build(self):
        # inference graph
        self.infer_obs_t = nn.Variable((1, ) + self.obs_shape)
        with nn.parameter_scope('trainable'):
            infer_dist = policy_network(self.infer_obs_t, self.action_size,
                                        'actor')
        self.infer_act_t, _ = _squash_action(infer_dist)
        self.deterministic_act_t = infer_dist.mean()

        # training graph
        self.obss_t = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.acts_t = nn.Variable((self.batch_size, self.action_size))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        with nn.parameter_scope('trainable'):
            self.log_temp = get_parameter_or_create('temp', [1, 1],
                                                    ConstantInitializer(0.0))
            dist_t = policy_network(self.obss_t, self.action_size, 'actor')
            dist_tp1 = policy_network(self.obss_tp1, self.action_size, 'actor')
            squashed_act_t, log_prob_t = _squash_action(dist_t)
            squashed_act_tp1, log_prob_tp1 = _squash_action(dist_tp1)
            q1_t = q_network(self.obss_t, self.acts_t, 'critic/1')
            q2_t = q_network(self.obss_t, self.acts_t, 'critic/2')
            q1_t_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/1')
            q2_t_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/2')

        with nn.parameter_scope('target'):
            q1_tp1 = q_network(self.obss_tp1, squashed_act_tp1, 'critic/1')
            q2_tp1 = q_network(self.obss_tp1, squashed_act_tp1, 'critic/2')

        # q function loss
        q_tp1 = F.minimum2(q1_tp1, q2_tp1)
        entropy_tp1 = F.exp(self.log_temp) * log_prob_tp1
        mask = (1.0 - self.ters_tp1)
        q_target = self.rews_tp1 + self.gamma * (q_tp1 - entropy_tp1) * mask
        q_target.need_grad = False
        q1_loss = 0.5 * F.mean(F.squared_error(q1_t, q_target))
        q2_loss = 0.5 * F.mean(F.squared_error(q2_t, q_target))
        self.critic_loss = q1_loss + q2_loss

        # policy function loss
        q_t = F.minimum2(q1_t_with_actor, q2_t_with_actor)
        entropy_t = F.exp(self.log_temp) * log_prob_t
        self.actor_loss = F.mean(entropy_t - q_t)

        # temperature loss
        temp_target = log_prob_t - self.action_size
        temp_target.need_grad = False
        self.temp_loss = -F.mean(F.exp(self.log_temp) * temp_target)

        # trainable parameters
        with nn.parameter_scope('trainable'):
            with nn.parameter_scope('critic'):
                critic_params = nn.get_parameters()
            with nn.parameter_scope('actor'):
                actor_params = nn.get_parameters()
        # target parameters
        with nn.parameter_scope('target/critic'):
            target_params = nn.get_parameters()

        # target update
        update_targets = []
        sync_targets = []
        for key, src in critic_params.items():
            dst = target_params[key]
            updated_dst = (1.0 - self.tau) * dst + self.tau * src
            update_targets.append(F.assign(dst, updated_dst))
            sync_targets.append(F.assign(dst, src))
        self.update_target_expr = F.sink(*update_targets)
        self.sync_target_expr = F.sink(*sync_targets)

        # setup solvers
        self.critic_solver = S.Adam(self.critic_lr)
        self.critic_solver.set_parameters(critic_params)
        self.actor_solver = S.Adam(self.actor_lr)
        self.actor_solver.set_parameters(actor_params)
        self.temp_solver = S.Adam(self.temp_lr)
        self.temp_solver.set_parameters({'temp': self.log_temp})
Exemple #6
0
    def _build(self):
        # inference graph
        self.infer_obs_t = nn.Variable((1, ) + self.obs_shape)
        with nn.parameter_scope('trainable'):
            infer_dist = policy_network(self.infer_obs_t, self.action_size,
                                        'actor')
        self.infer_act_t, _ = _squash_action(infer_dist)
        self.deterministic_act_t = infer_dist.mean()

        # training graph
        self.obss_t = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.acts_t = nn.Variable((self.batch_size, self.action_size))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        with nn.parameter_scope('trainable'):
            dist = policy_network(self.obss_t, self.action_size, 'actor')
            squashed_act_t, log_prob_t = _squash_action(dist)
            v_t = v_network(self.obss_t, 'value')
            q_t1 = q_network(self.obss_t, self.acts_t, 'critic/1')
            q_t2 = q_network(self.obss_t, self.acts_t, 'critic/2')
            q_t1_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/1')
            q_t2_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/2')

        with nn.parameter_scope('target'):
            v_tp1 = v_network(self.obss_tp1, 'value')

        # value loss
        q_t = F.minimum2(q_t1_with_actor, q_t2_with_actor)
        v_target = q_t - log_prob_t
        v_target.need_grad = False
        self.value_loss = 0.5 * F.mean(F.squared_error(v_t, v_target))

        # q function loss
        scaled_rews_tp1 = self.rews_tp1 * self.reward_scale
        q_target = scaled_rews_tp1 + self.gamma * v_tp1 * (1.0 - self.ters_tp1)
        q_target.need_grad = False
        q1_loss = 0.5 * F.mean(F.squared_error(q_t1, q_target))
        q2_loss = 0.5 * F.mean(F.squared_error(q_t2, q_target))
        self.critic_loss = q1_loss + q2_loss

        # policy function loss
        mean_loss = 0.5 * F.mean(dist.mean()**2)
        logstd_loss = 0.5 * F.mean(F.log(dist.stddev())**2)
        policy_reg_loss = self.policy_reg * (mean_loss + logstd_loss)
        self.objective_loss = F.mean(log_prob_t - q_t)
        self.actor_loss = self.objective_loss + policy_reg_loss

        # trainable parameters
        with nn.parameter_scope('trainable'):
            with nn.parameter_scope('value'):
                value_params = nn.get_parameters()
            with nn.parameter_scope('critic'):
                critic_params = nn.get_parameters()
            with nn.parameter_scope('actor'):
                actor_params = nn.get_parameters()
        # target parameters
        with nn.parameter_scope('target/value'):
            target_params = nn.get_parameters()

        # target update
        update_targets = []
        sync_targets = []
        for key, src in value_params.items():
            dst = target_params[key]
            updated_dst = (1.0 - self.tau) * dst + self.tau * src
            update_targets.append(F.assign(dst, updated_dst))
            sync_targets.append(F.assign(dst, src))
        self.update_target_expr = F.sink(*update_targets)
        self.sync_target_expr = F.sink(*sync_targets)

        # setup solvers
        self.value_solver = S.Adam(self.value_lr)
        self.value_solver.set_parameters(value_params)
        self.critic_solver = S.Adam(self.critic_lr)
        self.critic_solver.set_parameters(critic_params)
        self.actor_solver = S.Adam(self.actor_lr)
        self.actor_solver.set_parameters(actor_params)