Exemple #1
0
 def __init__(self, config):
     super().__init__()
     self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
     self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                   eps=config.layer_norm_eps)
     self.dropout = nn.Dropout(config.hidden_dropout_prob)
     self.hidden_add = FloatFunctional()
Exemple #2
0
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size,
                                            padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)

        self.add_token_embeddings = FloatFunctional()

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embedding_type = getattr(config,
                                               "position_embedding_type",
                                               "absolute")

        if self.position_embedding_type == 'absolute':
            self.add_position_embeddings = FloatFunctional()
Exemple #3
0
 def __init__(self, activation: str, quant: bool=False):
     super().__init__()
     self.quant = quant
     if quant:
         self.ffunc = FloatFunctional()
     self.act = None
     if activation != 'linear':
         self.act = ACTIVATION_MAP[activation]()
Exemple #4
0
class NNUE(pl.LightningModule):
    """
  This model implementation is designed to be quantized using the built-in
  Pytorch quantization framework.  This leads to some different design decisions
  which is why it's a separate implementation.
  """
    def __init__(self):
        super(NNUE, self).__init__()
        self.input = nn.Linear(halfkp.INPUTS, L1)
        self.input_act = nn.ReLU()
        self.l1 = nn.Linear(2 * L1, L2)
        self.l1_act = nn.ReLU()
        self.l2 = nn.Linear(L2, L3)
        self.l2_act = nn.ReLU()
        self.output = nn.Linear(L3, 1)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.input_mul = FloatFunctional()
        self.input_add = FloatFunctional()

    def forward(self, us, them, w_in, b_in):
        us = self.quant(us)
        them = self.quant(them)
        w_in = self.quant(w_in)
        b_in = self.quant(b_in)
        w = self.input(w_in)
        b = self.input(b_in)
        l0_ = self.input_add.add(
            self.input_mul.mul(us, torch.cat([w, b], dim=1)),
            self.input_mul.mul(them, torch.cat([b, w], dim=1)))
        l0_ = self.input_act(l0_)
        l1_ = self.l1_act(self.l1(l0_))
        l2_ = self.l2_act(self.l2(l1_))
        x = self.output(l2_)
        x = self.dequant(x)
        return x

    def step_(self, batch, batch_idx, loss_type):
        us, them, white, black, outcome, score = batch
        output = self(us, them, white, black)
        loss = F.mse_loss(output, cp_conversion(score))
        self.log(loss_type, loss)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step_(batch, batch_idx, 'train_loss')

    def validation_step(self, batch, batch_idx):
        self.step_(batch, batch_idx, 'val_loss')

    def test_step(self, batch, batch_idx):
        self.step_(batch, batch_idx, 'test_loss')

    def configure_optimizers(self):
        optimizer = torch.optim.Adadelta(self.parameters(), lr=1.0)
        return optimizer
Exemple #5
0
 def __init__(self):
     super(NNUE, self).__init__()
     self.input = nn.Linear(halfkp.INPUTS, L1)
     self.input_act = nn.ReLU()
     self.l1 = nn.Linear(2 * L1, L2)
     self.l1_act = nn.ReLU()
     self.l2 = nn.Linear(L2, L3)
     self.l2_act = nn.ReLU()
     self.output = nn.Linear(L3, 1)
     self.quant = QuantStub()
     self.dequant = DeQuantStub()
     self.input_mul = FloatFunctional()
     self.input_add = FloatFunctional()
