Exemple #1
0
    def getStatsEigen(self, stats=None):
        if len(self.stats_eigen) == 0:
            stats_eigen = {}
            if stats is None:
                stats = self.stats

            tmpEigenCache = {}
            with tf.device('/cpu:0'):
                for var in stats:
                    for key in ['fprop_concat_stats', 'bprop_concat_stats']:
                        for stats_var in stats[var][key]:
                            if stats_var not in tmpEigenCache:
                                stats_dim = stats_var.get_shape()[1].value
                                e = tf.Variable(tf.ones([stats_dim]),
                                                name='KFAC_FAC/' +
                                                stats_var.name.split(':')[0] +
                                                '/e',
                                                trainable=False)
                                Q = tf.Variable(tf.diag(tf.ones([stats_dim])),
                                                name='KFAC_FAC/' +
                                                stats_var.name.split(':')[0] +
                                                '/Q',
                                                trainable=False)
                                stats_eigen[stats_var] = {'e': e, 'Q': Q}
                                tmpEigenCache[stats_var] = stats_eigen[
                                    stats_var]
                            else:
                                stats_eigen[stats_var] = tmpEigenCache[
                                    stats_var]
            self.stats_eigen = stats_eigen
        return self.stats_eigen
Exemple #2
0
def test_MpiAdam():
    np.random.seed(0)
    tf.set_random_seed(0)

    a = tf.Variable(np.random.randn(3).astype('float32'))
    b = tf.Variable(np.random.randn(2, 5).astype('float32'))
    loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))

    stepsize = 1e-2
    update_op = tf.train.AdamOptimizer(stepsize).minimize(loss)
    do_update = U.function([], loss, updates=[update_op])

    tf.get_default_session().run(tf.global_variables_initializer())
    for i in range(10):
        print(i, do_update())

    tf.set_random_seed(0)
    tf.get_default_session().run(tf.global_variables_initializer())

    var_list = [a, b]
    lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)],
                             updates=[update_op])
    adam = MpiAdam(var_list)

    for i in range(10):
        l, g = lossandgrad()
        adam.update(g, stepsize)
        print(i, l)
def build_summaries():
    episode_reward = tf.Variable(0.)
    tf.summary.scalar("Reward", episode_reward)
    episode_ave_max_q = tf.Variable(0.)
    tf.summary.scalar("Qmax Value", episode_ave_max_q)

    summary_vars = [episode_reward, episode_ave_max_q]
    summary_ops = tf.summary.merge_all()

    return summary_ops, summary_vars
Exemple #4
0
def get_xavier_weights(filter_shape, poolsize=(2, 2)):
    fan_in = np.prod(filter_shape[1:])
    fan_out = (filter_shape[0] * np.prod(filter_shape[2:]) //
               np.prod(poolsize))

    low = -4*np.sqrt(6.0/(fan_in + fan_out)) # use 4 for sigmoid, 1 for tanh activation
    high = 4*np.sqrt(6.0/(fan_in + fan_out))
    return tf.Variable(tf.random_uniform(filter_shape, minval=low, maxval=high, dtype=tf.float32))
