Пример #1
0
def preact_resnet_symbol(channels=256, channels_value_head=8,
                   channels_policy_head=81, value_fc_size=256, value_kernelsize=7, res_blocks=19, act_type='relu',
                   n_labels=4992, grad_scale_value=0.01, grad_scale_policy=0.99, select_policy_from_plane=True):
    """
    Creates the alpha zero model symbol based on the given parameters.

    :param channels: Used for all convolution operations. (Except the last 2)
    :param workspace: Parameter for convolution
    :param value_fc_size: Fully Connected layer size. Used for the value output
    :param num_res_blocks: Number of residual blocks to stack. In the paper they used 19 or 39 residual blocks
    :param bn_mom: batch normalization momentum
    :param act_type: Activation function which will be used for all intermediate layers
    :param n_labels: Number of labels the for the policy
    :param grad_scale_value: Constant scalar which the gradient for the value outputs are being scaled width.
                            (They used 1.0 for default and 0.01 in the supervised setting)
    :param grad_scale_policy: Constant scalar which the gradient for the policy outputs are being scaled width.
                            (They used 1.0 for default and 0.99 in the supervised setting)
    :return: mxnet symbol of the model
    """
    # get the input data
    data = mx.sym.Variable(name='data')

    body = get_stem(data=data, channels=channels, act_type=act_type)

    for idx in range(res_blocks):
        body = preact_residual_block(body, channels, name='res_block%d' % idx, kernel=3,
                              act_type=act_type)

    body = mx.sym.BatchNorm(data=body, name='stem_bn1')
    body = get_act(data=body, act_type=act_type, name='stem_act1')

    # for policy output
    policy_out = policy_head(data=body, channels=channels, act_type=act_type, channels_policy_head=channels_policy_head,
                             select_policy_from_plane=select_policy_from_plane, n_labels=n_labels,
                             grad_scale_policy=grad_scale_policy, use_se=False, no_bias=True)

    # for value output
    value_out = value_head(data=body, channels_value_head=channels_value_head, value_kernelsize=1, act_type=act_type,
                           value_fc_size=value_fc_size, grad_scale_value=grad_scale_value, use_se=False,
                           use_mix_conv=False)

    # group value_out and policy_out together
    sym = mx.symbol.Group([value_out, policy_out])

    return sym
Пример #2
0
def rise_mobile_v3_symbol(channels=256, channels_operating_init=128, channel_expansion=64, act_type='relu',
                          channels_value_head=32, channels_policy_head=81, value_fc_size=128, dropout_rate=0.15,
                          select_policy_from_plane=True, use_se=True, res_blocks=13, n_labels=4992):
    """
    RISEv3 architecture
    :param channels: Main number of channels
    :param channels_operating_init: Initial number of channels at the start of the net for the depthwise convolution
    :param channel_expansion: Number of channels to add after each residual block
    :param act_type: Activation type to use
    :param channels_value_head: Number of channels for the value head
    :param value_fc_size: Number of units in the fully connected layer of the value head
    :param channels_policy_head: Number of channels for the policy head
    :param dropout_rate: Droput factor to use. If 0, no dropout will be applied. Value must be in [0,1]
    :param select_policy_from_plane: True, if policy head type shall be used
    :param use_se: Indicates if a squeeze excitation layer shall be used
    :param res_blocks: Number of residual blocks
    :param n_labels: Number of policy target labels (used for select_policy_from_plane=False)
    :return: symbol
    """
    # get the input data
    data = mx.sym.Variable(name='data')

    data = get_stem(data=data, channels=channels, act_type=act_type)

    cur_channels = channels_operating_init

    kernels = [
        [3],  # 0
        [3],  # 1
        [3, 5],  # 2
        [3, 5],  # 3
        [3, 5, 7, 9],  # 4
        [3, 5],  # 5
        [3, 5],  # 6
        [3, 5],  # 7
        [3, 5],  # 8
        [3, 5],  # 9
        [3, 5],  # 10
        [3, 5],  # 11
        [3, 5],  # 12
    ]
    for idx in range(res_blocks):

        cur_kernels = kernels[idx]
        if idx == 4 or idx >= 9:
            use_se = True
        else:
            use_se = False
        data = preact_residual_dmixconv_block(data=data, channels=channels, channels_operating=cur_channels,
                                              kernels=cur_kernels, name='dconv_%d' % idx, use_se=use_se)
        cur_channels += channel_expansion
    # return data
    data = mx.sym.BatchNorm(data=data, name='stem_bn1')
    data = get_act(data=data, act_type=act_type, name='stem_act1')

    if dropout_rate != 0:
        data = mx.sym.Dropout(data, p=dropout_rate)

    value_out = value_head(data=data, act_type=act_type, use_se=use_se, channels_value_head=channels_value_head,
                           value_fc_size=value_fc_size, use_mix_conv=True)
    policy_out = policy_head(data=data, act_type=act_type, channels_policy_head=channels_policy_head, n_labels=n_labels,
                             select_policy_from_plane=select_policy_from_plane, use_se=False, channels=channels)
    # group value_out and policy_out together
    sym = mx.symbol.Group([value_out, policy_out])

    return sym