Exemple #6
0
class DeepPoolLayer(nn.Module):
    def __init__(self, k, k_out, need_x2, need_fuse):
        super(DeepPoolLayer, self).__init__()       
        self.pools_sizes = [2,4,8]
        self.need_x2 = need_x2
        self.need_fuse = need_fuse
        pools, convs = [],[]
        for i in self.pools_sizes:
            pools.append(nn.AvgPool2d(kernel_size=i, stride=i))
            convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False))
        self.pools = nn.ModuleList(pools)
        self.convs = nn.ModuleList(convs)
        self.q_add00 = FloatFunctional()
        self.q_add01 = FloatFunctional()
        self.q_add02 = FloatFunctional()
        self.relu = nn.ReLU()
        self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False)
        if self.need_fuse:
            self.q_add1 = FloatFunctional()
            self.q_add2 = FloatFunctional() 
            self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False)

    def forward(self, x, x2=None, x3=None):
        x_size = x.size()
        resl = x
        #for i in range(len(self.pools_sizes)):
           
        y0 = self.convs[0](self.pools[0](x))
        z0 = nn.functional.interpolate(y0, x_size[2:], mode='bilinear', align_corners=True)
        
        y1 = self.convs[1](self.pools[1](x))
        z1 = nn.functional.interpolate(y1, x_size[2:], mode='bilinear', align_corners=True)
        
        y2 = self.convs[2](self.pools[2](x))
        z2 = nn.functional.interpolate(y2, x_size[2:], mode='bilinear', align_corners=True)
        
        resl = self.q_add00.add(resl, z0)
        resl = self.q_add01.add(resl, z1)   
        resl = self.q_add02.add(resl, z2)

        resl = self.relu(resl)

        if self.need_x2:
            resl = nn.functional.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True)
        resl = self.conv_sum(resl)

        if self.need_fuse:
            resl = self.q_add1.add(resl, x2)
            resl = self.q_add2.add(resl, x3)
            resl = self.conv_sum_c(resl)
        return resl
Exemple #7
0
 def __init__(self,
              channels,
              kernel_size=3,
              leak=0,
              norm_type='batch',
              DWS=False,
              dilation=1,
              groups=1):
     super().__init__()
     self.conv1 = Conv(channels,
                       channels,
                       kernel_size,
                       DWS=DWS,
                       groups=groups,
                       norm_type=norm_type,
                       dilation=dilation,
                       leak=leak)
     self.conv2 = Conv(channels,
                       channels,
                       kernel_size,
                       DWS=DWS,
                       groups=groups,
                       norm_type=norm_type,
                       leak=-1)
     self.skip_add = FF()
     self.leak = leak
     if leak == 0:
         self.relu = torch.nn.ReLU(inplace=True)
     else:
         self.relu = torch.nn.LeakyReLU(leak)
Exemple #8
0
 def __init__(self,
              channels,
              kernel_size=3,
              leak=0,
              norm_type='batch',
              DWS=False,
              groups=1,
              dilation=1):
     super().__init__()
     if norm_type == 'batch':
         norm_layer = torch.nn.BatchNorm2d
     else:
         norm_layer = torch.nn.InstanceNorm2d
     self.conv1 = Conv(channels,
                       channels,
                       kernel_size,
                       DWS=DWS,
                       groups=groups,
                       dilation=dilation,
                       norm_type=norm_type)
     self.norm1 = norm_layer(channels, affine=True)
     self.leak = leak
     if leak == 0:
         self.relu = torch.nn.ReLU(inplace=True)
     else:
         self.relu = torch.nn.LeakyReLU(leak)
     self.conv2 = Conv(channels,
                       channels,
                       kernel_size,
                       DWS=DWS,
                       groups=groups,
                       norm_type=norm_type)
     self.norm2 = norm_layer(channels, affine=True)
     self.groups = groups
     self.skip_add = FF()
Exemple #9
0
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()
        self.conv1 = Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = BatchNorm2d(out_channels)
        self.act1 = ReLU(num_channels=out_channels, inplace=True)
        self.conv2 = Conv2d(out_channels,
                            out_channels,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            bias=False)
        self.bn2 = BatchNorm2d(out_channels)
        self.identity = (_IdentityModifier(
            in_channels, out_channels, stride) if _IdentityModifier.required(
                in_channels, out_channels, stride) else None)

        self.add_relu = (FloatFunctional() if FloatFunctional is not None else
                         ReLU(num_channels=out_channels, inplace=True))

        self.initialize()
