def ConvBlock(kernel_size, filters, strides=(2, 2), batchnorm=True, parameterization='standard', nonlin=Relu): ks = kernel_size filters1, filters2, filters3 = filters if parameterization == 'standard': def MyConv(*args, **kwargs): return Conv(*args, **kwargs) elif parameterization == 'ntk': def MyConv(*args, **kwargs): return stax.Conv(*args, **kwargs)[:2] if batchnorm: Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), BatchNorm(), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), BatchNorm(), nonlin, MyConv(filters3, (1, 1)), BatchNorm()) Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides), BatchNorm()) else: Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), nonlin, MyConv(filters3, (1, 1))) Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides)) return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut), FanInSum, nonlin)
def IdentityBlock(kernel_size, filters, batchnorm=True, parameterization='standard', nonlin=Relu): ks = kernel_size filters1, filters2 = filters if parameterization == 'standard': def MyConv(*args, **kwargs): return Conv(*args, **kwargs) elif parameterization == 'ntk': def MyConv(*args, **kwargs): return stax.Conv(*args, **kwargs)[:2] def make_main(input_shape): # the number of output channels depends on the number of input channels if batchnorm: return jax_stax.serial(MyConv(filters1, (1, 1)), BatchNorm(), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), BatchNorm(), nonlin, MyConv(input_shape[3], (1, 1)), BatchNorm()) else: return jax_stax.serial(MyConv(filters1, (1, 1)), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), nonlin, MyConv(input_shape[3], (1, 1))) Main = jax_stax.shape_dependent(make_main) return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Identity), FanInSum, nonlin)
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False, nonlin=Relu, parameterization='standard', order=None): Main = jax_stax.serial( nonlin, MyConv(channels, (3, 3), strides, padding='SAME', parameterization=parameterization, order=order), nonlin, MyConv(channels, (3, 3), padding='SAME', parameterization=parameterization, order=order)) Shortcut = Identity if not channel_mismatch else MyConv( channels, (3, 3), strides, padding='SAME', parameterization=parameterization, order=order) return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut), FanInSum)
def CifarBasicBlockv2(planes, stride=1, option="A", normalization_method=None, use_fixup=False, num_layers=None, w_init=None, actfn=stax.Relu): assert not use_fixup, "nah" Main = stax.serial( maybe_use_normalization(normalization_method), actfn, Conv(planes, (3, 3), strides=(stride, stride), padding="SAME", W_init=w_init, bias=False), maybe_use_normalization(normalization_method), actfn, Conv(planes, (3, 3), padding="SAME", W_init=w_init, bias=False), ) Shortcut = Identity if stride > 1: if option == "A": # For CIFAR10 ResNet paper uses option A. Shortcut = LambdaLayer(_shortcut_pad) elif option == "B": Shortcut = Conv(planes, (1, 1), strides=(stride, stride), W_init=w_init, bias=False) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum)
def ConvBlock(kernel_size, filters, strides=(2, 2)): ks = kernel_size filters1, filters2, filters3 = filters Main = stax.serial(Conv(filters1, (1, 1), strides), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(filters3, (1, 1)), BatchNorm()) Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
def ResidualBlock(out_channels, kernel_size, stride, padding, input_format): double_conv = stax.serial( GeneralConv(input_format, out_channels, kernel_size, stride, padding), Elu, GeneralConv(input_format, out_channels, kernel_size, stride, padding), ) return Module( *stax.serial(FanOut(2), stax.parallel(double_conv, Identity), FanInSum) )
def PolicyNetwork(): """Policy network for the experiments in: https://arxiv.org/abs/2102.12425""" return serial( helx.nn.rnn.LSTM(256), Dense(256), Relu, FanOut(2), parallel(Dense(1), Dense(1)), )
def IdentityBlock(kernel_size, filters): ks = kernel_size filters1, filters2 = filters def make_main(input_shape): # the number of output channels depends on the number of input channels return stax.serial( Conv(filters1, (1, 1)), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm()) Main = stax.shape_dependent(make_main) return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
def Lpg(hparams): phi = serial(Dense(16), Dense(1)) return serial( # FanOut(6), parallel(Identity, Identity, Identity, Identity, phi, phi), FanInConcat(), LSTMCell(hparams.hidden_size)[0:2], DiscardHidden(), Relu, FanOut(2), parallel(phi, phi), )
def SyntheticReturn(features_network): """Synthetic return module as described in: https://arxiv.org/abs/2102.12425, Raposo, D., Synthetic Returns for Long-Term Credit Assignment, 2021.""" # sigmoid gate g = lambda: serial(Dense(256), Relu, Dense(1), Relu, Dense(1), Sigmoid) # state utility contribution c = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1)) # state utility baseline b = lambda: serial(Dense(256), Relu, Dense(256), Relu, Dense(1)) return serial(features_network, Flatten, FanOut(3), parallel(g(), c(), b()))
def ResNet( hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format ): residual = stax.serial( GeneralConv(input_format, hidden_channels, kernel_size, strides, padding), *[ ResidualBlock(hidden_channels, kernel_size, strides, padding, input_format) for _ in range(depth) ], GeneralConv(input_format, out_channels, kernel_size, strides, padding) ) return Module( *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1)) )
def convBlock(ks, filters, stride=(1, 1)): Main = stax.serial(Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(), Relu, Conv(filters[1], (ks, ks), strides=stride), BatchNorm(), Relu, Conv(filters[2], (1, 1), strides=(1, 1)), BatchNorm(), Relu) Shortcut = stax.serial( Conv(filters[3], (1, 1), strides=stride), BatchNorm(), ) fullInternal = stax.parallel(Main, Shortcut) return stax.serial(FanOut(2), fullInternal, FanInSum, Relu)
def RevNet( hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format ): residual = stax.serial( # Split(input_format[0].lower().index("c")), GeneralConv(input_format, hidden_channels, kernel_size, strides, padding), *[ ReversibleBlock(hidden_channels, kernel_size, input_format) for _ in range(depth) ], GeneralConv(input_format, out_channels, kernel_size, strides, padding) ) return Module( *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1)) )
def identityBlock(ks, filters): def construct_main(inp_shape): return stax.serial( Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(), Relu, Conv(filters[1], (ks, ks), padding="SAME"), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm(), ) Main = stax.shape_dependent(construct_main) return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
def q_network(self): #no regression ! if self.dueling: init, apply = stax.serial( elementwise(lambda x: x / 10000.0), stax.serial(Dense(128), Relu, Dense(64), Relu), #base layers FanOut(2), stax.parallel( stax.serial(Dense(32), Relu, Dense(1)), #state value stax.serial(Dense(32), Relu, Dense(self.num_actions))) #advantage func ) else: init, apply = stax.serial(elementwise(lambda x: x/10000.0), Dense(64), Relu, \ Dense(32), Relu, Dense(self.num_actions)) return init, apply
def CifarBasicBlock(planes, stride=1, option="A", normalization_method=None, use_fixup=False, num_layers=None, w_init=None, actfn=stax.Relu): Main = stax.serial( FixupBias() if use_fixup else Identity, Conv(planes, (3, 3), strides=(stride, stride), padding="SAME", W_init=fixup_init(num_layers) if use_fixup else w_init, bias=False), maybe_use_normalization(normalization_method), FixupBias() if use_fixup else Identity, actfn, FixupBias() if use_fixup else Identity, Conv(planes, (3, 3), padding="SAME", bias=False, W_init=zeros if use_fixup else w_init), maybe_use_normalization(normalization_method), FixupScale() if use_fixup else Identity, FixupBias() if use_fixup else Identity, ) Shortcut = Identity if stride > 1: if option == "A": # For CIFAR10 ResNet paper uses option A. Shortcut = stax.serial( # FixupBiast() if use_fixup else Identity, LambdaLayer(_shortcut_pad)) elif option == "B": Shortcut = stax.serial( FixupBias() if use_fixup else Identity, Conv(planes, (1, 1), strides=(stride, stride), W_init=w_init, bias=False), maybe_use_normalization(normalization_method)) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, actfn)
def BasicBlock(planes, stride=1, downsample=None, base_width=64, norm_layer=Identity, actfn=stax.Relu): if base_width != 64: raise ValueError("BasicBlock only supports base_width=64") Main = stax.serial( Conv(planes, (3, 3), strides=(stride, stride), padding="SAME", bias=False), norm_layer, actfn, Conv(planes, (3, 3), padding="SAME", bias=False), norm_layer, ) Shortcut = downsample if downsample is not None else Identity return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, actfn)
def constructDuelNetwork(n_actions, seed, input_shape): advantage_stream = stax.serial(Dense(512), Relu, Dense(n_actions)) state_function_stream = stax.serial(Dense(512), Relu, Dense(1)) dueling_architecture = stax.serial( elementwise(lambda x: x / 255.0), GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)), Relu, GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)), Relu, GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)), Relu, Flatten, FanOut(2), parallel(advantage_stream, state_function_stream), ) def duelingNetworkMapping(inputs): advantage_values = inputs[0] state_values = inputs[1] advantage_sums = jnp.sum(advantage_values, axis=1) advantage_sums = advantage_sums / float(n_actions) advantage_sums = advantage_sums.reshape(-1, 1) Q_values = state_values + (advantage_values - advantage_sums) return Q_values duelArchitectureMapping = jit(duelingNetworkMapping) ##### Create deep neural net model = DDQN(n_actions, input_shape, adam_params, architecture=dueling_architecture, seed=seed, mappingFunction=duelArchitectureMapping) return model
def create_pi_net( obs_dim: int, action_dim: int, rngkey=jax.random.PRNGKey(0) ) -> TT.Tuple[RT.NNParams, RT.NNParamsFn]: pi_init, pi_fn = serial( Dense(64, he_normal(), zeros), Relu, FanOut(2), parallel( serial( Dense(64, he_normal(), zeros), Relu, Dense(action_dim, he_normal(), zeros), ), serial( Dense(64, he_normal(), zeros), Relu, Dense(action_dim, he_normal(), zeros), ), ), ) output_shape, pi_params = pi_init(rngkey, (1, obs_dim)) pi_fn = jit(pi_fn) return pi_params, pi_fn
def BottleneckBlock(planes, stride=1, downsample=None, base_width=64, norm_layer=Identity, actfn=stax.Relu): width = int(planes * (base_width / 64.)) Main = stax.serial( Conv(width, (1, 1), bias=False), norm_layer, actfn, Conv(width, (3, 3), strides=(stride, stride), padding="SAME", bias=False), norm_layer, actfn, Conv(planes * 4, (1, 1), bias=False), norm_layer, ) Shortcut = downsample if downsample is not None else Identity return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, actfn)
code_rng, img_rng = random.split(rng) logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10))) sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits)) return image_grid(nrow, ncol, sampled_images, (28, 28)) def image_grid(nrow, ncol, imagevecs, imshape): """Reshape a stack of image vectors into an image grid for plotting.""" images = iter(imagevecs.reshape((-1,) + imshape)) return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1]) for _ in range(nrow)]).T encoder_init, encode = stax.serial( Dense(512), Relu, Dense(512), Relu, FanOut(2), stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)), ) decoder_init, decode = stax.serial( Dense(512), Relu, Dense(512), Relu, Dense(28 * 28), ) if __name__ == "__main__": step_size = 0.001 num_epochs = 100 batch_size = 32 nrow, ncol = 10, 10 # sampled image grid size