Пример #3
0
def rise_mobile_v2_symbol(channels=256,
                          channels_operating_init=128,
                          channel_expansion=64,
                          channels_value_head=8,
                          channels_policy_head=81,
                          value_fc_size=256,
                          bc_res_blocks=[],
                          res_blocks=[3] * 13,
                          act_type='relu',
                          n_labels=4992,
                          grad_scale_value=0.01,
                          grad_scale_policy=0.99,
                          select_policy_from_plane=True,
                          use_se=True,
                          dropout_rate=0,
                          use_extra_variant_input=False):
    """
    Creates the rise mobile model symbol based on the given parameters.

    :param channels: Used for all convolution operations. (Except the last 2)
    :param workspace: Parameter for convolution
    :param value_fc_size: Fully Connected layer size. Used for the value output
    :param num_res_blocks: Number of residual blocks to stack. In the paper they used 19 or 39 residual blocks
    :param bn_mom: batch normalization momentum
    :param act_type: Activation function which will be used for all intermediate layers
    :param n_labels: Number of labels the for the policy
    :param grad_scale_value: Constant scalar which the gradient for the value outputs are being scaled width.
                            (They used 1.0 for default and 0.01 in the supervised setting)
    :param grad_scale_policy: Constant scalar which the gradient for the policy outputs are being scaled width.
                            (They used 1.0 for default and 0.99 in the supervised setting)
    :param dropout_rate: Applies optionally droput during learning with a given factor on the last feature space before
    :param use_extra_variant_input: If true, the last 9 channel which represent the active variant are passed to each
    residual block separately and concatenated at the end of the final feature representation
    branching into value and policy head
    :return: mxnet symbol of the model
    """
    # get the input data
    data = mx.sym.Variable(name='data')

    if use_extra_variant_input:
        data_variant = extract_variant_info(data,
                                            channels=NB_CHANNELS_TOTAL,
                                            name="variants")
    else:
        data_variant = None

    # first initial convolution layer followed by batchnormalization
    body = get_stem(data=data, channels=channels, act_type=act_type)
    channels_operating = channels_operating_init

    # build residual tower
    for idx, kernel in enumerate(bc_res_blocks):
        use_squeeze_excitation = use_se

        if idx < len(bc_res_blocks) - 5:
            use_squeeze_excitation = False
        body = bottleneck_residual_block(body,
                                         channels,
                                         channels_operating,
                                         name='bc_res_block%d' % idx,
                                         kernel=kernel,
                                         use_se=use_squeeze_excitation,
                                         act_type=act_type,
                                         data_variant=data_variant)
        channels_operating += channel_expansion

    for idx, kernel in enumerate(res_blocks):
        if idx < len(res_blocks) - 5:
            use_squeeze_excitation = False
        else:
            use_squeeze_excitation = use_se

        body = residual_block(body,
                              channels,
                              name='res_block%d' % idx,
                              kernel=kernel,
                              use_se=use_squeeze_excitation,
                              act_type=act_type)

    if dropout_rate != 0:
        body = mx.sym.Dropout(body, p=dropout_rate)

    if use_extra_variant_input:
        body = mx.sym.Concat(*[body, data_variant], name='feature_concat')

    # for policy output
    policy_out = policy_head(data=body,
                             channels=channels,
                             act_type=act_type,
                             channels_policy_head=channels_policy_head,
                             select_policy_from_plane=select_policy_from_plane,
                             n_labels=n_labels,
                             grad_scale_policy=grad_scale_policy,
                             use_se=False,
                             no_bias=True)

    # for value output
    value_out = value_head(data=body,
                           channels_value_head=channels_value_head,
                           value_kernelsize=1,
                           act_type=act_type,
                           value_fc_size=value_fc_size,
                           grad_scale_value=grad_scale_value,
                           use_se=False,
                           use_mix_conv=False)

    # group value_out and policy_out together
    sym = mx.symbol.Group([value_out, policy_out])

    return sym
