Beispiel #1
0
    def write_feature_transformer(self, model):
        # int16 bias = round(x * 127)
        # int16 weight = round(x * 127)
        layer = model.input
        bias = layer.bias.data
        bias = bias.mul(127).round().to(torch.int16)
        ascii_hist('ft bias:', bias.numpy())
        self.buf.extend(bias.flatten().numpy().tobytes())

        weight = M.coalesce_ft_weights(model, layer)
        weight = weight.mul(127).round().to(torch.int16)
        ascii_hist('ft weight:', weight.numpy())
        # weights stored as [41024][256]
        self.buf.extend(weight.flatten().numpy().tobytes())
Beispiel #2
0
    def write_feature_transformer(self, model):
        # int16 bias = round(x * 127)
        # int16 weight = round(x * 127)
        layer = model.input
        bias = layer.bias.data[:M.L1]
        bias = bias.mul(127).round().to(torch.int16)
        ascii_hist('ft bias:', bias.numpy())
        self.buf.extend(bias.flatten().numpy().tobytes())

        weight = M.coalesce_ft_weights(model, layer)
        weight0 = weight[:, :M.L1]
        psqtweight0 = weight[:, M.L1:]
        weight = weight0.mul(127).round().to(torch.int16)
        psqtweight = psqtweight0.mul(9600).round().to(
            torch.int32)  # kPonanzaConstant * FV_SCALE = 9600
        ascii_hist('ft weight:', weight.numpy())
        # weights stored as [41024][256]
        self.buf.extend(weight.flatten().numpy().tobytes())
        self.buf.extend(psqtweight.flatten().numpy().tobytes())
