コード例 #1
0
def ConvBlock(kernel_size,
              filters,
              strides=(2, 2),
              batchnorm=True,
              parameterization='standard',
              nonlin=Relu):
    ks = kernel_size
    filters1, filters2, filters3 = filters
    if parameterization == 'standard':

        def MyConv(*args, **kwargs):
            return Conv(*args, **kwargs)
    elif parameterization == 'ntk':

        def MyConv(*args, **kwargs):
            return stax.Conv(*args, **kwargs)[:2]

    if batchnorm:
        Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), BatchNorm(),
                               nonlin,
                               MyConv(filters2, (ks, ks), padding='SAME'),
                               BatchNorm(), nonlin, MyConv(filters3, (1, 1)),
                               BatchNorm())
        Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides),
                                   BatchNorm())
    else:
        Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), nonlin,
                               MyConv(filters2, (ks, ks), padding='SAME'),
                               nonlin, MyConv(filters3, (1, 1)))
        Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides))
    return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut),
                           FanInSum, nonlin)
コード例 #2
0
def IdentityBlock(kernel_size,
                  filters,
                  batchnorm=True,
                  parameterization='standard',
                  nonlin=Relu):
    ks = kernel_size
    filters1, filters2 = filters
    if parameterization == 'standard':

        def MyConv(*args, **kwargs):
            return Conv(*args, **kwargs)
    elif parameterization == 'ntk':

        def MyConv(*args, **kwargs):
            return stax.Conv(*args, **kwargs)[:2]

    def make_main(input_shape):
        # the number of output channels depends on the number of input channels
        if batchnorm:
            return jax_stax.serial(MyConv(filters1, (1, 1)), BatchNorm(),
                                   nonlin,
                                   MyConv(filters2, (ks, ks), padding='SAME'),
                                   BatchNorm(), nonlin,
                                   MyConv(input_shape[3], (1, 1)), BatchNorm())
        else:
            return jax_stax.serial(MyConv(filters1, (1, 1)), nonlin,
                                   MyConv(filters2, (ks, ks), padding='SAME'),
                                   nonlin, MyConv(input_shape[3], (1, 1)))

    Main = jax_stax.shape_dependent(make_main)
    return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Identity),
                           FanInSum, nonlin)
コード例 #3
0
def WideResnetBlock(channels,
                    strides=(1, 1),
                    channel_mismatch=False,
                    nonlin=Relu,
                    parameterization='standard',
                    order=None):
    Main = jax_stax.serial(
        nonlin,
        MyConv(channels, (3, 3),
               strides,
               padding='SAME',
               parameterization=parameterization,
               order=order), nonlin,
        MyConv(channels, (3, 3),
               padding='SAME',
               parameterization=parameterization,
               order=order))
    Shortcut = Identity if not channel_mismatch else MyConv(
        channels, (3, 3),
        strides,
        padding='SAME',
        parameterization=parameterization,
        order=order)
    return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut),
                           FanInSum)
コード例 #4
0
ファイル: resnet.py プロジェクト: worldblue0214/bayesian-sde
def CifarBasicBlockv2(planes,
                      stride=1,
                      option="A",
                      normalization_method=None,
                      use_fixup=False,
                      num_layers=None,
                      w_init=None,
                      actfn=stax.Relu):
    assert not use_fixup, "nah"
    Main = stax.serial(
        maybe_use_normalization(normalization_method),
        actfn,
        Conv(planes, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             W_init=w_init,
             bias=False),
        maybe_use_normalization(normalization_method),
        actfn,
        Conv(planes, (3, 3), padding="SAME", W_init=w_init, bias=False),
    )
    Shortcut = Identity
    if stride > 1:
        if option == "A":
            # For CIFAR10 ResNet paper uses option A.
            Shortcut = LambdaLayer(_shortcut_pad)
        elif option == "B":
            Shortcut = Conv(planes, (1, 1),
                            strides=(stride, stride),
                            W_init=w_init,
                            bias=False)
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum)
コード例 #5
0
def ConvBlock(kernel_size, filters, strides=(2, 2)):
    ks = kernel_size
    filters1, filters2, filters3 = filters
    Main = stax.serial(Conv(filters1, (1, 1), strides), BatchNorm(), Relu,
                       Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(),
                       Relu, Conv(filters3, (1, 1)), BatchNorm())
    Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       Relu)
コード例 #6
0
def ResidualBlock(out_channels, kernel_size, stride, padding, input_format):
    double_conv = stax.serial(
        GeneralConv(input_format, out_channels, kernel_size, stride, padding),
        Elu,
        GeneralConv(input_format, out_channels, kernel_size, stride, padding),
    )
    return Module(
        *stax.serial(FanOut(2), stax.parallel(double_conv, Identity), FanInSum)
    )
コード例 #7
0
def PolicyNetwork():
    """Policy network for the experiments in:
    https://arxiv.org/abs/2102.12425"""
    return serial(
        helx.nn.rnn.LSTM(256),
        Dense(256),
        Relu,
        FanOut(2),
        parallel(Dense(1), Dense(1)),
    )
