コード例 #1
0
    def applyStatsEigen(self, eigen_list):
        updateOps = []
        print(('updating %d eigenvalue/vectors' % len(eigen_list)))
        for i, (tensor,
                mark) in enumerate(zip(eigen_list, self.eigen_update_list)):
            stats_eigen_var = self.eigen_reverse_lookup[mark]
            updateOps.append(
                tf.assign(stats_eigen_var, tensor, use_locking=True))

        with tf.control_dependencies(updateOps):
            factor_step_op = tf.assign_add(self.factor_step, 1)
            updateOps.append(factor_step_op)
            if KFAC_DEBUG:
                updateOps.append(
                    tf.Print(tf.constant(0.),
                             [tf.convert_to_tensor('updated kfac factors')]))
        return updateOps
コード例 #2
0
        def coldSGDstart():
            sgd_grads, sgd_var = zip(*grads)

            if self.max_grad_norm != None:
                sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(
                    sgd_grads, self.max_grad_norm)

            sgd_grads = list(zip(sgd_grads, sgd_var))

            sgd_step_op = tf.assign_add(self.sgd_step, 1)
            coldOptim_op = coldOptim.apply_gradients(sgd_grads)
            if KFAC_DEBUG:
                with tf.control_dependencies([sgd_step_op, coldOptim_op]):
                    sgd_step_op = tf.Print(sgd_step_op, [
                        self.sgd_step,
                        tf.convert_to_tensor('doing cold sgd step')
                    ])
            return tf.group(*[sgd_step_op, coldOptim_op])
コード例 #3
0
    def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.):
        updateOps = []
        # obtain the stats var list
        for stats_var in statsUpdates:
            stats_new = statsUpdates[stats_var]
            if accumulate:
                # simple superbatch averaging
                update_op = tf.assign_add(stats_var,
                                          accumulateCoeff * stats_new,
                                          use_locking=True)
            else:
                # exponential running averaging
                update_op = tf.assign(stats_var,
                                      stats_var * self._stats_decay,
                                      use_locking=True)
                update_op = tf.assign_add(update_op,
                                          (1. - self._stats_decay) * stats_new,
                                          use_locking=True)
            updateOps.append(update_op)

        with tf.control_dependencies(updateOps):
            stats_step_op = tf.assign_add(self.stats_step, 1)

        if KFAC_DEBUG:
            stats_step_op = (tf.Print(stats_step_op, [
                tf.convert_to_tensor('step:'), self.global_step,
                tf.convert_to_tensor('fac step:'), self.factor_step,
                tf.convert_to_tensor('sgd step:'), self.sgd_step,
                tf.convert_to_tensor('Accum:'),
                tf.convert_to_tensor(accumulate),
                tf.convert_to_tensor('Accum coeff:'),
                tf.convert_to_tensor(accumulateCoeff),
                tf.convert_to_tensor('stat step:'), self.stats_step,
                updateOps[0], updateOps[1]
            ]))
        return [
            stats_step_op,
        ]
