Exemple #1
0
    def compute_gradients(self, loss, var_list=None):
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()
        g = tf.gradients(loss, varlist)

        return [(a, b) for a, b in zip(g, varlist)]
Exemple #2
0
    def compute_and_apply_stats(self, loss_sampled, var_list=None):
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()

        stats = self.compute_stats(loss_sampled, var_list=varlist)
        return self.apply_stats(stats)
    def __init__(self,
                 loss_scalar,
                 solver_name='adam',
                 base_lr=None,
                 lr_policy=None,
                 momentum=None,
                 weight_decay=None,
                 fc_vars=None,
                 last_conv_vars=None,
                 vars_to_opt=None):
        self.base_lr = base_lr
        self.lr_policy = lr_policy
        self.momentum = momentum
        self.solver_name = solver_name
        self.loss_scalar = loss_scalar
        if self.lr_policy != 'fixed':
            raise NotImplementedError(
                'learning rate policies other than fixed are not implemented')

        self.weight_decay = weight_decay
        if weight_decay is not None:
            if vars_to_opt is None:
                trainable_vars = tf.trainable_variables()
            else:
                trainable_vars = vars_to_opt
            loss_with_reg = self.loss_scalar
            for var in trainable_vars:
                loss_with_reg += self.weight_decay * tf.nn.l2_loss(var)
            self.loss_scalar = loss_with_reg

        self.solver_op = self.get_solver_op()
        if fc_vars is not None:
            self.fc_vars = fc_vars
            self.last_conv_vars = last_conv_vars
            self.fc_solver_op = self.get_solver_op(var_list=fc_vars)
 def get_solver_op(self, var_list=None, loss=None):
     solver_string = self.solver_name.lower()
     if var_list is None:
         var_list = tf.trainable_variables()
     if loss is None:
         loss = self.loss_scalar
     if solver_string == 'adam':
         return tf.train.AdamOptimizer(learning_rate=self.base_lr,
                                       beta1=self.momentum).minimize(
                                           loss, var_list=var_list)
     elif solver_string == 'rmsprop':
         return tf.train.RMSPropOptimizer(learning_rate=self.base_lr,
                                          decay=self.momentum).minimize(
                                              loss, var_list=var_list)
     elif solver_string == 'momentum':
         return tf.train.MomentumOptimizer(learning_rate=self.base_lr,
                                           momentum=self.momentum).minimize(
                                               loss, var_list=var_list)
     elif solver_string == 'adagrad':
         return tf.train.AdagradOptimizer(
             learning_rate=self.base_lr,
             initial_accumulator_value=self.momentum).minimize(
                 loss, var_list=var_list)
     elif solver_string == 'sgd':
         return tf.train.GradientDescentOptimizer(
             learning_rate=self.base_lr).minimize(loss, var_list=var_list)
     else:
         raise NotImplementedError("Please select a valid optimizer.")