Exemple #10
0
    def __init__(self, feature_set, lambda_=1.0):
        super(NNUE, self).__init__()
        self.feature_set = feature_set
        self.lambda_ = lambda_
        self.input = nn.Linear(feature_set.num_features, L1)
        self.input_act = nn.ReLU()
        self.l1 = nn.Linear(2 * L1, L2)
        self.l1_act = nn.ReLU()
        self.l2 = nn.Linear(L2, L3)
        self.l2_act = nn.ReLU()
        self.output = nn.Linear(L3, 1)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.input_mul = FloatFunctional()
        self.input_add = FloatFunctional()

        self._zero_virtual_feature_weights()
Exemple #11
0
class non_bottleneck_1d(nn.Module):
    def __init__(self, chann, dropprob, dilated):
        super(non_bottleneck_1d, self).__init__()

        self.conv3x1_1 = nn.Conv2d(chann,
                                   chann, (3, 1),
                                   stride=1,
                                   padding=(1, 0),
                                   bias=True)

        self.conv1x3_1 = nn.Conv2d(chann,
                                   chann, (1, 3),
                                   stride=1,
                                   padding=(0, 1),
                                   bias=True)

        self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)

        self.conv3x1_2 = nn.Conv2d(chann,
                                   chann, (3, 1),
                                   stride=1,
                                   padding=(1 * dilated, 0),
                                   bias=True,
                                   dilation=(dilated, 1))

        self.conv1x3_2 = nn.Conv2d(chann,
                                   chann, (1, 3),
                                   stride=1,
                                   padding=(0, 1 * dilated),
                                   bias=True,
                                   dilation=(1, dilated))

        self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)

        self.dropout = nn.Dropout2d(dropprob)

        self.adder = FloatFunctional()

    def forward(self, input):

        output = self.conv3x1_1(input)
        output = F.relu(output)
        output = self.conv1x3_1(output)
        output = self.bn1(output)
        output = F.relu(output)

        output = self.conv3x1_2(output)
        output = F.relu(output)
        output = self.conv1x3_2(output)
        output = self.bn2(output)

        if (self.dropout.p != 0):
            output = self.dropout(output)

        return F.relu(self.adder.add(
            output, input))  #+input = identity (residual connection)
Exemple #12
0
class ScaleChannels(nn.Module):
    def __init__(self, quant: bool=False):
        super().__init__()
        self.quant = quant
        if quant:
            self.ffunc = FloatFunctional()

    def forward(self, x, other):
        if self.quant:
            return self.ffunc.mul(x, other)
        return other * x
Exemple #13
0
 def __init__(self, k, k_out, need_x2, need_fuse):
     super(DeepPoolLayer, self).__init__()       
     self.pools_sizes = [2,4,8]
     self.need_x2 = need_x2
     self.need_fuse = need_fuse
     pools, convs = [],[]
     for i in self.pools_sizes:
         pools.append(nn.AvgPool2d(kernel_size=i, stride=i))
         convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False))
     self.pools = nn.ModuleList(pools)
     self.convs = nn.ModuleList(convs)
     self.q_add00 = FloatFunctional()
     self.q_add01 = FloatFunctional()
     self.q_add02 = FloatFunctional()
     self.relu = nn.ReLU()
     self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False)
     if self.need_fuse:
         self.q_add1 = FloatFunctional()
         self.q_add2 = FloatFunctional() 
         self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False)
Exemple #14
0
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
                config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" %
                (config.hidden_size, config.num_attention_heads))

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size /
                                       config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.attention_scores = Einsum()
        self.normalize = FloatFunctional()
        self.softmax = nn.Softmax(dim=-1)
        self.context_layer = Einsum()

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config,
                                               "position_embedding_type",
                                               "absolute")
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(
                2 * config.max_position_embeddings - 1,
                self.attention_head_size)

        if self.position_embedding_type == 'relative_key':
            self.relative_position_scores = Einsum()
            self.rel_attention_add = FloatFunctional()
        elif self.position_embedding_type == 'relative_key_query':
            self.relative_position_scores_query = Einsum()
            self.relative_position_scores_key = Einsum()
            self.rel_attention_add = FloatFunctional()
            self.attention_add = FloatFunctional()