コード例 #4
0
    def getKfacPrecondUpdates(self, gradlist, varlist):
        updatelist = []
        vg = 0.

        assert len(self.stats) > 0
        assert len(self.stats_eigen) > 0
        assert len(self.factors) > 0
        counter = 0

        grad_dict = {var: grad for grad, var in zip(gradlist, varlist)}

        for grad, var in zip(gradlist, varlist):
            GRAD_RESHAPE = False
            GRAD_TRANSPOSE = False

            fpropFactoredFishers = self.stats[var]['fprop_concat_stats']
            bpropFactoredFishers = self.stats[var]['bprop_concat_stats']

            if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0:
                counter += 1
                GRAD_SHAPE = grad.get_shape()
                if len(grad.get_shape()) > 2:
                    # reshape conv kernel parameters
                    KW = int(grad.get_shape()[0])
                    KH = int(grad.get_shape()[1])
                    C = int(grad.get_shape()[2])
                    D = int(grad.get_shape()[3])

                    if len(fpropFactoredFishers) > 1 and self._channel_fac:
                        # reshape conv kernel parameters into tensor
                        grad = tf.reshape(grad, [KW * KH, C, D])
                    else:
                        # reshape conv kernel parameters into 2D grad
                        grad = tf.reshape(grad, [-1, D])
                    GRAD_RESHAPE = True
                elif len(grad.get_shape()) == 1:
                    # reshape bias or 1D parameters
                    D = int(grad.get_shape()[0])

                    grad = tf.expand_dims(grad, 0)
                    GRAD_RESHAPE = True
                else:
                    # 2D parameters
                    C = int(grad.get_shape()[0])
                    D = int(grad.get_shape()[1])

                if (self.stats[var]['assnBias']
                        is not None) and not self._blockdiag_bias:
                    # use homogeneous coordinates only works for 2D grad.
                    # TO-DO: figure out how to factorize bias grad
                    # stack bias grad
                    var_assnBias = self.stats[var]['assnBias']
                    grad = tf.concat(
                        [grad,
                         tf.expand_dims(grad_dict[var_assnBias], 0)], 0)

                # project gradient to eigen space and reshape the eigenvalues
                # for broadcasting
                eigVals = []

                for idx, stats in enumerate(
                        self.stats[var]['fprop_concat_stats']):
                    Q = self.stats_eigen[stats]['Q']
                    e = detectMinVal(self.stats_eigen[stats]['e'],
                                     var,
                                     name='act',
                                     debug=KFAC_DEBUG)

                    Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act')
                    eigVals.append(e)
                    grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx)

                for idx, stats in enumerate(
                        self.stats[var]['bprop_concat_stats']):
                    Q = self.stats_eigen[stats]['Q']
                    e = detectMinVal(self.stats_eigen[stats]['e'],
                                     var,
                                     name='grad',
                                     debug=KFAC_DEBUG)

                    Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad')
                    eigVals.append(e)
                    grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx)
                ##

                #####
                # whiten using eigenvalues
                weightDecayCoeff = 0.
                if var in self._weight_decay_dict:
                    weightDecayCoeff = self._weight_decay_dict[var]
                    if KFAC_DEBUG:
                        print(('weight decay coeff for %s is %f' %
                               (var.name, weightDecayCoeff)))

                if self._factored_damping:
                    if KFAC_DEBUG:
                        print(('use factored damping for %s' % (var.name)))
                    coeffs = 1.
                    num_factors = len(eigVals)
                    # compute the ratio of two trace norm of the left and right
                    # KFac matrices, and their generalization
                    if len(eigVals) == 1:
                        damping = self._epsilon + weightDecayCoeff
                    else:
                        damping = tf.pow(self._epsilon + weightDecayCoeff,
                                         1. / num_factors)
                    eigVals_tnorm_avg = [
                        tf.reduce_mean(tf.abs(e)) for e in eigVals
                    ]
                    for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg):
                        eig_tnorm_negList = [
                            item for item in eigVals_tnorm_avg
                            if item != e_tnorm
                        ]
                        if len(eigVals) == 1:
                            adjustment = 1.
                        elif len(eigVals) == 2:
                            adjustment = tf.sqrt(e_tnorm /
                                                 eig_tnorm_negList[0])
                        else:
                            eig_tnorm_negList_prod = reduce(
                                lambda x, y: x * y, eig_tnorm_negList)
                            adjustment = tf.pow(
                                tf.pow(e_tnorm, num_factors - 1.) /
                                eig_tnorm_negList_prod, 1. / num_factors)
                        coeffs *= (e + adjustment * damping)
                else:
                    coeffs = 1.
                    damping = (self._epsilon + weightDecayCoeff)
                    for e in eigVals:
                        coeffs *= e
                    coeffs += damping

                #grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()])

                grad /= coeffs

                #grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()])
                #####
                # project gradient back to euclidean space
                for idx, stats in enumerate(
                        self.stats[var]['fprop_concat_stats']):
                    Q = self.stats_eigen[stats]['Q']
                    grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx)

                for idx, stats in enumerate(
                        self.stats[var]['bprop_concat_stats']):
                    Q = self.stats_eigen[stats]['Q']
                    grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx)
                ##

                #grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()])
                if (self.stats[var]['assnBias']
                        is not None) and not self._blockdiag_bias:
                    # use homogeneous coordinates only works for 2D grad.
                    # TO-DO: figure out how to factorize bias grad
                    # un-stack bias grad
                    var_assnBias = self.stats[var]['assnBias']
                    C_plus_one = int(grad.get_shape()[0])
                    grad_assnBias = tf.reshape(
                        tf.slice(grad, begin=[C_plus_one - 1, 0], size=[1,
                                                                        -1]),
                        var_assnBias.get_shape())
                    grad_assnWeights = tf.slice(grad,
                                                begin=[0, 0],
                                                size=[C_plus_one - 1, -1])
                    grad_dict[var_assnBias] = grad_assnBias
                    grad = grad_assnWeights

                #grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()])
                if GRAD_RESHAPE:
                    grad = tf.reshape(grad, GRAD_SHAPE)

                grad_dict[var] = grad

        print(('projecting %d gradient matrices' % counter))

        for g, var in zip(gradlist, varlist):
            grad = grad_dict[var]
            ### clipping ###
            if KFAC_DEBUG:
                print(('apply clipping to %s' % (var.name)))
            tf.Print(grad, [tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))],
                     "Euclidean norm of new grad")
            local_vg = tf.reduce_sum(grad * g * (self._lr * self._lr))
            vg += local_vg

        # recale everything
        if KFAC_DEBUG:
            print('apply vFv clipping')

        scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg))
        if KFAC_DEBUG:
            scaling = tf.Print(scaling, [
                tf.convert_to_tensor('clip: '), scaling,
                tf.convert_to_tensor(' vFv: '), vg
            ])
        with tf.control_dependencies([tf.assign(self.vFv, vg)]):
            updatelist = [grad_dict[var] for var in varlist]
            for i, item in enumerate(updatelist):
                updatelist[i] = scaling * item

        return updatelist
