Exemplo n.º 1
0
    def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.):
        updateOps = []
        # obtain the stats var list
        for stats_var in statsUpdates:
            stats_new = statsUpdates[stats_var]
            if accumulate:
                # simple superbatch averaging
                update_op = tf.assign_add(stats_var,
                                          accumulateCoeff * stats_new,
                                          use_locking=True)
            else:
                # exponential running averaging
                update_op = tf.assign(stats_var,
                                      stats_var * self._stats_decay,
                                      use_locking=True)
                update_op = tf.assign_add(update_op,
                                          (1. - self._stats_decay) * stats_new,
                                          use_locking=True)
            updateOps.append(update_op)

        with tf.control_dependencies(updateOps):
            stats_step_op = tf.assign_add(self.stats_step, 1)

        if KFAC_DEBUG:
            stats_step_op = (tf.Print(stats_step_op, [
                tf.convert_to_tensor('step:'), self.global_step,
                tf.convert_to_tensor('fac step:'), self.factor_step,
                tf.convert_to_tensor('sgd step:'), self.sgd_step,
                tf.convert_to_tensor('Accum:'),
                tf.convert_to_tensor(accumulate),
                tf.convert_to_tensor('Accum coeff:'),
                tf.convert_to_tensor(accumulateCoeff),
                tf.convert_to_tensor('stat step:'), self.stats_step,
                updateOps[0], updateOps[1]
            ]))
        return [
            stats_step_op,
        ]
Exemplo n.º 2
0
    def applyStatsEigen(self, eigen_list):
        updateOps = []
        print(('updating %d eigenvalue/vectors' % len(eigen_list)))
        for i, (tensor,
                mark) in enumerate(zip(eigen_list, self.eigen_update_list)):
            stats_eigen_var = self.eigen_reverse_lookup[mark]
            updateOps.append(
                tf.assign(stats_eigen_var, tensor, use_locking=True))

        with tf.control_dependencies(updateOps):
            factor_step_op = tf.assign_add(self.factor_step, 1)
            updateOps.append(factor_step_op)
            if KFAC_DEBUG:
                updateOps.append(
                    tf.Print(tf.constant(0.),
                             [tf.convert_to_tensor('updated kfac factors')]))
        return updateOps
Exemplo n.º 3
0
        def coldSGDstart():
            sgd_grads, sgd_var = zip(*grads)

            if self.max_grad_norm != None:
                sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(
                    sgd_grads, self.max_grad_norm)

            sgd_grads = list(zip(sgd_grads, sgd_var))

            sgd_step_op = tf.assign_add(self.sgd_step, 1)
            coldOptim_op = coldOptim.apply_gradients(sgd_grads)
            if KFAC_DEBUG:
                with tf.control_dependencies([sgd_step_op, coldOptim_op]):
                    sgd_step_op = tf.Print(sgd_step_op, [
                        self.sgd_step,
                        tf.convert_to_tensor('doing cold sgd step')
                    ])
            return tf.group(*[sgd_step_op, coldOptim_op])
Exemplo n.º 4
0
 def no_op_wrapper():
     return tf.group(*[tf.assign_add(self.cold_step, 1)])
Exemplo n.º 5
0
    def apply_gradients_kfac(self, grads):
        g, varlist = list(zip(*grads))

        if len(self.stats_eigen) == 0:
            self.getStatsEigen()

        qr = None
        # launch eigen-decomp on a queue thread
        if self._async:
            print('Use async eigen decomp')
            # get a list of factor loading tensors
            factorOps_dummy = self.computeStatsEigen()

            # define a queue for the list of factor loading tensors
            queue = tf.FIFOQueue(
                1, [item.dtype for item in factorOps_dummy],
                shapes=[item.get_shape() for item in factorOps_dummy])
            enqueue_op = tf.cond(
                tf.logical_and(
                    tf.equal(tf.mod(self.stats_step, self._kfac_update),
                             tf.convert_to_tensor(0)),
                    tf.greater_equal(self.stats_step, self._stats_accum_iter)),
                lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op)

            def dequeue_op():
                return queue.dequeue()

            qr = tf.train.QueueRunner(queue, [enqueue_op])

        updateOps = []
        global_step_op = tf.assign_add(self.global_step, 1)
        updateOps.append(global_step_op)

        with tf.control_dependencies([global_step_op]):

            # compute updates
            assert self._update_stats_op != None
            updateOps.append(self._update_stats_op)
            dependency_list = []
            if not self._async:
                dependency_list.append(self._update_stats_op)

            with tf.control_dependencies(dependency_list):

                def no_op_wrapper():
                    return tf.group(*[tf.assign_add(self.cold_step, 1)])

                if not self._async:
                    # synchronous eigen-decomp updates
                    updateFactorOps = tf.cond(
                        tf.logical_and(
                            tf.equal(
                                tf.mod(self.stats_step, self._kfac_update),
                                tf.convert_to_tensor(0)),
                            tf.greater_equal(self.stats_step,
                                             self._stats_accum_iter)),
                        lambda: tf.group(*self.applyStatsEigen(
                            self.computeStatsEigen())), no_op_wrapper)
                else:
                    # asynchronous eigen-decomp updates using queue
                    updateFactorOps = tf.cond(
                        tf.greater_equal(self.stats_step,
                                         self._stats_accum_iter),
                        lambda: tf.cond(
                            tf.equal(queue.size(), tf.convert_to_tensor(0)),
                            tf.no_op,
                            lambda: tf.group(*self.applyStatsEigen(dequeue_op(
                            ))),
                        ), no_op_wrapper)

                updateOps.append(updateFactorOps)

                with tf.control_dependencies([updateFactorOps]):

                    def gradOp():
                        return list(g)

                    def getKfacGradOp():
                        return self.getKfacPrecondUpdates(g, varlist)

                    u = tf.cond(
                        tf.greater(self.factor_step, tf.convert_to_tensor(0)),
                        getKfacGradOp, gradOp)

                    optim = tf.train.MomentumOptimizer(
                        self._lr * (1. - self._momentum), self._momentum)

                    #optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01)

                    def optimOp():
                        def updateOptimOp():
                            if self._full_stats_init:
                                return tf.cond(
                                    tf.greater(self.factor_step,
                                               tf.convert_to_tensor(0)),
                                    lambda: optim.apply_gradients(
                                        list(zip(u, varlist))), tf.no_op)
                            else:
                                return optim.apply_gradients(
                                    list(zip(u, varlist)))

                        if self._full_stats_init:
                            return tf.cond(
                                tf.greater_equal(self.stats_step,
                                                 self._stats_accum_iter),
                                updateOptimOp, tf.no_op)
                        else:
                            return tf.cond(
                                tf.greater_equal(self.sgd_step,
                                                 self._cold_iter),
                                updateOptimOp, tf.no_op)

                    updateOps.append(optimOp())

        return tf.group(*updateOps), qr