Пример #4
0
def mixture_net_symbol(channels=256,
                       num_res_blocks=7,
                       act_type='relu',
                       channels_value_head=8,
                       channels_policy_head=81,
                       value_fc_size=256,
                       grad_scale_value=0.01,
                       grad_scale_policy=0.99,
                       select_policy_from_plane=True,
                       n_labels=4992):
    """
    Mixture net
    :param channels: Main number of channels
    :param channels_operating_init: Initial number of channels at the start of the net for the depthwise convolution
    :param channel_expansion: Number of channels to add after each residual block
    :param act_type: Activation type to use
    :param channels_value_head: Number of channels for the value head
    :param value_fc_size: Number of units in the fully connected layer of the value head
    :param channels_policy_head: Number of channels for the policy head
    :param dropout_rate: Droput factor to use. If 0, no dropout will be applied. Value must be in [0,1]
    :param grad_scale_value: Constant scalar which the gradient for the value outputs are being scaled width.
                            (0.01 is recommended for supervised learning with little data)
    :param grad_scale_policy: Constant scalar which the gradient for the policy outputs are being scaled width.
    :param select_policy_from_plane: True, if policy head type shall be used
    :param use_se: Indicates if a squeeze excitation layer shall be used
    :param res_blocks: Number of residual blocks
    :param n_labels: Number of policy target labels (used for select_policy_from_plane=False)
    :return: symbol
    """
    # get the input data
    data = mx.sym.Variable(name='data')

    data = get_stem(data=data, channels=channels, act_type=act_type)

    nb_subnets = 3
    bodies = [None] * nb_subnets

    # build residual tower
    for z in range(len(bodies)):
        for i in range(num_res_blocks):
            if i == 0:
                bodies[z] = residual_block(data,
                                           channels,
                                           name='b%d_block%d' % (z, i),
                                           bn_mom=0.9,
                                           workspace=1024)
            else:
                bodies[z] = residual_block(bodies[z],
                                           channels,
                                           name='b%d_block%d' % (z, i),
                                           bn_mom=0.9,
                                           workspace=1024)

    w_a = mx.sym.Variable('w_a', init=mx.init.Constant(1 / 3))
    w_b = mx.sym.Variable('w_b', init=mx.init.Constant(1 / 3))
    w_c = mx.sym.Variable('w_c', init=mx.init.Constant(1 / 3))

    data_value = w_a * bodies[0] + w_b * bodies[1] + w_c * bodies[2]
    data_policy = w_a * bodies[0] + w_b * bodies[1] + w_c * bodies[2]

    value_out = value_head(data=data_value,
                           act_type=act_type,
                           use_se=False,
                           channels_value_head=channels_value_head,
                           value_fc_size=value_fc_size,
                           use_mix_conv=False,
                           grad_scale_value=grad_scale_value)
    policy_out = policy_head(data=data_policy,
                             act_type=act_type,
                             channels_policy_head=channels_policy_head,
                             n_labels=n_labels,
                             select_policy_from_plane=select_policy_from_plane,
                             use_se=False,
                             channels=channels,
                             grad_scale_policy=grad_scale_policy)
    # group value_out and policy_out together
    sym = mx.symbol.Group([value_out, policy_out])

    return sym