コード例 #5
0
    def computeStatsEigen(self):
        """ compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """
        # TO-DO: figure out why this op has delays (possibly moving
        # eigenvectors around?)
        with tf.device('/cpu:0'):

            def removeNone(tensor_list):
                local_list = []
                for item in tensor_list:
                    if item is not None:
                        local_list.append(item)
                return local_list

            def copyStats(var_list):
                print("copying stats to buffer tensors before eigen decomp")
                redundant_stats = {}
                copied_list = []
                for item in var_list:
                    if item is not None:
                        if item not in redundant_stats:
                            if self._use_float64:
                                redundant_stats[item] = tf.cast(
                                    tf.identity(item), tf.float64)
                            else:
                                redundant_stats[item] = tf.identity(item)
                        copied_list.append(redundant_stats[item])
                    else:
                        copied_list.append(None)
                return copied_list

            #stats = [copyStats(self.fStats), copyStats(self.bStats)]
            #stats = [self.fStats, self.bStats]

            stats_eigen = self.stats_eigen
            computedEigen = {}
            eigen_reverse_lookup = {}
            updateOps = []
            # sync copied stats
            # with tf.control_dependencies(removeNone(stats[0]) +
            # removeNone(stats[1])):
            with tf.control_dependencies([]):
                for stats_var in stats_eigen:
                    if stats_var not in computedEigen:
                        eigens = tf.self_adjoint_eig(stats_var)
                        e = eigens[0]
                        Q = eigens[1]
                        if self._use_float64:
                            e = tf.cast(e, tf.float32)
                            Q = tf.cast(Q, tf.float32)
                        updateOps.append(e)
                        updateOps.append(Q)
                        computedEigen[stats_var] = {'e': e, 'Q': Q}
                        eigen_reverse_lookup[e] = stats_eigen[stats_var]['e']
                        eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q']

            self.eigen_reverse_lookup = eigen_reverse_lookup
            self.eigen_update_list = updateOps

            if KFAC_DEBUG:
                self.eigen_update_list = [item for item in updateOps]
                with tf.control_dependencies(updateOps):
                    updateOps.append(
                        tf.Print(
                            tf.constant(0.),
                            [tf.convert_to_tensor('computed factor eigen')]))

        return updateOps