Exemple #15
0
 def __init__(self, ins, outs, expansion, stride=1, leak=0, dilation=1):
     super().__init__()
     self.stride = stride
     assert stride in [1, 2]
     self.is_res = stride == 1 and ins == outs
     self.conv = Layer131(ins,
                          outs,
                          ins * expansion,
                          kernel_size=3,
                          stride=stride,
                          leak=leak,
                          dilation=dilation)
     self.skip_add = FF()
Exemple #16
0
class Route(nn.Module):
    def __init__(self, quant: bool=False, single: bool=False):
        super().__init__()
        self.quant = quant
        self.single = single
        if not single:
            self.ffunc = FloatFunctional()

    def forward(self, xs):
        if self.single:
            return xs[0]
        if self.quant:
            return self.ffunc.cat(xs, dim=1)
        return torch.cat(xs, dim=1)
Exemple #17
0
    def __init__(self, chann, dropprob, dilated):
        super(non_bottleneck_1d, self).__init__()

        self.conv3x1_1 = nn.Conv2d(chann,
                                   chann, (3, 1),
                                   stride=1,
                                   padding=(1, 0),
                                   bias=True)

        self.conv1x3_1 = nn.Conv2d(chann,
                                   chann, (1, 3),
                                   stride=1,
                                   padding=(0, 1),
                                   bias=True)

        self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)

        self.conv3x1_2 = nn.Conv2d(chann,
                                   chann, (3, 1),
                                   stride=1,
                                   padding=(1 * dilated, 0),
                                   bias=True,
                                   dilation=(dilated, 1))

        self.conv1x3_2 = nn.Conv2d(chann,
                                   chann, (1, 3),
                                   stride=1,
                                   padding=(0, 1 * dilated),
                                   bias=True,
                                   dilation=(1, dilated))

        self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)

        self.dropout = nn.Dropout2d(dropprob)

        self.adder = FloatFunctional()
Exemple #18
0
class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.hidden_add = FloatFunctional()

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(
            self.hidden_add.add(hidden_states, input_tensor))
        return hidden_states
Exemple #19
0
class ResShuffleLayer(torch.nn.Module):
    """ Basic residual layer to import into ImageTransformer
    """
    def __init__(self,
                 channels,
                 kernel_size=3,
                 leak=0,
                 norm_type='batch',
                 DWS=False,
                 groups=1,
                 dilation=1):
        super().__init__()
        if norm_type == 'batch':
            norm_layer = torch.nn.BatchNorm2d
        else:
            norm_layer = torch.nn.InstanceNorm2d
        self.conv1 = Conv(channels,
                          channels,
                          kernel_size,
                          DWS=DWS,
                          groups=groups,
                          dilation=dilation,
                          norm_type=norm_type)
        self.norm1 = norm_layer(channels, affine=True)
        self.leak = leak
        if leak == 0:
            self.relu = torch.nn.ReLU(inplace=True)
        else:
            self.relu = torch.nn.LeakyReLU(leak)
        self.conv2 = Conv(channels,
                          channels,
                          kernel_size,
                          DWS=DWS,
                          groups=groups,
                          norm_type=norm_type)
        self.norm2 = norm_layer(channels, affine=True)
        self.groups = groups
        self.skip_add = FF()

    def forward(self, ins):
        """ forward pass """
        res = ins
        out = self.relu(self.norm1(self.conv1(ins)))
        out = self.norm2(self.conv2(out))
        return shuffle_v1(self.relu(self.skip_add.add(out, res)), self.groups)
