示例#1
0
def ConvBlock(kernel_size, filters, strides=(2, 2)):
    ks = kernel_size
    filters1, filters2, filters3 = filters
    Main = stax.serial(Conv(filters1, (1, 1), strides), BatchNorm(), Relu,
                       Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(),
                       Relu, Conv(filters3, (1, 1)), BatchNorm())
    Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       Relu)
示例#2
0
def DeepQNetwork():
    init_fun, predict_fun = stax.serial(
        Conv(16, (8, 8), strides=(4, 4)), Relu,
        Conv(32, (4, 4), strides=(2, 2)), Relu,
        Conv(64, (3, 3)), Relu,
        Flatten,
        Dense(256), Relu,
        Dense(6)
    )
    return init_fun, predict_fun
示例#3
0
 def construct_main(inp_shape):
     return stax.serial(
         Conv(filters[0], (1, 1), strides=(1, 1)),
         BatchNorm(),
         Relu,
         Conv(filters[1], (ks, ks), padding="SAME"),
         BatchNorm(),
         Relu,
         Conv(input_shape[3], (1, 1)),
         BatchNorm(),
     )
示例#4
0
 def make_main(input_shape):
     # the number of output channels depends on the number of input channels
     return stax.serial(
         Conv(filters1, (1, 1)),
         BatchNorm(),
         Relu,
         Conv(filters2, (ks, ks), padding="SAME"),
         BatchNorm(),
         Relu,
         Conv(input_shape[3], (1, 1)),
         BatchNorm(),
     )
示例#5
0
 def state_encoder(output_num):
     return serial(
         Conv(4, (3, 3), (1, 1), "SAME"),
         Tanh,  # BatchNorm(),
         Conv(4, (3, 3), (1, 1), "SAME"),
         Tanh,  # BatchNorm(),
         Conv(4, (3, 3), (1, 1), "SAME"),
         Tanh,  # BatchNorm(),
         Flatten,
         Dense(128),
         Tanh,  # BatchNormつけるとなぜか出力が固定値になる,
         Dense(output_num))
示例#6
0
def CnnSmall():
    """CNN used for the Catch and Key-to-Door experiments in:
    https://arxiv.org/abs/2102.12425"""
    return serial(
        Conv(32, (2, 2), (1, 1), "VALID"),
        Relu,
        Conv(64, (2, 2), (1, 1), "VALID"),
        Relu,
        Flatten,
        Dense(256),
        Relu,
    )
示例#7
0
文件: dqn.py 项目: epignatelli/helx
def Cnn(n_actions: int, hidden_size) -> Module:
    return serial(
        Conv(32, (8, 8), (4, 4), "VALID"),
        Relu,
        Conv(64, (4, 4), (2, 2), "VALID"),
        Relu,
        Conv(64, (3, 3), (1, 1), "VALID"),
        Relu,
        Flatten,
        Dense(hidden_size),
        Relu,
        Dense(n_actions),
    )
示例#8
0
def CnnLarge():
    """CNN used for the Pong and Skiing experiments in:
    https://arxiv.org/abs/2102.12425"""
    return serial(
        Conv(32, (3, 3), (2, 2), "VALID"),
        Relu,
        Conv(64, (3, 3), (2, 2), "VALID"),
        Relu,
        Conv(64, (3, 3), (2, 2), "VALID"),
        Relu,
        Flatten,
        Dense(256),
        Relu,
    )
示例#9
0
def init_conv_affine_coupling(rng,
                              in_shape,
                              n_channels,
                              flip,
                              sigmoid=True,
                              init_batch=None):
    """
    in_shape: tuple of (h, w, c)
    """
    h, w, c = in_shape
    assert c % 2 == 0, "channels must be even doooooooog!"
    half_c = c // 2
    net_init, net_apply = stax.serial(Conv(n_channels, (3, 3),
                                           padding="SAME"), Relu,
                                      Conv(n_channels, (3, 3), padding="SAME"),
                                      Relu, Conv(c, (3, 3), padding="SAME"))
    _, net_params = net_init(rng, (-1, h, w, half_c))

    def shift_and_log_scale_fn(net_params, x1):
        s = net_apply(net_params, x1)
        return np.split(s, 2, axis=3)

    def conv_coupling_forward(net_params, prev_sample, prev_logp=0.):
        x1, x2 = prev_sample[:, :, :, :half_c], prev_sample[:, :, :, half_c:]
        if flip:
            x2, x1 = x1, x2
        shift, log_scale = shift_and_log_scale_fn(net_params, x1)
        if sigmoid:
            log_scale = log_sigmoid(log_scale + 2.)
        y2 = x2 * np.exp(log_scale) + shift
        if flip:
            x1, y2 = y2, x1
        y = np.concatenate([x1, y2], axis=-1)
        return y, prev_logp + np.sum(log_scale, axis=(1, 2, 3))

    def conv_coupling_reverse(net_params, next_sample, next_logp=0.):
        y1, y2 = next_sample[:, :, :, :half_c], next_sample[:, :, :, half_c:]
        if flip:
            y1, y2 = y2, y1
        shift, log_scale = shift_and_log_scale_fn(net_params, y1)
        if sigmoid:
            log_scale = log_sigmoid(log_scale + 2.)
        x2 = (y2 - shift) * np.exp(-log_scale)
        if flip:
            y1, x2 = x2, y1
        x = np.concatenate([y1, x2], axis=-1)
        return x, next_logp - np.sum(log_scale, axis=(1, 2, 3))

    return net_params, conv_coupling_forward, conv_coupling_reverse