コード例 #6
0
    def compute_stats(self, loss_sampled, var_list=None):
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()

        gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
        self.gs = gs
        factors = self.getFactors(gs, varlist)
        stats = self.getStats(factors, varlist)

        updateOps = []
        statsUpdates = {}
        statsUpdates_cache = {}
        for var in varlist:
            opType = factors[var]['opName']
            fops = factors[var]['op']
            fpropFactor = factors[var]['fpropFactors_concat']
            fpropStats_vars = stats[var]['fprop_concat_stats']
            bpropFactor = factors[var]['bpropFactors_concat']
            bpropStats_vars = stats[var]['bprop_concat_stats']
            SVD_factors = {}
            for stats_var in fpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_fpropFactor = fpropFactor
                    B = (tf.shape(fpropFactor)[0])  # batch size
                    if opType == 'Conv2D':
                        strides = fops.get_attr("strides")
                        padding = fops.get_attr("padding")
                        convkernel_size = var.get_shape()[0:3]

                        KH = int(convkernel_size[0])
                        KW = int(convkernel_size[1])
                        C = int(convkernel_size[2])
                        flatten_size = int(KH * KW * C)

                        Oh = int(bpropFactor.get_shape()[1])
                        Ow = int(bpropFactor.get_shape()[2])

                        if Oh == 1 and Ow == 1 and self._channel_fac:
                            # factorization along the channels
                            # assume independence among input channels
                            # factor = B x 1 x 1 x (KH xKW x C)
                            # patches = B x Oh x Ow x (KH xKW x C)
                            if len(SVD_factors) == 0:
                                if KFAC_DEBUG:
                                    print((
                                        'approx %s act factor with rank-1 SVD factors'
                                        % (var.name)))
                                # find closest rank-1 approx to the feature map
                                S, U, V = tf.batch_svd(
                                    tf.reshape(fpropFactor, [-1, KH * KW, C]))
                                # get rank-1 approx slides
                                sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
                                patches_k = U[:, :, 0] * sqrtS1  # B x KH*KW
                                full_factor_shape = fpropFactor.get_shape()
                                patches_k.set_shape(
                                    [full_factor_shape[0], KH * KW])
                                patches_c = V[:, :, 0] * sqrtS1  # B x C
                                patches_c.set_shape([full_factor_shape[0], C])
                                SVD_factors[C] = patches_c
                                SVD_factors[KH * KW] = patches_k
                            fpropFactor = SVD_factors[stats_var_dim]

                        else:
                            # poor mem usage implementation
                            patches = tf.extract_image_patches(
                                fpropFactor,
                                ksizes=[
                                    1, convkernel_size[0], convkernel_size[1],
                                    1
                                ],
                                strides=strides,
                                rates=[1, 1, 1, 1],
                                padding=padding)

                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 act fisher for %s' %
                                           (var.name)))
                                # T^2 terms * 1/T^2, size: B x C
                                fpropFactor = tf.reduce_mean(patches, [1, 2])
                            else:
                                # size: (B x Oh x Ow) x C
                                fpropFactor = tf.reshape(
                                    patches, [-1, flatten_size]) / Oh / Ow
                    fpropFactor_size = int(fpropFactor.get_shape()[-1])
                    if stats_var_dim == (fpropFactor_size +
                                         1) and not self._blockdiag_bias:
                        if opType == 'Conv2D' and not self._approxT2:
                            # correct padding for numerical stability (we
                            # divided out OhxOw from activations for T1 approx)
                            fpropFactor = tf.concat([
                                fpropFactor,
                                tf.ones([tf.shape(fpropFactor)[0], 1]) / Oh /
                                Ow
                            ], 1)
                        else:
                            # use homogeneous coordinates
                            fpropFactor = tf.concat([
                                fpropFactor,
                                tf.ones([tf.shape(fpropFactor)[0], 1])
                            ], 1)

                    # average over the number of data points in a batch
                    # divided by B
                    cov = tf.matmul(fpropFactor, fpropFactor,
                                    transpose_a=True) / tf.cast(B, tf.float32)
                    updateOps.append(cov)
                    statsUpdates[stats_var] = cov
                    if opType != 'Conv2D':
                        # HACK: for convolution we recompute fprop stats for
                        # every layer including forking layers
                        statsUpdates_cache[stats_var] = cov

            for stats_var in bpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_bpropFactor = bpropFactor
                    bpropFactor_shape = bpropFactor.get_shape()
                    B = tf.shape(bpropFactor)[0]  # batch size
                    C = int(bpropFactor_shape[-1])  # num channels
                    if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
                        if fpropFactor is not None:
                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 grad fisher for %s' %
                                           (var.name)))
                                bpropFactor = tf.reduce_sum(
                                    bpropFactor, [1, 2])  # T^2 terms * 1/T^2
                            else:
                                bpropFactor = tf.reshape(
                                    bpropFactor,
                                    [-1, C]) * Oh * Ow  # T * 1/T terms
                        else:
                            # just doing block diag approx. spatial independent
                            # structure does not apply here. summing over
                            # spatial locations
                            if KFAC_DEBUG:
                                print(('block diag approx fisher for %s' %
                                       (var.name)))
                            bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])

                    # assume sampled loss is averaged. TO-DO:figure out better
                    # way to handle this
                    bpropFactor *= tf.to_float(B)
                    ##

                    cov_b = tf.matmul(bpropFactor,
                                      bpropFactor,
                                      transpose_a=True) / tf.to_float(
                                          tf.shape(bpropFactor)[0])

                    updateOps.append(cov_b)
                    statsUpdates[stats_var] = cov_b
                    statsUpdates_cache[stats_var] = cov_b

        if KFAC_DEBUG:
            aKey = list(statsUpdates.keys())[0]
            statsUpdates[aKey] = tf.Print(statsUpdates[aKey], [
                tf.convert_to_tensor('step:'),
                self.global_step,
                tf.convert_to_tensor('computing stats'),
            ])
        self.statsUpdates = statsUpdates
        return statsUpdates
コード例 #7
0
def detectMinVal(input_mat, var, threshold=1e-6, name='', debug=False):
    eigen_min = tf.reduce_min(input_mat)
    eigen_max = tf.reduce_max(input_mat)
    eigen_ratio = eigen_max / eigen_min
    input_mat_clipped = clipoutNeg(input_mat, threshold)

    if debug:
        input_mat_clipped = tf.cond(tf.logical_or(tf.greater(eigen_ratio, 0.), tf.less(eigen_ratio, -500)), lambda: input_mat_clipped, lambda: tf.Print(
            input_mat_clipped, [tf.convert_to_tensor('screwed ratio ' + name + ' eigen values!!!'), tf.convert_to_tensor(var.name), eigen_min, eigen_max, eigen_ratio]))

    return input_mat_clipped