Exemple #20
0
class ShortCut(nn.Module):
    def __init__(self, activation: str, quant: bool=False):
        super().__init__()
        self.quant = quant
        if quant:
            self.ffunc = FloatFunctional()
        self.act = None
        if activation != 'linear':
            self.act = ACTIVATION_MAP[activation]()

    def forward(self, x, other):
        if self.quant:
            x = self.ffunc.add(x, other)
        else:
            x += other
        if self.act is not None:
            x = self.act(x)
        return x
Exemple #21
0
class ResLayer(torch.nn.Module):
    """ Basic residual layer to import into ImageTransformer
    """
    def __init__(self,
                 channels,
                 kernel_size=3,
                 leak=0,
                 norm_type='batch',
                 DWS=False,
                 dilation=1,
                 groups=1):
        super().__init__()
        self.conv1 = Conv(channels,
                          channels,
                          kernel_size,
                          DWS=DWS,
                          groups=groups,
                          norm_type=norm_type,
                          dilation=dilation,
                          leak=leak)
        self.conv2 = Conv(channels,
                          channels,
                          kernel_size,
                          DWS=DWS,
                          groups=groups,
                          norm_type=norm_type,
                          leak=-1)
        self.skip_add = FF()
        self.leak = leak
        if leak == 0:
            self.relu = torch.nn.ReLU(inplace=True)
        else:
            self.relu = torch.nn.LeakyReLU(leak)

    def forward(self, ins):
        """ forward pass """
        res = ins
        out = self.conv2(self.conv1(ins))
        return self.relu(self.skip_add.add(out, res))
Exemple #22
0
class InvertedResidual(torch.nn.Module):
    """ MobileNetv2 style residual linear bottleneck layer
        to import into ImageTransformer
    """
    def __init__(self, ins, outs, expansion, stride=1, leak=0, dilation=1):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]
        self.is_res = stride == 1 and ins == outs
        self.conv = Layer131(ins,
                             outs,
                             ins * expansion,
                             kernel_size=3,
                             stride=stride,
                             leak=leak,
                             dilation=dilation)
        self.skip_add = FF()

    def forward(self, x):
        if self.is_res:
            return self.skip_add.add(x, self.conv(x))
        else:
            return self.conv(x)
Exemple #23
0
 def __init__(self, quant: bool=False):
     super().__init__()
     self.quant = quant
     if quant:
         self.ffunc = FloatFunctional()
Exemple #24
0
 def _conv_float_functional():
     return torch.nn.Sequential(
         torch.nn.Conv2d(20, 20, 3),
         FloatFunctional(),
     )
Exemple #25
0
 def __init__(self, quant: bool=False, single: bool=False):
     super().__init__()
     self.quant = quant
     self.single = single
     if not single:
         self.ffunc = FloatFunctional()