Beispiel #3
0
    def plot_input_weights(self):
        # Coalesce weights and transform them to Numpy domain.
        weights = M.coalesce_ft_weights(self.model, self.model.input)
        weights = weights[:, :M.L1]
        weights = weights.flatten().numpy()

        if self.args.ref_model:
            ref_weights = M.coalesce_ft_weights(self.ref_model,
                                                self.ref_model.input)
            ref_weights = ref_weights[:, :M.L1]
            ref_weights = ref_weights.flatten().numpy()
            weights -= ref_weights

        hd = M.L1  # Number of input neurons.
        self.M = hd

        # Preferred ratio of number of input neurons per row/col.
        preferred_ratio = 4

        # Number of input neurons per row.
        # Find a factor of hd such that the aspect ratio
        # is as close to the preferred ratio as possible.
        factor, smallest_diff = 0, hd
        for n in range(1, hd + 1):
            if hd % n == 0:
                ratio = hd / (n * n)
                diff = abs(preferred_ratio - ratio)
                if diff < smallest_diff:
                    factor = n
                    smallest_diff = diff

        numx = hd // factor

        if self.args.sort_input_neurons:
            # Sort input neurons by the L1-norm of their associated weights.
            neuron_weights_norm = np.zeros(hd)
            for i in range(hd):
                neuron_weights_norm[i] = np.sum(np.abs(weights[i::hd]))

            self.sorted_input_neurons = np.flip(
                np.argsort(neuron_weights_norm))
        else:
            self.sorted_input_neurons = np.arange(hd, dtype=int)

        # Derived/fixed constants.
        numy = hd // numx
        widthx = 128
        widthy = 400
        totalx = numx * widthx
        totaly = numy * widthy
        totaldim = totalx * totaly

        if not self.args.no_input_weights:
            default_order = self.args.input_weights_order == "piece-centric-flipped-king"

            # Calculate masks for first input neuron.
            img_mask = []
            weights_mask = []
            for j in range(0, weights.size, hd):
                # Calculate piece and king placement.
                pi = (j // hd) % 704
                ki = (j // hd) // 704
                piece = pi // 64
                rank = (pi % 64) // 8

                if ((rank == 0 or rank == 7) and (piece == 0 or piece == 1)):
                    # Ignore unused weights for pawns on first/last rank.
                    continue

                kipos = [ki % 8, ki // 8]
                pipos = [pi % 8, rank]

                if default_order:
                    # Piece centric, but with flipped king position.
                    # Same order as used by https://github.com/hxim/Stockfish-Evaluation-Guide.
                    # See also https://github.com/glinscott/nnue-pytorch/issues/42#issuecomment-753604393.
                    inpos = [(7 - kipos[0]) + pipos[0] * 8,
                             kipos[1] + (7 - pipos[1]) * 8]
                    d = -8 if piece < 2 else 48 + (piece // 2 - 1) * 64
                else:
                    # King centric.
                    inpos = [
                        8 * kipos[0] + pipos[0],
                        8 * (7 - kipos[1]) + (7 - pipos[1])
                    ]
                    d = -2*(7-kipos[1]) - 1 if piece < 2 else 48 + \
                        (piece // 2 - 1) * 64

                jhd = j % hd
                x = inpos[0] + widthx * (jhd % numx) + (piece % 2) * 64
                y = inpos[1] + d + widthy * (jhd // numx)
                ii = x + y * totalx

                img_mask.append(ii)
                weights_mask.append(j)

            img_mask = np.array(img_mask, dtype=int)
            weights_mask = np.array(weights_mask, dtype=int)

            # Fill image for all input neurons.
            img = np.zeros(totaldim)
            for k in range(hd):
                offset_x = k % numx
                offset_y = k // numx
                img[img_mask + offset_x * widthx + totalx * widthy *
                    offset_y] = weights[weights_mask +
                                        self.sorted_input_neurons[k]]

            if self.args.input_weights_auto_scale:
                vmin = None
                vmax = None
            else:
                vmin = self.args.input_weights_vmin
                vmax = self.args.input_weights_vmax

            extra_info = ""
            if self.args.sort_input_neurons:
                extra_info += "sorted"
                if not default_order:
                    extra_info += ", " + self.args.input_weights_order
            else:
                if not default_order:
                    extra_info += self.args.input_weights_order
            if len(extra_info) > 0:
                extra_info = "; " + extra_info

            if self.args.input_weights_auto_scale or self.args.input_weights_vmin < 0:
                title_template = "input weights [{LABEL}" + extra_info + "]"
                hist_title_template = "input weights histogram [{LABEL}]"
                cmap = 'coolwarm'
            else:
                img = np.abs(img)
                title_template = "abs(input weights) [{LABEL}" + \
                    extra_info + "]"
                hist_title_template = "abs(input weights) histogram [{LABEL}]"
                cmap = 'viridis'

            # Input weights.
            scalex = (numx / numy) / preferred_ratio
            plt.figure(figsize=((scalex * self.args.default_width) // self.dpi,
                                self.args.default_height // self.dpi))
            plt.matshow(img.reshape((totaldim // totalx, totalx)),
                        fignum=0,
                        vmin=vmin,
                        vmax=vmax,
                        cmap=cmap)
            plt.colorbar(fraction=0.046, pad=0.04)

            line_options = {'color': 'black', 'linewidth': 0.5}
            for i in range(1, numx):
                plt.axvline(x=widthx * i - 0.5, **line_options)

            for j in range(1, numy):
                plt.axhline(y=widthy * j - 0.5, **line_options)

            plt.xlim([0, totalx])
            plt.ylim([totaly, 0])
            plt.xticks(ticks=widthx * np.arange(1, numx) - 0.5)
            plt.yticks(ticks=widthy * np.arange(1, numy) - 0.5)
            plt.axis('off')
            plt.title(title_template.format(LABEL=self.args.label))
            plt.tight_layout()

            def format_coord(x, y):
                x, y = int(round(x)), int(round(y))

                x_ = x % widthx
                y_ = y % widthy
                piece_type = (y_ + 16) // 64
                piece_name = "{} {}".format(
                    "white" if x_ // (widthx // 2) == 0 else "black",
                    chess.piece_name(piece_type + 1))

                x_ = x_ % (widthx // 2)
                y_ = (y_ + 16) % 64 if y_ >= 48 else y_ + 8
                if default_order:
                    # Piece centric, flipped king.
                    piece_square_name = chess.square_name(x_ // 8 + 8 *
                                                          (7 - y_ // 8))
                    king_square_name = chess.square_name(7 - (x_ % 8) + 8 *
                                                         (y_ % 8))
                else:
                    # King centric.
                    if piece_type == 0:
                        piece_square_name = chess.square_name(x_ % 8 + 8 *
                                                              (6 -
                                                               ((y_ - 8) % 6)))
                        king_square_name = chess.square_name(x_ // 8 + 8 *
                                                             (7 -
                                                              (y_ - 8) // 6))
                    else:
                        piece_square_name = chess.square_name(x_ % 8 + 8 *
                                                              (7 - (y_ % 8)))
                        king_square_name = chess.square_name(x_ // 8 + 8 *
                                                             (7 - y_ // 8))

                neuron_id = int(numx * (y // widthy) + x // widthx)
                if self.args.sort_input_neurons:
                    neuron_label = "sorted neuron {} (original {})".format(
                        neuron_id, self.sorted_input_neurons[neuron_id])
                else:
                    neuron_label = "neuron {}".format(neuron_id)

                return "{}, {} on {}, white king on {}".format(
                    neuron_label, piece_name, piece_square_name,
                    king_square_name)

            ax = plt.gca()
            ax.format_coord = format_coord

            self._process_fig("input-weights")

            if not self.args.no_hist:
                # Input weights histogram.
                plt.figure()
                plt.hist(img,
                         log=True,
                         bins=(np.arange(
                             int(np.min(img) * 127) - 1,
                             int(np.max(img) * 127) + 3) - 0.5) / 127)
                plt.title(hist_title_template.format(LABEL=self.args.label))
                plt.tight_layout()
                self._process_fig("input-weights-histogram")
def main():
    parser = argparse.ArgumentParser(
        description="Visualizes networks in ckpt, pt and nnue format.")
    parser.add_argument("models",
                        nargs='+',
                        help="Source model (can be .ckpt, .pt or .nnue)")
    parser.add_argument("--dont-show",
                        action="store_true",
                        help="Don't show the plots.")
    features.add_argparse_args(parser)
    args = parser.parse_args()

    supported_features = ('HalfKAv2', 'HalfKAv2^')
    assert args.features in supported_features
    feature_set = features.get_feature_set_from_name(args.features)

    from os.path import basename
    labels = []
    for m in args.models:
        label = basename(m)
        if label.startswith('nn-'):
            label = label[3:]
        if label.endswith('.nnue'):
            label = label[:-5]
        labels.append('\n'.join(label.split('-')))

    models = [load_model(m, feature_set) for m in args.models]

    coalesced_ins = [
        M.coalesce_ft_weights(model, model.input) for model in models
    ]
    input_weights = [
        coalesced_in[:, :M.L1].flatten().numpy()
        for coalesced_in in coalesced_ins
    ]
    input_weights_psqt = [(coalesced_in[:, M.L1:] * 600).flatten().numpy()
                          for coalesced_in in coalesced_ins]
    plot_hists(
        [input_weights],
        labels, [None],
        w=10.0,
        h=3.0,
        num_bins=8 * 128,
        title=
        'Distribution of feature transformer weights among different nets',
        filename='input_weights_hist.png')
    plot_hists(
        [input_weights_psqt],
        labels, [None],
        w=10.0,
        h=3.0,
        num_bins=8 * 128,
        title=
        'Distribution of feature transformer PSQT weights among different nets (in stockfish internal units)',
        filename='input_weights_psqt_hist.png')

    layer_stacks = [model.layer_stacks for model in models]
    layers_l1 = [[] for i in range(layer_stacks[0].count)]
    layers_l2 = [[] for i in range(layer_stacks[0].count)]
    layers_l3 = [[] for i in range(layer_stacks[0].count)]
    for ls in layer_stacks:
        for i, sublayers in enumerate(ls.get_coalesced_layer_stacks()):
            l1, l2, l3 = sublayers
            layers_l1[i].append(l1.weight.flatten().numpy())
            layers_l2[i].append(l2.weight.flatten().numpy())
            layers_l3[i].append(l3.weight.flatten().numpy())
    col_names = ['Subnet {}'.format(i) for i in range(layer_stacks[0].count)]
    plot_hists(
        layers_l1,
        labels,
        col_names,
        w=2.0,
        h=2.0,
        num_bins=128,
        title='Distribution of l1 weights among different nets and buckets',
        filename='l1_weights_hist.png')
    plot_hists(
        layers_l2,
        labels,
        col_names,
        w=2.0,
        h=2.0,
        num_bins=32,
        title='Distribution of l2 weights among different nets and buckets',
        filename='l2_weights_hist.png')
    plot_hists(
        layers_l3,
        labels,
        col_names,
        w=2.0,
        h=2.0,
        num_bins=16,
        title='Distribution of output weights among different nets and buckets',
        filename='output_weights_hist.png')

    if not args.dont_show:
        plt.show()