def setup_param_noise(self, normalized_obs0):
        assert self.param_noise is not None

        # Configure perturbed actor.
        param_noise_actor = copy(self.actor)
        param_noise_actor.name = 'param_noise_actor'
        self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
        logger.info('setting up param noise')
        self.perturb_policy_ops = get_perturbed_actor_updates(
            self.actor, param_noise_actor, self.param_noise_stddev)

        # Configure separate copy for stddev adoption.
        adaptive_param_noise_actor = copy(self.actor)
        adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
        adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
        self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(
            self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
        self.adaptive_policy_distance = tf.sqrt(
            tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))
Esempio n. 2
0
    def getKfacPrecondUpdates(self, gradlist, varlist):
        updatelist = []
        vg = 0.

        assert len(self.stats) > 0
        assert len(self.stats_eigen) > 0
        assert len(self.factors) > 0
        counter = 0

        grad_dict = {var: grad for grad, var in zip(gradlist, varlist)}

        for grad, var in zip(gradlist, varlist):
            GRAD_RESHAPE = False
            GRAD_TRANSPOSE = False

            fpropFactoredFishers = self.stats[var]['fprop_concat_stats']
            bpropFactoredFishers = self.stats[var]['bprop_concat_stats']

            if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0:
                counter += 1
                GRAD_SHAPE = grad.get_shape()
                if len(grad.get_shape()) > 2:
                    # reshape conv kernel parameters
                    KW = int(grad.get_shape()[0])
                    KH = int(grad.get_shape()[1])
                    C = int(grad.get_shape()[2])
                    D = int(grad.get_shape()[3])

                    if len(fpropFactoredFishers) > 1 and self._channel_fac:
                        # reshape conv kernel parameters into tensor
                        grad = tf.reshape(grad, [KW * KH, C, D])
                    else:
                        # reshape conv kernel parameters into 2D grad
                        grad = tf.reshape(grad, [-1, D])
                    GRAD_RESHAPE = True
                elif len(grad.get_shape()) == 1:
                    # reshape bias or 1D parameters
                    D = int(grad.get_shape()[0])

                    grad = tf.expand_dims(grad, 0)
                    GRAD_RESHAPE = True
                else:
                    # 2D parameters
                    C = int(grad.get_shape()[0])
                    D = int(grad.get_shape()[1])

                if (self.stats[var]['assnBias']
                        is not None) and not self._blockdiag_bias:
                    # use homogeneous coordinates only works for 2D grad.
                    # TO-DO: figure out how to factorize bias grad
                    # stack bias grad
                    var_assnBias = self.stats[var]['assnBias']
                    grad = tf.concat(
                        [grad,
                         tf.expand_dims(grad_dict[var_assnBias], 0)], 0)

                # project gradient to eigen space and reshape the eigenvalues
                # for broadcasting
                eigVals = []

                for idx, stats in enumerate(
                        self.stats[var]['fprop_concat_stats']):
                    Q = self.stats_eigen[stats]['Q']
                    e = detectMinVal(self.stats_eigen[stats]['e'],
                                     var,
                                     name='act',
                                     debug=KFAC_DEBUG)

                    Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act')
                    eigVals.append(e)
                    grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx)

                for idx, stats in enumerate(
                        self.stats[var]['bprop_concat_stats']):
                    Q = self.stats_eigen[stats]['Q']
                    e = detectMinVal(self.stats_eigen[stats]['e'],
                                     var,
                                     name='grad',
                                     debug=KFAC_DEBUG)

                    Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad')
                    eigVals.append(e)
                    grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx)
                ##

                #####
                # whiten using eigenvalues
                weightDecayCoeff = 0.
                if var in self._weight_decay_dict:
                    weightDecayCoeff = self._weight_decay_dict[var]
                    if KFAC_DEBUG:
                        print(('weight decay coeff for %s is %f' %
                               (var.name, weightDecayCoeff)))

                if self._factored_damping:
                    if KFAC_DEBUG:
                        print(('use factored damping for %s' % (var.name)))
                    coeffs = 1.
                    num_factors = len(eigVals)
                    # compute the ratio of two trace norm of the left and right
                    # KFac matrices, and their generalization
                    if len(eigVals) == 1:
                        damping = self._epsilon + weightDecayCoeff
                    else:
                        damping = tf.pow(self._epsilon + weightDecayCoeff,
                                         1. / num_factors)
                    eigVals_tnorm_avg = [
                        tf.reduce_mean(tf.abs(e)) for e in eigVals
                    ]
                    for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg):
                        eig_tnorm_negList = [
                            item for item in eigVals_tnorm_avg
                            if item != e_tnorm
                        ]
                        if len(eigVals) == 1:
                            adjustment = 1.
                        elif len(eigVals) == 2:
                            adjustment = tf.sqrt(e_tnorm /
                                                 eig_tnorm_negList[0])
                        else:
                            eig_tnorm_negList_prod = reduce(
                                lambda x, y: x * y, eig_tnorm_negList)
                            adjustment = tf.pow(
                                tf.pow(e_tnorm, num_factors - 1.) /
                                eig_tnorm_negList_prod, 1. / num_factors)
                        coeffs *= (e + adjustment * damping)
                else:
                    coeffs = 1.
                    damping = (self._epsilon + weightDecayCoeff)
                    for e in eigVals:
                        coeffs *= e
                    coeffs += damping

                #grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()])

                grad /= coeffs

                #grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()])
                #####
                # project gradient back to euclidean space
                for idx, stats in enumerate(
                        self.stats[var]['fprop_concat_stats']):
                    Q = self.stats_eigen[stats]['Q']
                    grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx)

                for idx, stats in enumerate(
                        self.stats[var]['bprop_concat_stats']):
                    Q = self.stats_eigen[stats]['Q']
                    grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx)
                ##

                #grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()])
                if (self.stats[var]['assnBias']
                        is not None) and not self._blockdiag_bias:
                    # use homogeneous coordinates only works for 2D grad.
                    # TO-DO: figure out how to factorize bias grad
                    # un-stack bias grad
                    var_assnBias = self.stats[var]['assnBias']
                    C_plus_one = int(grad.get_shape()[0])
                    grad_assnBias = tf.reshape(
                        tf.slice(grad, begin=[C_plus_one - 1, 0], size=[1,
                                                                        -1]),
                        var_assnBias.get_shape())
                    grad_assnWeights = tf.slice(grad,
                                                begin=[0, 0],
                                                size=[C_plus_one - 1, -1])
                    grad_dict[var_assnBias] = grad_assnBias
                    grad = grad_assnWeights

                #grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()])
                if GRAD_RESHAPE:
                    grad = tf.reshape(grad, GRAD_SHAPE)

                grad_dict[var] = grad

        print(('projecting %d gradient matrices' % counter))

        for g, var in zip(gradlist, varlist):
            grad = grad_dict[var]
            ### clipping ###
            if KFAC_DEBUG:
                print(('apply clipping to %s' % (var.name)))
            tf.Print(grad, [tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))],
                     "Euclidean norm of new grad")
            local_vg = tf.reduce_sum(grad * g * (self._lr * self._lr))
            vg += local_vg

        # recale everything
        if KFAC_DEBUG:
            print('apply vFv clipping')

        scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg))
        if KFAC_DEBUG:
            scaling = tf.Print(scaling, [
                tf.convert_to_tensor('clip: '), scaling,
                tf.convert_to_tensor(' vFv: '), vg
            ])
        with tf.control_dependencies([tf.assign(self.vFv, vg)]):
            updatelist = [grad_dict[var] for var in varlist]
            for i, item in enumerate(updatelist):
                updatelist[i] = scaling * item

        return updatelist