Exemple #5
0
    def getStats(self, factors, varlist):
        if len(self.stats) == 0:
            # initialize stats variables on CPU because eigen decomp is
            # computed on CPU
            with tf.device('/cpu'):
                tmpStatsCache = {}

                # search for tensor factors and
                # use block diag approx for the bias units
                for var in varlist:
                    fpropFactor = factors[var]['fpropFactors_concat']
                    bpropFactor = factors[var]['bpropFactors_concat']
                    opType = factors[var]['opName']
                    if opType == 'Conv2D':
                        Kh = var.get_shape()[0]
                        Kw = var.get_shape()[1]
                        C = fpropFactor.get_shape()[-1]

                        Oh = bpropFactor.get_shape()[1]
                        Ow = bpropFactor.get_shape()[2]
                        if Oh == 1 and Ow == 1 and self._channel_fac:
                            # factorization along the channels do not support
                            # homogeneous coordinate
                            var_assnBias = factors[var]['assnBias']
                            if var_assnBias:
                                factors[var]['assnBias'] = None
                                factors[var_assnBias]['assnWeights'] = None
                ##

                for var in varlist:
                    fpropFactor = factors[var]['fpropFactors_concat']
                    bpropFactor = factors[var]['bpropFactors_concat']
                    opType = factors[var]['opName']
                    self.stats[var] = {
                        'opName': opType,
                        'fprop_concat_stats': [],
                        'bprop_concat_stats': [],
                        'assnWeights': factors[var]['assnWeights'],
                        'assnBias': factors[var]['assnBias'],
                    }
                    if fpropFactor is not None:
                        if fpropFactor not in tmpStatsCache:
                            if opType == 'Conv2D':
                                Kh = var.get_shape()[0]
                                Kw = var.get_shape()[1]
                                C = fpropFactor.get_shape()[-1]

                                Oh = bpropFactor.get_shape()[1]
                                Ow = bpropFactor.get_shape()[2]
                                if Oh == 1 and Ow == 1 and self._channel_fac:
                                    # factorization along the channels
                                    # assume independence between input channels and spatial
                                    # 2K-1 x 2K-1 covariance matrix and C x C covariance matrix
                                    # factorization along the channels do not
                                    # support homogeneous coordinate, assnBias
                                    # is always None
                                    fpropFactor2_size = Kh * Kw
                                    slot_fpropFactor_stats2 = tf.Variable(
                                        tf.diag(tf.ones([fpropFactor2_size])) *
                                        self._diag_init_coeff,
                                        name='KFAC_STATS/' +
                                        fpropFactor.op.name,
                                        trainable=False)
                                    self.stats[var][
                                        'fprop_concat_stats'].append(
                                            slot_fpropFactor_stats2)

                                    fpropFactor_size = C
                                else:
                                    # 2K-1 x 2K-1 x C x C covariance matrix
                                    # assume BHWC
                                    fpropFactor_size = Kh * Kw * C
                            else:
                                # D x D covariance matrix
                                fpropFactor_size = fpropFactor.get_shape()[-1]

                            # use homogeneous coordinate
                            if not self._blockdiag_bias and self.stats[var][
                                    'assnBias']:
                                fpropFactor_size += 1

                            slot_fpropFactor_stats = tf.Variable(
                                tf.diag(tf.ones([fpropFactor_size])) *
                                self._diag_init_coeff,
                                name='KFAC_STATS/' + fpropFactor.op.name,
                                trainable=False)
                            self.stats[var]['fprop_concat_stats'].append(
                                slot_fpropFactor_stats)
                            if opType != 'Conv2D':
                                tmpStatsCache[fpropFactor] = self.stats[var][
                                    'fprop_concat_stats']
                        else:
                            self.stats[var][
                                'fprop_concat_stats'] = tmpStatsCache[
                                    fpropFactor]

                    if bpropFactor is not None:
                        # no need to collect backward stats for bias vectors if
                        # using homogeneous coordinates
                        if not ((not self._blockdiag_bias)
                                and self.stats[var]['assnWeights']):
                            if bpropFactor not in tmpStatsCache:
                                slot_bpropFactor_stats = tf.Variable(
                                    tf.diag(
                                        tf.ones([bpropFactor.get_shape()[-1]
                                                 ])) * self._diag_init_coeff,
                                    name='KFAC_STATS/' + bpropFactor.op.name,
                                    trainable=False)
                                self.stats[var]['bprop_concat_stats'].append(
                                    slot_bpropFactor_stats)
                                tmpStatsCache[bpropFactor] = self.stats[var][
                                    'bprop_concat_stats']
                            else:
                                self.stats[var][
                                    'bprop_concat_stats'] = tmpStatsCache[
                                        bpropFactor]

        return self.stats
Exemple #6
0
    def __init__(self,
                 learning_rate=0.01,
                 momentum=0.9,
                 clip_kl=0.01,
                 kfac_update=2,
                 stats_accum_iter=60,
                 full_stats_init=False,
                 cold_iter=100,
                 cold_lr=None,
                 is_async=False,
                 async_stats=False,
                 epsilon=1e-2,
                 stats_decay=0.95,
                 blockdiag_bias=False,
                 channel_fac=False,
                 factored_damping=False,
                 approxT2=False,
                 use_float64=False,
                 weight_decay_dict={},
                 max_grad_norm=0.5):
        self.max_grad_norm = max_grad_norm
        self._lr = learning_rate
        self._momentum = momentum
        self._clip_kl = clip_kl
        self._channel_fac = channel_fac
        self._kfac_update = kfac_update
        self._async = is_async
        self._async_stats = async_stats
        self._epsilon = epsilon
        self._stats_decay = stats_decay
        self._blockdiag_bias = blockdiag_bias
        self._approxT2 = approxT2
        self._use_float64 = use_float64
        self._factored_damping = factored_damping
        self._cold_iter = cold_iter
        if cold_lr == None:
            # good heuristics
            self._cold_lr = self._lr  # * 3.
        else:
            self._cold_lr = cold_lr
        self._stats_accum_iter = stats_accum_iter
        self._weight_decay_dict = weight_decay_dict
        self._diag_init_coeff = 0.
        self._full_stats_init = full_stats_init
        if not self._full_stats_init:
            self._stats_accum_iter = self._cold_iter

        self.sgd_step = tf.Variable(0, name='KFAC/sgd_step', trainable=False)
        self.global_step = tf.Variable(0,
                                       name='KFAC/global_step',
                                       trainable=False)
        self.cold_step = tf.Variable(0, name='KFAC/cold_step', trainable=False)
        self.factor_step = tf.Variable(0,
                                       name='KFAC/factor_step',
                                       trainable=False)
        self.stats_step = tf.Variable(0,
                                      name='KFAC/stats_step',
                                      trainable=False)
        self.vFv = tf.Variable(0., name='KFAC/vFv', trainable=False)

        self.factors = {}
        self.param_vars = []
        self.stats = {}
        self.stats_eigen = {}