Exemple #26
0
class NNUE(pl.LightningModule):
    """
  This model implementation is designed to be quantized using the built-in
  Pytorch quantization framework.  This leads to some different design decisions
  which is why it's a separate implementation.

  lambda_ = 0.0 - purely based on game results
  lambda_ = 1.0 - purely based on search scores
  """
    def __init__(self, feature_set, lambda_=1.0):
        super(NNUE, self).__init__()
        self.feature_set = feature_set
        self.lambda_ = lambda_
        self.input = nn.Linear(feature_set.num_features, L1)
        self.input_act = nn.ReLU()
        self.l1 = nn.Linear(2 * L1, L2)
        self.l1_act = nn.ReLU()
        self.l2 = nn.Linear(L2, L3)
        self.l2_act = nn.ReLU()
        self.output = nn.Linear(L3, 1)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.input_mul = FloatFunctional()
        self.input_add = FloatFunctional()

        self._zero_virtual_feature_weights()

    '''
  We zero all virtual feature weights because during serialization to .nnue
  we compute weights for each real feature as being the sum of the weights for
  the real feature in question and the virtual features it can be factored to.
  This means that if we didn't initialize the virtual feature weights to zero
  we would end up with the real features having effectively unexpected values
  at initialization - following the bell curve based on how many factors there are.
  '''

    def _zero_virtual_feature_weights(self):
        weights = self.input.weight
        with torch.no_grad():
            for a, b in self.feature_set.get_virtual_feature_ranges():
                weights[:, a:b] = 0.0
        self.input.weight = nn.Parameter(weights)

    '''
  This method attempts to convert the model from using the self.feature_set
  to new_feature_set.
  '''

    def set_feature_set(self, new_feature_set):
        if self.feature_set.name == new_feature_set.name:
            return

        # TODO: Implement this for more complicated conversions.
        #       Currently we support only a single feature block.
        if len(self.feature_set.features) > 1:
            raise Exception('Cannot change feature set from {} to {}.'.format(
                self.feature_set.name, new_feature_set.name))

        # Currently we only support conversion for feature sets with
        # one feature block each so we'll dig the feature blocks directly
        # and forget about the set.
        old_feature_block = self.feature_set.features[0]
        new_feature_block = new_feature_set.features[0]

        # next(iter(new_feature_block.factors)) is the way to get the
        # first item in a OrderedDict. (the ordered dict being str : int
        # mapping of the factor name to its size).
        # It is our new_feature_factor_name.
        # For example old_feature_block.name == "HalfKP"
        # and new_feature_factor_name == "HalfKP^"
        # We assume here that the "^" denotes factorized feature block
        # and we would like feature block implementers to follow this convention.
        # So if our current feature_set matches the first factor in the new_feature_set
        # we only have to add the virtual feature on top of the already existing real ones.
        if old_feature_block.name == next(iter(new_feature_block.factors)):
            # We can just extend with zeros since it's unfactorized -> factorized
            weights = self.input.weight
            padding = weights.new_zeros(
                (weights.shape[0], new_feature_block.num_virtual_features))
            weights = torch.cat([weights, padding], dim=1)
            self.input.weight = nn.Parameter(weights)
            self.feature_set = new_feature_set
        else:
            raise Exception('Cannot change feature set from {} to {}.'.format(
                self.feature_set.name, new_feature_set.name))

    def forward(self, us, them, w_in, b_in):
        us = self.quant(us)
        them = self.quant(them)
        w_in = self.quant(w_in)
        b_in = self.quant(b_in)
        w = self.input(w_in)
        b = self.input(b_in)
        l0_ = self.input_add.add(
            self.input_mul.mul(us, torch.cat([w, b], dim=1)),
            self.input_mul.mul(them, torch.cat([b, w], dim=1)))
        l0_ = self.input_act(l0_)
        l1_ = self.l1_act(self.l1(l0_))
        l2_ = self.l2_act(self.l2(l1_))
        x = self.output(l2_)
        x = self.dequant(x)
        return x

    def step_(self, batch, batch_idx, loss_type):
        us, them, white, black, outcome, score = batch

        # 600 is the kPonanzaConstant scaling factor needed to convert the training net output to a score.
        # This needs to match the value used in the serializer
        nnue2score = 600
        scaling = 361

        q = self(us, them, white, black) * nnue2score / scaling
        t = outcome
        p = (score / scaling).sigmoid()

        epsilon = 1e-12
        teacher_entropy = -(p * (p + epsilon).log() + (1.0 - p) *
                            (1.0 - p + epsilon).log())
        outcome_entropy = -(t * (t + epsilon).log() + (1.0 - t) *
                            (1.0 - t + epsilon).log())
        teacher_loss = -(p * F.logsigmoid(q) + (1.0 - p) * F.logsigmoid(-q))
        outcome_loss = -(t * F.logsigmoid(q) + (1.0 - t) * F.logsigmoid(-q))
        result = self.lambda_ * teacher_loss + (1.0 -
                                                self.lambda_) * outcome_loss
        entropy = self.lambda_ * teacher_entropy + (
            1.0 - self.lambda_) * outcome_entropy
        loss = result.mean() - entropy.mean()
        self.log(loss_type, loss)
        return loss

        # MSE Loss function for debugging
        # Scale score by 600.0 to match the expected NNUE scaling factor
        # output = self(us, them, white, black) * 600.0
        # loss = F.mse_loss(output, score)

    def training_step(self, batch, batch_idx):
        return self.step_(batch, batch_idx, 'train_loss')

    def validation_step(self, batch, batch_idx):
        self.step_(batch, batch_idx, 'val_loss')

    def test_step(self, batch, batch_idx):
        self.step_(batch, batch_idx, 'test_loss')

    def configure_optimizers(self):
        # Train with a lower LR on the output layer
        LR = 1e-3
        train_params = [
            {
                'params': self.get_layers(lambda x: self.output != x),
                'lr': LR
            },
            {
                'params': self.get_layers(lambda x: self.output == x),
                'lr': LR / 10
            },
        ]
        # increasing the eps leads to less saturated nets with a few dead neurons
        optimizer = ranger.Ranger(train_params, betas=(.9, 0.999), eps=1.0e-7)
        # Drop learning rate after 75 epochs
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=75,
                                                    gamma=0.3)
        return [optimizer], [scheduler]

    def get_layers(self, filt):
        """
    Returns a list of layers.
    filt: Return true to include the given layer.
    """
        for i in self.children():
            if filt(i):
                if isinstance(i, nn.Linear):
                    for p in i.parameters():
                        if p.requires_grad:
                            yield p