Esempio n. 3
0
    def compute_stats(self, loss_sampled, var_list=None):
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()

        gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
        self.gs = gs
        factors = self.getFactors(gs, varlist)
        stats = self.getStats(factors, varlist)

        updateOps = []
        statsUpdates = {}
        statsUpdates_cache = {}
        for var in varlist:
            opType = factors[var]['opName']
            fops = factors[var]['op']
            fpropFactor = factors[var]['fpropFactors_concat']
            fpropStats_vars = stats[var]['fprop_concat_stats']
            bpropFactor = factors[var]['bpropFactors_concat']
            bpropStats_vars = stats[var]['bprop_concat_stats']
            SVD_factors = {}
            for stats_var in fpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_fpropFactor = fpropFactor
                    B = (tf.shape(fpropFactor)[0])  # batch size
                    if opType == 'Conv2D':
                        strides = fops.get_attr("strides")
                        padding = fops.get_attr("padding")
                        convkernel_size = var.get_shape()[0:3]

                        KH = int(convkernel_size[0])
                        KW = int(convkernel_size[1])
                        C = int(convkernel_size[2])
                        flatten_size = int(KH * KW * C)

                        Oh = int(bpropFactor.get_shape()[1])
                        Ow = int(bpropFactor.get_shape()[2])

                        if Oh == 1 and Ow == 1 and self._channel_fac:
                            # factorization along the channels
                            # assume independence among input channels
                            # factor = B x 1 x 1 x (KH xKW x C)
                            # patches = B x Oh x Ow x (KH xKW x C)
                            if len(SVD_factors) == 0:
                                if KFAC_DEBUG:
                                    print((
                                        'approx %s act factor with rank-1 SVD factors'
                                        % (var.name)))
                                # find closest rank-1 approx to the feature map
                                S, U, V = tf.batch_svd(
                                    tf.reshape(fpropFactor, [-1, KH * KW, C]))
                                # get rank-1 approx slides
                                sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
                                patches_k = U[:, :, 0] * sqrtS1  # B x KH*KW
                                full_factor_shape = fpropFactor.get_shape()
                                patches_k.set_shape(
                                    [full_factor_shape[0], KH * KW])
                                patches_c = V[:, :, 0] * sqrtS1  # B x C
                                patches_c.set_shape([full_factor_shape[0], C])
                                SVD_factors[C] = patches_c
                                SVD_factors[KH * KW] = patches_k
                            fpropFactor = SVD_factors[stats_var_dim]

                        else:
                            # poor mem usage implementation
                            patches = tf.extract_image_patches(
                                fpropFactor,
                                ksizes=[
                                    1, convkernel_size[0], convkernel_size[1],
                                    1
                                ],
                                strides=strides,
                                rates=[1, 1, 1, 1],
                                padding=padding)

                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 act fisher for %s' %
                                           (var.name)))
                                # T^2 terms * 1/T^2, size: B x C
                                fpropFactor = tf.reduce_mean(patches, [1, 2])
                            else:
                                # size: (B x Oh x Ow) x C
                                fpropFactor = tf.reshape(
                                    patches, [-1, flatten_size]) / Oh / Ow
                    fpropFactor_size = int(fpropFactor.get_shape()[-1])
                    if stats_var_dim == (fpropFactor_size +
                                         1) and not self._blockdiag_bias:
                        if opType == 'Conv2D' and not self._approxT2:
                            # correct padding for numerical stability (we
                            # divided out OhxOw from activations for T1 approx)
                            fpropFactor = tf.concat([
                                fpropFactor,
                                tf.ones([tf.shape(fpropFactor)[0], 1]) / Oh /
                                Ow
                            ], 1)
                        else:
                            # use homogeneous coordinates
                            fpropFactor = tf.concat([
                                fpropFactor,
                                tf.ones([tf.shape(fpropFactor)[0], 1])
                            ], 1)

                    # average over the number of data points in a batch
                    # divided by B
                    cov = tf.matmul(fpropFactor, fpropFactor,
                                    transpose_a=True) / tf.cast(B, tf.float32)
                    updateOps.append(cov)
                    statsUpdates[stats_var] = cov
                    if opType != 'Conv2D':
                        # HACK: for convolution we recompute fprop stats for
                        # every layer including forking layers
                        statsUpdates_cache[stats_var] = cov

            for stats_var in bpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_bpropFactor = bpropFactor
                    bpropFactor_shape = bpropFactor.get_shape()
                    B = tf.shape(bpropFactor)[0]  # batch size
                    C = int(bpropFactor_shape[-1])  # num channels
                    if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
                        if fpropFactor is not None:
                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 grad fisher for %s' %
                                           (var.name)))
                                bpropFactor = tf.reduce_sum(
                                    bpropFactor, [1, 2])  # T^2 terms * 1/T^2
                            else:
                                bpropFactor = tf.reshape(
                                    bpropFactor,
                                    [-1, C]) * Oh * Ow  # T * 1/T terms
                        else:
                            # just doing block diag approx. spatial independent
                            # structure does not apply here. summing over
                            # spatial locations
                            if KFAC_DEBUG:
                                print(('block diag approx fisher for %s' %
                                       (var.name)))
                            bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])

                    # assume sampled loss is averaged. TO-DO:figure out better
                    # way to handle this
                    bpropFactor *= tf.to_float(B)
                    ##

                    cov_b = tf.matmul(bpropFactor,
                                      bpropFactor,
                                      transpose_a=True) / tf.to_float(
                                          tf.shape(bpropFactor)[0])

                    updateOps.append(cov_b)
                    statsUpdates[stats_var] = cov_b
                    statsUpdates_cache[stats_var] = cov_b

        if KFAC_DEBUG:
            aKey = list(statsUpdates.keys())[0]
            statsUpdates[aKey] = tf.Print(statsUpdates[aKey], [
                tf.convert_to_tensor('step:'),
                self.global_step,
                tf.convert_to_tensor('computing stats'),
            ])
        self.statsUpdates = statsUpdates
        return statsUpdates