コード例 #8
0
def IdentityBlock(kernel_size, filters):
  ks = kernel_size
  filters1, filters2 = filters
  def make_main(input_shape):
    # the number of output channels depends on the number of input channels
    return stax.serial(
        Conv(filters1, (1, 1)), BatchNorm(), Relu,
        Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
        Conv(input_shape[3], (1, 1)), BatchNorm())
  Main = stax.shape_dependent(make_main)
  return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
def Lpg(hparams):
    phi = serial(Dense(16), Dense(1))
    return serial(
        # FanOut(6),
        parallel(Identity, Identity, Identity, Identity, phi, phi),
        FanInConcat(),
        LSTMCell(hparams.hidden_size)[0:2],
        DiscardHidden(),
        Relu,
        FanOut(2),
        parallel(phi, phi),
    )
コード例 #10
0
def SyntheticReturn(features_network):
    """Synthetic return module as described in:
    https://arxiv.org/abs/2102.12425,
    Raposo, D., Synthetic Returns for Long-Term Credit Assignment, 2021."""
    #  sigmoid gate
    g = lambda: serial(Dense(256), Relu, Dense(1), Relu, Dense(1), Sigmoid)
    #  state utility contribution
    c = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1))
    #  state utility baseline
    b = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1))
    return serial(features_network, Flatten, FanOut(3),
                  parallel(g(), c(), b()))
コード例 #11
0
def ResNet(
    hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format
):
    residual = stax.serial(
        GeneralConv(input_format, hidden_channels, kernel_size, strides, padding),
        *[
            ResidualBlock(hidden_channels, kernel_size, strides, padding, input_format)
            for _ in range(depth)
        ],
        GeneralConv(input_format, out_channels, kernel_size, strides, padding)
    )
    return Module(
        *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1))
    )
コード例 #12
0
ファイル: resnet.py プロジェクト: ls-da3m0ns/JAXified_ML
def convBlock(ks, filters, stride=(1, 1)):
    Main = stax.serial(Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(),
                       Relu, Conv(filters[1], (ks, ks), strides=stride),
                       BatchNorm(), Relu,
                       Conv(filters[2], (1, 1),
                            strides=(1, 1)), BatchNorm(), Relu)

    Shortcut = stax.serial(
        Conv(filters[3], (1, 1), strides=stride),
        BatchNorm(),
    )

    fullInternal = stax.parallel(Main, Shortcut)

    return stax.serial(FanOut(2), fullInternal, FanInSum, Relu)
コード例 #13
0
def RevNet(
    hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format
):
    residual = stax.serial(  #
        Split(input_format[0].lower().index("c")),
        GeneralConv(input_format, hidden_channels, kernel_size, strides, padding),
        *[
            ReversibleBlock(hidden_channels, kernel_size, input_format)
            for _ in range(depth)
        ],
        GeneralConv(input_format, out_channels, kernel_size, strides, padding)
    )
    return Module(
        *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1))
    )
コード例 #14
0
ファイル: resnet.py プロジェクト: ls-da3m0ns/JAXified_ML
def identityBlock(ks, filters):
    def construct_main(inp_shape):
        return stax.serial(
            Conv(filters[0], (1, 1), strides=(1, 1)),
            BatchNorm(),
            Relu,
            Conv(filters[1], (ks, ks), padding="SAME"),
            BatchNorm(),
            Relu,
            Conv(input_shape[3], (1, 1)),
            BatchNorm(),
        )

    Main = stax.shape_dependent(construct_main)
    return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum,
                       Relu)
コード例 #15
0
ファイル: netjit.py プロジェクト: rl-team-10/rl_project
    def q_network(self):
        #no regression !
        if self.dueling:
            init, apply = stax.serial(
                elementwise(lambda x: x / 10000.0),
                stax.serial(Dense(128), Relu, Dense(64), Relu),  #base layers
                FanOut(2),
                stax.parallel(
                    stax.serial(Dense(32), Relu, Dense(1)),  #state value
                    stax.serial(Dense(32), Relu,
                                Dense(self.num_actions)))  #advantage func
            )

        else:
            init, apply = stax.serial(elementwise(lambda x: x/10000.0), Dense(64), Relu, \
                                      Dense(32), Relu, Dense(self.num_actions))

        return init, apply