示例#10
0
def convBlock(ks, filters, stride=(1, 1)):
    Main = stax.serial(Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(),
                       Relu, Conv(filters[1], (ks, ks), strides=stride),
                       BatchNorm(), Relu,
                       Conv(filters[2], (1, 1),
                            strides=(1, 1)), BatchNorm(), Relu)

    Shortcut = stax.serial(
        Conv(filters[3], (1, 1), strides=stride),
        BatchNorm(),
    )

    fullInternal = stax.parallel(Main, Shortcut)

    return stax.serial(FanOut(2), fullInternal, FanInSum, Relu)
示例#11
0
def feature_extractor(rng, dim):
  """Feature extraction network."""
  init_params, forward = stax.serial(
    Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
    Relu,
    MaxPool((2, 2), (1, 1)),
    Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
    Relu,
    MaxPool((2, 2), (1, 1)),
    Flatten,
    Dense(dim),
  )
  temp, rng = random.split(rng)
  params = init_params(temp, (-1, 28, 28, 1))[1]
  return params, forward
def init_nn():
    """
    Initialize Stax model. This function can be customized as needed to define architecture
    """
    layers = [
        Conv(16, (3, 3)),
        Relu,
        Conv(16, (3, 3)),
        Relu,
        Flatten,
        Dense(10),
        LogSoftmax,
    ]

    return stax.serial(*layers)
示例#13
0
def conv():
    init_fun, predict = stax.serial(
        Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        Relu,
        MaxPool((2, 2), (1, 1)),
        Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        Relu,
        MaxPool((2, 2), (1, 1)),
        Flatten,
        Dense(32),
        Relu,
        Dense(10),
    )

    def init_params(rng):
        return init_fun(rng, (-1, 28, 28, 1))[1]

    return init_params, predict
示例#14
0
def LeNet5(batch_size, num_particles):
    input_shape = _input_shape(batch_size)
    return make_model(
        stax.serial(
            GeneralConv(('NCHW', 'OIHW', 'NHWC'),
                        out_chan=6,
                        filter_shape=(5, 5),
                        strides=(1, 1),
                        padding="VALID"), Relu,
            MaxPool(window_shape=(2, 2), strides=(2, 2), padding="VALID"),
            Conv(out_chan=16,
                 filter_shape=(5, 5),
                 strides=(1, 1),
                 padding="SAME"), Relu,
            MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"),
            Conv(out_chan=120,
                 filter_shape=(5, 5),
                 strides=(1, 1),
                 padding="VALID"), Relu,
            MaxPool(window_shape=(2, 2),
                    strides=(2, 2), padding="SAME"), Flatten, Dense(84), Relu,
            Dense(10), LogSoftmax), input_shape, num_particles)
示例#15
0
文件: layers.py 项目: lilujunai/dnet
 def __init__(self,
              filters: int,
              kernel_size: Tuple[int, int],
              strides: Tuple[int, int] = (1, 1),
              padding: str = "valid",
              activation: str = "linear") -> None:
     layer_activation: Tuple[Callable,
                             Callable] = getattr(activations, activation)()
     self.layer: List = [
         Conv(out_chan=filters,
              filter_shape=kernel_size,
              strides=strides,
              padding=padding.upper()), layer_activation
     ]
示例#16
0
def LeNet5(num_classes):
    return stax.serial(
        GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Conv(16, (5,5), strides = (1,1),padding="SAME"),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Flatten,
        Dense(num_classes*10),
        Dense(num_classes*5),
        Dense(num_classes),
        LogSoftmax
    )
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(
    Conv(10, (5, 5), (1, 1)), Activator,
    MaxPool((4, 4)), Flatten,
    Dense(10), LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)
    
    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    # input shape for CNN
    input_shape = (-1, 28, 28, 1)
    
    # training/test split