def reduce_std(x, axis=None, keep_dims=False):
    return tf.sqrt(reduce_var(x, axis=axis, keep_dims=keep_dims))
    def __init__(self, size, eps=1e-2, default_clip_range=np.inf, sess=None):
        """A normalizer that ensures that observations are approximately distributed according to
        a standard Normal distribution (i.e. have mean zero and variance one).

        Args:
            size (int): the size of the observation to be normalized
            eps (float): a small constant that avoids underflows
            default_clip_range (float): normalized observations are clipped to be in
                [-default_clip_range, default_clip_range]
            sess (object): the TensorFlow session to be used
        """
        self.size = size
        self.eps = eps
        self.default_clip_range = default_clip_range
        self.sess = sess if sess is not None else tf.get_default_session()

        self.local_sum = np.zeros(self.size, np.float32)
        self.local_sumsq = np.zeros(self.size, np.float32)
        self.local_count = np.zeros(1, np.float32)

        self.sum_tf = tf.get_variable(initializer=tf.zeros_initializer(),
                                      shape=self.local_sum.shape,
                                      name='sum',
                                      trainable=False,
                                      dtype=tf.float32)
        self.sumsq_tf = tf.get_variable(initializer=tf.zeros_initializer(),
                                        shape=self.local_sumsq.shape,
                                        name='sumsq',
                                        trainable=False,
                                        dtype=tf.float32)
        self.count_tf = tf.get_variable(initializer=tf.ones_initializer(),
                                        shape=self.local_count.shape,
                                        name='count',
                                        trainable=False,
                                        dtype=tf.float32)
        self.mean = tf.get_variable(initializer=tf.zeros_initializer(),
                                    shape=(self.size, ),
                                    name='mean',
                                    trainable=False,
                                    dtype=tf.float32)
        self.std = tf.get_variable(initializer=tf.ones_initializer(),
                                   shape=(self.size, ),
                                   name='std',
                                   trainable=False,
                                   dtype=tf.float32)
        self.count_pl = tf.placeholder(name='count_pl',
                                       shape=(1, ),
                                       dtype=tf.float32)
        self.sum_pl = tf.placeholder(name='sum_pl',
                                     shape=(self.size, ),
                                     dtype=tf.float32)
        self.sumsq_pl = tf.placeholder(name='sumsq_pl',
                                       shape=(self.size, ),
                                       dtype=tf.float32)

        self.update_op = tf.group(self.count_tf.assign_add(self.count_pl),
                                  self.sum_tf.assign_add(self.sum_pl),
                                  self.sumsq_tf.assign_add(self.sumsq_pl))
        self.recompute_op = tf.group(
            tf.assign(self.mean, self.sum_tf / self.count_tf),
            tf.assign(
                self.std,
                tf.sqrt(
                    tf.maximum(
                        tf.square(self.eps), self.sumsq_tf / self.count_tf -
                        tf.square(self.sum_tf / self.count_tf)))),
        )
        self.lock = threading.Lock()