Пример #5
0
def rise_mobile_v3_symbol(channels=256,
                          channels_operating_init=128,
                          channel_expansion=64,
                          act_type='relu',
                          channels_value_head=8,
                          channels_policy_head=81,
                          value_fc_size=256,
                          dropout_rate=0.15,
                          grad_scale_value=0.01,
                          grad_scale_policy=0.99,
                          select_policy_from_plane=True,
                          kernels=None,
                          n_labels=4992,
                          se_ratio=4,
                          se_types="se"):
    """
    RISEv3 architecture
    :param channels: Main number of channels
    :param channels_operating_init: Initial number of channels at the start of the net for the depthwise convolution
    :param channel_expansion: Number of channels to add after each residual block
    :param act_type: Activation type to use
    :param channels_value_head: Number of channels for the value head
    :param value_fc_size: Number of units in the fully connected layer of the value head
    :param channels_policy_head: Number of channels for the policy head
    :param dropout_rate: Droput factor to use. If 0, no dropout will be applied. Value must be in [0,1]
    :param grad_scale_value: Constant scalar which the gradient for the value outputs are being scaled width.
                            (0.01 is recommended for supervised learning with little data)
    :param grad_scale_policy: Constant scalar which the gradient for the policy outputs are being scaled width.
    :param select_policy_from_plane: True, if policy head type shall be used
    :param kernels: List of kernel sizes used for the residual blocks. The length of the list corresponds to the number
    of residual blocks.
    :param n_labels: Number of policy target labels (used for select_policy_from_plane=False)
    :param se_ratio: Reduction ration used in the squeeze excitation module
    :param se_types: List of squeeze exciation modules to use for each residual layer.
     The length of this list must be the same as len(kernels). Available types:
    - "se": Squeeze excitation block - Hu et al. - https://arxiv.org/abs/1709.01507
    - "cbam": Convolutional Block Attention Module (CBAM) - Woo et al. - https://arxiv.org/pdf/1807.06521.pdf
    - "ca_se": Same as "se"
    - "cm_se": Squeeze excitation with max operator
    - "sa_se": Spatial excitation with average operator
    - "sm_se": Spatial excitation with max operator
    :return: symbol
    """
    if len(kernels) != len(se_types):
        raise Exception(
            f'The length of "kernels": {len(kernels)} must be the same as'
            f' the length of "se_types": {len(se_types)}')

    valid_se_types = [None, "se", "cbam", "ca_se", "cm_se", "sa_se", "sm_se"]
    for se_type in se_types:
        if se_type not in valid_se_types:
            raise Exception(
                f"Unavailable se_type: {se_type}. Available se_types include {se_types}"
            )

    # get the input data
    data = mx.sym.Variable(name='data')

    data = get_stem(data=data, channels=channels, act_type=act_type)

    if kernels is None:
        kernels = [3] * 13

    cur_channels = channels_operating_init

    for idx, cur_kernels in enumerate(kernels):

        data = preact_residual_dmixconv_block(data=data,
                                              channels=channels,
                                              channels_operating=cur_channels,
                                              kernels=cur_kernels,
                                              name='dconv_%d' % idx,
                                              se_ratio=se_ratio,
                                              se_type=se_types[idx])
        cur_channels += channel_expansion

    data = mx.sym.BatchNorm(data=data, name='stem_bn1')
    data = get_act(data=data, act_type=act_type, name='stem_act1')

    if dropout_rate != 0:
        data = mx.sym.Dropout(data, p=dropout_rate)

    value_out = value_head(data=data,
                           act_type=act_type,
                           use_se=False,
                           channels_value_head=channels_value_head,
                           value_fc_size=value_fc_size,
                           use_mix_conv=False,
                           grad_scale_value=grad_scale_value)
    policy_out = policy_head(data=data,
                             act_type=act_type,
                             channels_policy_head=channels_policy_head,
                             n_labels=n_labels,
                             select_policy_from_plane=select_policy_from_plane,
                             use_se=False,
                             channels=channels,
                             grad_scale_policy=grad_scale_policy)
    # group value_out and policy_out together
    sym = mx.symbol.Group([value_out, policy_out])

    return sym