示例#18
0

def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(Conv(10, (5, 5), (1, 1)), Activator,
                                          MaxPool((4, 4)), Flatten, Dense(24),
                                          LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)

    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    # input shape for CNN
    input_shape = (-1, 28, 28, 1)

    # training/test split
示例#19
0
 def MyConv(*args, **kwargs):
     return Conv(*args, **kwargs)
示例#20
0
def conv_net(mode="train"):
    out_dim = 1
    dim_nums = ("NHWC", "HWIO", "NHWC")
    unit_stride = (1,1)
    zero_pad = ((0,0), (0,0))

    # Primary convolutional layer.
    conv_channels = 32
    conv_init, conv_apply = Conv(out_chan=conv_channels, filter_shape=(3,3),
                                 strides=(1,3), padding=zero_pad)
    # Group all possible pairs.
    pair_channels, filter_shape = 256, (1, 2)

    # Convolutional block with the same number of channels.
    block_channels = pair_channels
    conv_block_init, conv_block_apply = serial(Conv(block_channels, (1,3), unit_stride, "SAME"), Relu,  # One block of convolutions.
                                               Conv(block_channels, (1,3), unit_stride, "SAME"), Relu,          
                                               Conv(block_channels, (1,3), unit_stride, "SAME"))
    # Forward pass.
    hidden_size = 2048
    dropout_rate = 0.25
    serial_init, serial_apply = serial(Conv(block_channels, (1,3), (1, 3), zero_pad), Relu,     # Using convolution with strides
                                       Flatten, Dense(hidden_size),                             # instead of pooling for downsampling.
                                    #    Dropout(dropout_rate, mode),
                                       Relu, Dense(out_dim))

    def init_fun(rng, input_shape):
        rng, conv_rng, block_rng, serial_rng = jax.random.split(rng, num=4)

        # Primary convolutional layer.
        conv_shape, conv_params = conv_init(conv_rng, (-1,) + input_shape)

        # Grouping all possible pairs.
        kernel_shape = [filter_shape[0], filter_shape[1], conv_channels, pair_channels]
        bias_shape = [1, 1, 1, pair_channels]
        W_init = glorot_normal(in_axis=2, out_axis=3)
        b_init = normal(1e-6)
        k1, k2 = jax.random.split(rng)
        W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
        pair_shape = conv_shape[:2] + (15,) + (pair_channels,)
        pair_params = (W, b)

        # Convolutional block.
        conv_block_shape, conv_block_params = conv_block_init(block_rng, pair_shape)

        # Forward pass.
        serial_shape, serial_params = serial_init(serial_rng, conv_block_shape)
        params = [conv_params, pair_params, conv_block_params, serial_params]
        return serial_shape, params

    def apply_fun(params, inputs):
        conv_params, pair_params, conv_block_params, serial_params = params

        # Apply the primary convolutional layer.
        conv_out = conv_apply(conv_params, inputs)
        conv_out = relu(conv_out)

        # Group all possible pairs.
        W, b = pair_params
        pair_1 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,1), dim_nums) + b
        pair_2 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,2), dim_nums) + b
        pair_3 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,3), dim_nums) + b
        pair_4 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,4), dim_nums) + b
        pair_5 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,5), dim_nums) + b
        pair_out = jnp.dstack([pair_1, pair_2, pair_3, pair_4, pair_5])
        pair_out = relu(pair_out)

        # Convolutional block.
        conv_block_out = conv_block_apply(conv_block_params, pair_out)

        # Residual connection.
        res_out = conv_block_out + pair_out
        res_out = relu(res_out)

        # Forward pass.
        out = serial_apply(serial_params, res_out)
        return out

    return init_fun, apply_fun
示例#21
0
# batches = synth_batches()
# inputs = next(batches)

# train_images, labels, _, _ = mnist(permute_train=True, resize=True)
# del _
# inputs = train_images[:data_size]
#
# del train_images

# u, s, v_t = onp.linalg.svd(inputs, full_matrices=False)
# I = np.eye(v_t.shape[-1])
# I_add = npr.normal(0.0, 0.002, size=I.shape)
# noisy_I = I + I_add

init_fun, conv_net = stax.serial(
    Conv(32, (5, 5), (2, 2), padding="SAME"),
    BatchNorm(),
    Relu,
    Conv(10, (3, 3), (2, 2), padding="SAME"),
    Relu,
    Flatten,
    Dense(num_classes),
    LogSoftmax,
)
_, key = random.split(random.PRNGKey(0))


class DataTopologyAE(AbstractProblem):
    def __init__(self):
        self.HPARAMS_PATH = "hparams.json"