Ejemplo n.º 1
0
    def __init__(self, lr_tensor, recorder):
        """Ctor.

    Args:
      lr_tensor: Tensor holding the learning rate during training
      recorder: MeasurementsRecorder that records the overlap.
        If None, set_recorder() should be called before training.
    """
        super(LearningRateSchedule,
              self).__init__(meas.Frequency(freq=1, stepwise=True),
                             recorder=None)
        self.lr_tensor = lr_tensor
        self.recorder = recorder
Ejemplo n.º 2
0
    def test_gradient_measurement(self):
        """Test that the full-batch gradient is computed correctly."""
        K.clear_session()

        d = 12
        n = 20
        batch_size = n // 4
        x = np.random.rand(n, d).astype(np.float32)
        y = np.sin(2 * np.pi * x[:, 0]).reshape((-1, 1)).astype(np.float32)

        x_test = np.random.rand(n, d).astype(np.float32)
        y_test = np.sin(2 * np.pi * x_test[:, 0]).reshape(
            (-1, 1)).astype(np.float32)

        # Linear regression
        model = keras.models.Sequential()
        model.add(keras.layers.Dense(1, use_bias=False, input_shape=(d, )))
        model.add(keras.layers.Activation('relu'))
        model.add(keras.layers.Dense(10))
        model.add(keras.layers.Activation('relu'))
        model.add(keras.layers.Dense(1))
        model.compile(loss='mean_squared_error',
                      optimizer=keras.optimizers.SGD())

        tfutils.keras_compute_tensors(model, x, y, model.total_loss)

        grad_t = tfutils.flatten_tensor_list(
            tf.gradients(model.total_loss, model.trainable_weights))
        grad = tfutils.keras_compute_tensors(model, x, y, grad_t)

        train_batches = tfutils.MiniBatchMaker(x, y, batch_size)
        test_batches = tfutils.MiniBatchMaker(x_test, y_test, batch_size)

        meas = measurements.GradientMeasurement(
            MockRecorder(), model,
            measurements.Frequency(freq=1, stepwise=False), train_batches,
            test_batches)

        meas.on_epoch_begin(0)
        meas.on_batch_begin(0)
        meas.on_batch_end(0)
        meas.on_epoch_end(0)
        actual_grad = meas.full_batch_g
        self.assertTrue(np.allclose(grad, actual_grad))
Ejemplo n.º 3
0
    def __init__(self, lr_tensor, eta0, alpha=None, T=None):
        """Ctor. The learning rate at step t will be given by:

    eta(t) = eta0 - (1-alpha) * eta0 * t / T   if t <= T
    eta(t) = alpha * eta0                      if t > T

    Args:
      lr_tensor: Tensor holding the learning rate during training
      eta0: Initial learning rate
      alpha: Linear decay coefficient, or None to keep constant lr
      T: Time at which to stop decaying, or None to keep constant lr
    """
        super(LearningRateLinearDecaySchedule,
              self).__init__(meas.Frequency(freq=1, stepwise=True),
                             recorder=None)
        self.lr_tensor = lr_tensor
        self.eta0 = eta0
        self.alpha = alpha
        self.T = T
Ejemplo n.º 4
0
def add_callbacks(callbacks, recorder, model, x_train, y_train, x_test, y_test,
                  lr_schedule):
    """Add measurement callbacks."""

    # TODO convert to Dataset
    def get_batch_makers(batch_size):
        """Returns train and test mini-batch makers."""
        train_batches = tfutils.MiniBatchMaker(x_train, y_train, batch_size)
        test_batches = tfutils.MiniBatchMaker(x_test, y_test, batch_size)
        return train_batches, test_batches

    if xFLAGS.loss_and_acc is not None:
        train_batches, test_batches = get_batch_makers(
            xFLAGS.measure_batch_size)
        freq = meas.Frequency.from_string(xFLAGS.loss_and_acc)
        loss_acc_cb = meas.BasicMetricsMeasurement(
            recorder,
            model,
            freq,
            train_batches,
            test_batches,
            lr_schedule,
            show_progress=not xFLAGS.show_progress_bar)
        callbacks.append(loss_acc_cb)

        weight_norm_cb = meas.WeightNormMeasurement(recorder, model, freq)
        callbacks.append(weight_norm_cb)

    grad_cb = None
    if xFLAGS.gradients is not None:
        train_batches, test_batches = get_batch_makers(
            xFLAGS.measure_batch_size)
        freq = meas.Frequency.from_string(xFLAGS.gradients)
        grad_cb = meas.GradientMeasurement(recorder, model, freq,
                                           train_batches, test_batches,
                                           xFLAGS.random_overlap)
        callbacks.append(grad_cb)

    if xFLAGS.hessian is not None:
        freq = meas.Frequency.from_string(xFLAGS.hessian)
        hess_cb = meas.LanczosHessianMeasurement(recorder,
                                                 model,
                                                 freq,
                                                 xFLAGS.hessian_num_evs,
                                                 x_train,
                                                 y_train,
                                                 xFLAGS.hessian_batch_size,
                                                 lr=xFLAGS.lr,
                                                 log_dir=xFLAGS.runlogdir)
        callbacks.append(hess_cb)

    if xFLAGS.last_layer_hessian is not None:
        freq = meas.Frequency.from_string(xFLAGS.last_layer_hessian)
        if xFLAGS.use_bias:
            weights = model.trainable_weights[-2:]
        else:
            weights = model.trainable_weights[-1:]
        num_weights = tfutils.num_weights(weights)
        grad_subvec = lambda g: g[-num_weights:]
        ll_hess_cb = meas.LanczosHessianMeasurement(recorder,
                                                    model,
                                                    freq,
                                                    xFLAGS.hessian_num_evs,
                                                    x_train,
                                                    y_train,
                                                    xFLAGS.hessian_batch_size,
                                                    lr=xFLAGS.lr,
                                                    log_dir=xFLAGS.runlogdir,
                                                    weights=weights,
                                                    grad_subvec=grad_subvec,
                                                    name=meas.LAST_LAYER)
        callbacks.append(ll_hess_cb)

    if xFLAGS.full_hessian is not None:
        train_batches, test_batches = get_batch_makers(
            xFLAGS.full_hessian_batch_size)
        freq = meas.Frequency.from_string(xFLAGS.full_hessian)
        full_hess_cb = meas.FullHessianMeasurement(
            recorder,
            model,
            freq,
            train_batches,
            xFLAGS.runlogdir,
            num_eigenvector_correlations=xFLAGS.output_dim)
        callbacks.append(full_hess_cb)

    if xFLAGS.interpolate_loss is not None:
        train_batches, test_batches = get_batch_makers(
            xFLAGS.measure_batch_size)
        freq = meas.Frequency.from_string(xFLAGS.interpolate_loss)
        loss_interp_cb = meas.LossInterpolationMeasurement(
            recorder, model, freq, train_batches, test_batches)
        callbacks.append(loss_interp_cb)

    if xFLAGS.dataset == 'gaussians':
        freq = meas.Frequency(1, xFLAGS.measure_gaussians_every_step)
        gauss_cb = meas.GaussiansMeasurement(recorder, model, freq, x_train,
                                             y_train)
        callbacks.append(gauss_cb)