Exemple #27
0
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size,
                                            padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)

        self.add_token_embeddings = FloatFunctional()

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embedding_type = getattr(config,
                                               "position_embedding_type",
                                               "absolute")

        if self.position_embedding_type == 'absolute':
            self.add_position_embeddings = FloatFunctional()

    def forward(self,
                input_ids=None,
                token_type_ids=None,
                position_ids=None,
                inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape,
                                         dtype=torch.long,
                                         device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = self.add_token_embeddings.add(inputs_embeds,
                                                   token_type_embeddings)
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = self.add_position_embeddings.add(
                embeddings, position_embeddings)
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
Exemple #28
0
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
                config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" %
                (config.hidden_size, config.num_attention_heads))

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size /
                                       config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.attention_scores = Einsum()
        self.normalize = FloatFunctional()
        self.softmax = nn.Softmax(dim=-1)
        self.context_layer = Einsum()

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config,
                                               "position_embedding_type",
                                               "absolute")
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(
                2 * config.max_position_embeddings - 1,
                self.attention_head_size)

        if self.position_embedding_type == 'relative_key':
            self.relative_position_scores = Einsum()
            self.rel_attention_add = FloatFunctional()
        elif self.position_embedding_type == 'relative_key_query':
            self.relative_position_scores_query = Einsum()
            self.relative_position_scores_key = Einsum()
            self.rel_attention_add = FloatFunctional()
            self.attention_add = FloatFunctional()

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads,
                                       self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
    ):
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
            attention_mask = encoder_attention_mask
        else:
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        #attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = self.attention_scores('bhij,bhjk->bhik',
                                                 query_layer,
                                                 key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length,
                                          dtype=torch.long,
                                          device=hidden_states.device).view(
                                              -1, 1)
            position_ids_r = torch.arange(seq_length,
                                          dtype=torch.long,
                                          device=hidden_states.device).view(
                                              1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(
                distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(
                dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = self.relative_position_scores(
                    "bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = self.rel_attention_add(
                    attention_scores, relative_position_scores)
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = self.relative_position_scores_query(
                    "bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = self.relative_position_scores_key(
                    "bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = self.attention_add(
                    attention_scores,
                    self.rel_attention_add(relative_position_scores_query,
                                           relative_position_scores_key))

        attention_scores = self.normalize.mul_scalar(
            attention_scores, 1 / math.sqrt(self.attention_head_size))
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            # TODO: Why is this a +? Do we need to quantize?
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = self.softmax(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        #context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = self.context_layer('bhij,bhjk->bhik', attention_probs,
                                           value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.all_head_size, )
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer,
                   attention_probs) if output_attentions else (context_layer, )
        return outputs