Exemple #5
0
    def compute_stats(self, loss_sampled, var_list=None):
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()

        gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
        self.gs = gs
        factors = self.getFactors(gs, varlist)
        stats = self.getStats(factors, varlist)

        updateOps = []
        statsUpdates = {}
        statsUpdates_cache = {}
        for var in varlist:
            opType = factors[var]['opName']
            fops = factors[var]['op']
            fpropFactor = factors[var]['fpropFactors_concat']
            fpropStats_vars = stats[var]['fprop_concat_stats']
            bpropFactor = factors[var]['bpropFactors_concat']
            bpropStats_vars = stats[var]['bprop_concat_stats']
            SVD_factors = {}
            for stats_var in fpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_fpropFactor = fpropFactor
                    B = (tf.shape(fpropFactor)[0])  # batch size
                    if opType == 'Conv2D':
                        strides = fops.get_attr("strides")
                        padding = fops.get_attr("padding")
                        convkernel_size = var.get_shape()[0:3]

                        KH = int(convkernel_size[0])
                        KW = int(convkernel_size[1])
                        C = int(convkernel_size[2])
                        flatten_size = int(KH * KW * C)

                        Oh = int(bpropFactor.get_shape()[1])
                        Ow = int(bpropFactor.get_shape()[2])

                        if Oh == 1 and Ow == 1 and self._channel_fac:
                            # factorization along the channels
                            # assume independence among input channels
                            # factor = B x 1 x 1 x (KH xKW x C)
                            # patches = B x Oh x Ow x (KH xKW x C)
                            if len(SVD_factors) == 0:
                                if KFAC_DEBUG:
                                    print((
                                        'approx %s act factor with rank-1 SVD factors'
                                        % (var.name)))
                                # find closest rank-1 approx to the feature map
                                S, U, V = tf.batch_svd(
                                    tf.reshape(fpropFactor, [-1, KH * KW, C]))
                                # get rank-1 approx slides
                                sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
                                patches_k = U[:, :, 0] * sqrtS1  # B x KH*KW
                                full_factor_shape = fpropFactor.get_shape()
                                patches_k.set_shape(
                                    [full_factor_shape[0], KH * KW])
                                patches_c = V[:, :, 0] * sqrtS1  # B x C
                                patches_c.set_shape([full_factor_shape[0], C])
                                SVD_factors[C] = patches_c
                                SVD_factors[KH * KW] = patches_k
                            fpropFactor = SVD_factors[stats_var_dim]

                        else:
                            # poor mem usage implementation
                            patches = tf.extract_image_patches(
                                fpropFactor,
                                ksizes=[
                                    1, convkernel_size[0], convkernel_size[1],
                                    1
                                ],
                                strides=strides,
                                rates=[1, 1, 1, 1],
                                padding=padding)

                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 act fisher for %s' %
                                           (var.name)))
                                # T^2 terms * 1/T^2, size: B x C
                                fpropFactor = tf.reduce_mean(patches, [1, 2])
                            else:
                                # size: (B x Oh x Ow) x C
                                fpropFactor = tf.reshape(
                                    patches, [-1, flatten_size]) / Oh / Ow
                    fpropFactor_size = int(fpropFactor.get_shape()[-1])
                    if stats_var_dim == (fpropFactor_size +
                                         1) and not self._blockdiag_bias:
                        if opType == 'Conv2D' and not self._approxT2:
                            # correct padding for numerical stability (we
                            # divided out OhxOw from activations for T1 approx)
                            fpropFactor = tf.concat([
                                fpropFactor,
                                tf.ones([tf.shape(fpropFactor)[0], 1]) / Oh /
                                Ow
                            ], 1)
                        else:
                            # use homogeneous coordinates
                            fpropFactor = tf.concat([
                                fpropFactor,
                                tf.ones([tf.shape(fpropFactor)[0], 1])
                            ], 1)

                    # average over the number of data points in a batch
                    # divided by B
                    cov = tf.matmul(fpropFactor, fpropFactor,
                                    transpose_a=True) / tf.cast(B, tf.float32)
                    updateOps.append(cov)
                    statsUpdates[stats_var] = cov
                    if opType != 'Conv2D':
                        # HACK: for convolution we recompute fprop stats for
                        # every layer including forking layers
                        statsUpdates_cache[stats_var] = cov

            for stats_var in bpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_bpropFactor = bpropFactor
                    bpropFactor_shape = bpropFactor.get_shape()
                    B = tf.shape(bpropFactor)[0]  # batch size
                    C = int(bpropFactor_shape[-1])  # num channels
                    if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
                        if fpropFactor is not None:
                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 grad fisher for %s' %
                                           (var.name)))
                                bpropFactor = tf.reduce_sum(
                                    bpropFactor, [1, 2])  # T^2 terms * 1/T^2
                            else:
                                bpropFactor = tf.reshape(
                                    bpropFactor,
                                    [-1, C]) * Oh * Ow  # T * 1/T terms
                        else:
                            # just doing block diag approx. spatial independent
                            # structure does not apply here. summing over
                            # spatial locations
                            if KFAC_DEBUG:
                                print(('block diag approx fisher for %s' %
                                       (var.name)))
                            bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])

                    # assume sampled loss is averaged. TO-DO:figure out better
                    # way to handle this
                    bpropFactor *= tf.to_float(B)
                    ##

                    cov_b = tf.matmul(bpropFactor,
                                      bpropFactor,
                                      transpose_a=True) / tf.to_float(
                                          tf.shape(bpropFactor)[0])

                    updateOps.append(cov_b)
                    statsUpdates[stats_var] = cov_b
                    statsUpdates_cache[stats_var] = cov_b

        if KFAC_DEBUG:
            aKey = list(statsUpdates.keys())[0]
            statsUpdates[aKey] = tf.Print(statsUpdates[aKey], [
                tf.convert_to_tensor('step:'),
                self.global_step,
                tf.convert_to_tensor('computing stats'),
            ])
        self.statsUpdates = statsUpdates
        return statsUpdates
def _serialize_variables():
    sess = get_session()
    variables = tf.trainable_variables()
    values = sess.run(variables)
    return {var.name: value for var, value in zip(variables, values)}