コード例 #16
0
ファイル: resnet.py プロジェクト: worldblue0214/bayesian-sde
def CifarBasicBlock(planes,
                    stride=1,
                    option="A",
                    normalization_method=None,
                    use_fixup=False,
                    num_layers=None,
                    w_init=None,
                    actfn=stax.Relu):
    Main = stax.serial(
        FixupBias() if use_fixup else Identity,
        Conv(planes, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             W_init=fixup_init(num_layers) if use_fixup else w_init,
             bias=False),
        maybe_use_normalization(normalization_method),
        FixupBias() if use_fixup else Identity,
        actfn,
        FixupBias() if use_fixup else Identity,
        Conv(planes, (3, 3),
             padding="SAME",
             bias=False,
             W_init=zeros if use_fixup else w_init),
        maybe_use_normalization(normalization_method),
        FixupScale() if use_fixup else Identity,
        FixupBias() if use_fixup else Identity,
    )
    Shortcut = Identity
    if stride > 1:
        if option == "A":
            # For CIFAR10 ResNet paper uses option A.
            Shortcut = stax.serial(
                #  FixupBiast() if use_fixup else Identity,
                LambdaLayer(_shortcut_pad))
        elif option == "B":
            Shortcut = stax.serial(
                FixupBias() if use_fixup else Identity,
                Conv(planes, (1, 1),
                     strides=(stride, stride),
                     W_init=w_init,
                     bias=False),
                maybe_use_normalization(normalization_method))
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       actfn)
コード例 #17
0
ファイル: resnet.py プロジェクト: worldblue0214/bayesian-sde
def BasicBlock(planes,
               stride=1,
               downsample=None,
               base_width=64,
               norm_layer=Identity,
               actfn=stax.Relu):
    if base_width != 64:
        raise ValueError("BasicBlock only supports base_width=64")
    Main = stax.serial(
        Conv(planes, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             bias=False),
        norm_layer,
        actfn,
        Conv(planes, (3, 3), padding="SAME", bias=False),
        norm_layer,
    )
    Shortcut = downsample if downsample is not None else Identity
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       actfn)
コード例 #18
0
ファイル: DL_Agent.py プロジェクト: dmitsov/RL-Dual-Networks
def constructDuelNetwork(n_actions, seed, input_shape):
    advantage_stream = stax.serial(Dense(512), Relu, Dense(n_actions))

    state_function_stream = stax.serial(Dense(512), Relu, Dense(1))
    dueling_architecture = stax.serial(
        elementwise(lambda x: x / 255.0),
        GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)),
        Relu,
        GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)),
        Relu,
        GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)),
        Relu,
        Flatten,
        FanOut(2),
        parallel(advantage_stream, state_function_stream),
    )

    def duelingNetworkMapping(inputs):
        advantage_values = inputs[0]
        state_values = inputs[1]
        advantage_sums = jnp.sum(advantage_values, axis=1)

        advantage_sums = advantage_sums / float(n_actions)
        advantage_sums = advantage_sums.reshape(-1, 1)

        Q_values = state_values + (advantage_values - advantage_sums)

        return Q_values

    duelArchitectureMapping = jit(duelingNetworkMapping)

    ##### Create deep neural net
    model = DDQN(n_actions,
                 input_shape,
                 adam_params,
                 architecture=dueling_architecture,
                 seed=seed,
                 mappingFunction=duelArchitectureMapping)

    return model
コード例 #19
0
def create_pi_net(
    obs_dim: int, action_dim: int, rngkey=jax.random.PRNGKey(0)
) -> TT.Tuple[RT.NNParams, RT.NNParamsFn]:
    pi_init, pi_fn = serial(
        Dense(64, he_normal(), zeros),
        Relu,
        FanOut(2),
        parallel(
            serial(
                Dense(64, he_normal(), zeros),
                Relu,
                Dense(action_dim, he_normal(), zeros),
            ),
            serial(
                Dense(64, he_normal(), zeros),
                Relu,
                Dense(action_dim, he_normal(), zeros),
            ),
        ),
    )
    output_shape, pi_params = pi_init(rngkey, (1, obs_dim))
    pi_fn = jit(pi_fn)
    return pi_params, pi_fn
コード例 #20
0
ファイル: resnet.py プロジェクト: worldblue0214/bayesian-sde
def BottleneckBlock(planes,
                    stride=1,
                    downsample=None,
                    base_width=64,
                    norm_layer=Identity,
                    actfn=stax.Relu):
    width = int(planes * (base_width / 64.))
    Main = stax.serial(
        Conv(width, (1, 1), bias=False),
        norm_layer,
        actfn,
        Conv(width, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             bias=False),
        norm_layer,
        actfn,
        Conv(planes * 4, (1, 1), bias=False),
        norm_layer,
    )
    Shortcut = downsample if downsample is not None else Identity
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       actfn)
コード例 #21
0
ファイル: mnist_vae.py プロジェクト: zhouj/jax
  code_rng, img_rng = random.split(rng)
  logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10)))
  sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits))
  return image_grid(nrow, ncol, sampled_images, (28, 28))

def image_grid(nrow, ncol, imagevecs, imshape):
  """Reshape a stack of image vectors into an image grid for plotting."""
  images = iter(imagevecs.reshape((-1,) + imshape))
  return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1])
                    for _ in range(nrow)]).T


encoder_init, encode = stax.serial(
    Dense(512), Relu,
    Dense(512), Relu,
    FanOut(2),
    stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)),
)

decoder_init, decode = stax.serial(
    Dense(512), Relu,
    Dense(512), Relu,
    Dense(28 * 28),
)


if __name__ == "__main__":
  step_size = 0.001
  num_epochs = 100
  batch_size = 32
  nrow, ncol = 10, 10  # sampled image grid size