Exemplo n.º 1
0
def ResidualBlock(out_channels, kernel_size, stride, padding, input_format):
    double_conv = stax.serial(
        GeneralConv(input_format, out_channels, kernel_size, stride, padding),
        Elu,
        GeneralConv(input_format, out_channels, kernel_size, stride, padding),
    )
    return Module(
        *stax.serial(FanOut(2), stax.parallel(double_conv, Identity), FanInSum)
    )
Exemplo n.º 2
0
def ResNet(
    hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format
):
    residual = stax.serial(
        GeneralConv(input_format, hidden_channels, kernel_size, strides, padding),
        *[
            ResidualBlock(hidden_channels, kernel_size, strides, padding, input_format)
            for _ in range(depth)
        ],
        GeneralConv(input_format, out_channels, kernel_size, strides, padding)
    )
    return Module(
        *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1))
    )
Exemplo n.º 3
0
def RevNet(
    hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format
):
    residual = stax.serial(  #
        Split(input_format[0].lower().index("c")),
        GeneralConv(input_format, hidden_channels, kernel_size, strides, padding),
        *[
            ReversibleBlock(hidden_channels, kernel_size, input_format)
            for _ in range(depth)
        ],
        GeneralConv(input_format, out_channels, kernel_size, strides, padding)
    )
    return Module(
        *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1))
    )
Exemplo n.º 4
0
def ResNet50(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256], strides=(1, 1)),
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        ConvBlock(3, [256, 256, 1024]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        ConvBlock(3, [512, 512, 2048]),
        IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]),
        AvgPool((7, 7)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
Exemplo n.º 5
0
def constructDuelNetwork(n_actions, seed, input_shape):
    advantage_stream = stax.serial(Dense(512), Relu, Dense(n_actions))

    state_function_stream = stax.serial(Dense(512), Relu, Dense(1))
    dueling_architecture = stax.serial(
        elementwise(lambda x: x / 255.0),
        GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)),
        Relu,
        GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)),
        Relu,
        GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)),
        Relu,
        Flatten,
        FanOut(2),
        parallel(advantage_stream, state_function_stream),
    )

    def duelingNetworkMapping(inputs):
        advantage_values = inputs[0]
        state_values = inputs[1]
        advantage_sums = jnp.sum(advantage_values, axis=1)

        advantage_sums = advantage_sums / float(n_actions)
        advantage_sums = advantage_sums.reshape(-1, 1)

        Q_values = state_values + (advantage_values - advantage_sums)

        return Q_values

    duelArchitectureMapping = jit(duelingNetworkMapping)

    ##### Create deep neural net
    model = DDQN(n_actions,
                 input_shape,
                 adam_params,
                 architecture=dueling_architecture,
                 seed=seed,
                 mappingFunction=duelArchitectureMapping)

    return model
Exemplo n.º 6
0
def constructSingleStreamNetwork(n_actions, seed, input_shape):
    single_stream_architecture = stax.serial(
        elementwise(lambda x: x / 255.0),  # normalize
        ### convolutional NN (CNN)
        GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)),
        Relu,
        GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)),
        Relu,
        GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)),
        Relu,
        Flatten,  # flatten output
        Dense(1024),
        Relu,
        Dense(n_actions))

    model = DDQN(n_actions,
                 input_shape,
                 adam_params,
                 architecture=single_stream_architecture,
                 seed=seed)

    return model
Exemplo n.º 7
0
def ResNet(num_classes):
    return stax.serial(
        GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
        BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)),
        convBlock(3, [64, 64, 256]), identityBlock(3, [64, 64]),
        identityBlock(3, [64, 64]), convBlock(3, [128, 128, 512]),
        identityBlock(3, [128, 128]), identityBlock(3, [128, 128]),
        identityBlock(3, [128, 128]), convBlock(3, [256, 256, 1024]),
        identityBlock(3, [256, 256]), identityBlock(3, [256, 256]),
        identityBlock(3, [256, 256]), identityBlock(3, [256, 256]),
        identityBlock(3, [256, 256]), convBlock(3, [512, 512, 2048]),
        identityBlock(3, [512, 512]), identityBlock(3, [512, 512]),
        AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
Exemplo n.º 8
0
def ResNet(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [4, 4, 4], strides=(1, 1)),
        IdentityBlock(3, [4, 4]),
        AvgPool((3, 3)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
Exemplo n.º 9
0
def LeNet5(num_classes):
    return stax.serial(
        GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Conv(16, (5,5), strides = (1,1),padding="SAME"),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Flatten,
        Dense(num_classes*10),
        Dense(num_classes*5),
        Dense(num_classes),
        LogSoftmax
    )
Exemplo n.º 10
0
    def _create_network_architecture(self, action_dim):

        dim_nums = ('NHWC', 'HWIO', 'NHWC')

        initialize_params, predict = stax.serial(
            # elementwise(lambda x: x/255.0),  # normalize
            ### convolutional NN (CNN)
            GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)),
            Relu,
            # GeneralConv(dim_nums, 64, (4,4), strides=(2,2) ),
            # Relu,
            # GeneralConv(dim_nums, 64, (3,3), strides=(1,1) ),
            # Relu,
            Flatten,  # flatten output
            Dense(512),
            Relu,
            Dense(action_dim))

        return initialize_params, predict
Exemplo n.º 11
0
def LeNet5(batch_size, num_particles):
    input_shape = _input_shape(batch_size)
    return make_model(
        stax.serial(
            GeneralConv(('NCHW', 'OIHW', 'NHWC'),
                        out_chan=6,
                        filter_shape=(5, 5),
                        strides=(1, 1),
                        padding="VALID"), Relu,
            MaxPool(window_shape=(2, 2), strides=(2, 2), padding="VALID"),
            Conv(out_chan=16,
                 filter_shape=(5, 5),
                 strides=(1, 1),
                 padding="SAME"), Relu,
            MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"),
            Conv(out_chan=120,
                 filter_shape=(5, 5),
                 strides=(1, 1),
                 padding="VALID"), Relu,
            MaxPool(window_shape=(2, 2),
                    strides=(2, 2), padding="SAME"), Flatten, Dense(84), Relu,
            Dense(10), LogSoftmax), input_shape, num_particles)
Exemplo n.º 12
0
 def MyGeneralConv(*args, **kwargs):
     return GeneralConv(*args, **kwargs)
Exemplo n.º 13
0
def ConvBlock(out_channels, kernel_size, input_format):
    return stax.serial(
        GeneralConv(input_format, out_channels, kernel_size, 1, "SAME"),
        Elu,
        GeneralConv(input_format, out_channels, kernel_size, 1, "SAME"),
    )