Example #1
0
    def __init__(self, units, dt, num_grids, **kwargs):
        super().__init__(**kwargs)
        self.dt = dt
        self.num_grids = num_grids

        t0 = tf.constant(0.)
        self.tN = t0 + num_grids * dt

        self._model = tf.keras.Sequential([
            tf.keras.layers.Dense(128,
                                  activation='relu',
                                  kernel_initializer=GlorotUniform(1e-1)),
            tf.keras.layers.Dense(units,
                                  activation='relu',
                                  kernel_initializer=GlorotUniform(1e-1)),
        ])
        self._model.build([None, units])

        @tf.function
        def fn(t, x):
            z = self._model(x)
            with tf.GradientTape() as g:
                g.watch(x)
                r = normalize(x, axis=-1)
            return g.gradient(r, x, z)

        self._node_fn = get_node_function(RKSolver(self.dt), tf.constant(0.),
                                          fn)
Example #2
0
    def __init__(self,
                 t0,
                 t1,
                 solver=RKSolver(1e-1),
                 validate_args=False,
                 name='cnf'):
        super().__init__(forward_min_event_ndims=1,
                         validate_args=validate_args,
                         name=name)

        self.t0 = tf.convert_to_tensor(t0)
        self.t1 = tf.convert_to_tensor(t1)
        self.solver = solver

        self._forward_fn = get_node_function(self.solver, self.t0,
                                             self._dynamics)
        self._inverse_fn = get_node_function(self.solver, self.t1,
                                             self._dynamics)
        self._forward_log_prob_fn = get_node_function(self.solver, self.t0,
                                                      self._log_prob_dynamics)
Example #3
0
    def __init__(self, units, dt, num_grids, **kwargs):
        super().__init__(**kwargs)
        self.dt = dt
        self.num_grids = num_grids

        t0 = tf.constant(0., dtype=DTYPE)
        self.tN = t0 + num_grids * dt

        self._model = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation='relu', dtype=DTYPE),
            tf.keras.layers.Dense(units, dtype=DTYPE),
        ])
        self._model.build([None, units])

        self._raw_pvf = lambda _, x: self._model(x)
        self._energy = Energy(identity, self._raw_pvf)
        self._pvf = energy_based(identity, self._energy)
        self._node_fn = get_node_function(RKSolver(self.dt, dtype=DTYPE),
                                          tf.constant(0., dtype=DTYPE),
                                          self._pvf)
Example #4
0
    def __init__(self, filters, kernel_size, t, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.solver = RKSolver(0.1, dtype=tf.float32)
        self.t = tf.convert_to_tensor(t)

        self.convolve = tf.keras.layers.Conv2D(
            filters, kernel_size, activation='relu', padding='same',
            kernel_initializer=GlorotUniform(0.1), dtype=tf.float32)

        @tf.function
        def pvf(t, x):
            z = self.convolve(x)
            with tf.GradientTape() as g:
                g.watch(x)
                r = normalize(x, axis=[-3, -2])
            return g.gradient(r, x, z)

        self._pvf = pvf
        t0 = tf.convert_to_tensor(0.)
        self._node_fn = get_node_function(self.solver, t0, pvf)
Example #5
0

model = tf.keras.Sequential(
    [tf.keras.layers.Dense(50, activation="tanh"),
     tf.keras.layers.Dense(2)])
model.build([None, 2])
var_list = model.trainable_variables


@tf.function
def network(t, x):
    h = x**3
    return model(h)


node_network = get_node_function(RKSolver(dt), t0, network)


def get_batch():
    """Returns initial point and last point over sampled frament of
    trajectory"""
    starts = np.random.choice(np.arange(data_size - batch_time - 1,
                                        dtype=np.int64),
                              batch_size,
                              replace=False)
    ends = starts + batch_time
    batch_y0 = true_y[starts]  # (batch_size, 2) -> initial point
    batch_yN = true_y[ends]
    return tf.constant(batch_y0), tf.constant(batch_yN)