示例#1
0
    def construct_net(self, planes):
        # NCHW format
        # batch, 112 input channels, 8 x 8
        x_planes = tf.reshape(planes, [-1, 112, 8, 8])

        # Input convolution
        flow = self.conv_block(x_planes,
                               filter_size=3,
                               input_channels=112,
                               output_channels=self.RESIDUAL_FILTERS,
                               bn_scale=True)
        # Residual tower
        for _ in range(0, self.RESIDUAL_BLOCKS):
            flow = self.residual_block(flow, self.RESIDUAL_FILTERS)

        # Policy head
        if self.POLICY_HEAD == pb.NetworkFormat.POLICY_CONVOLUTION:
            conv_pol = self.conv_block(flow,
                                       filter_size=3,
                                       input_channels=self.RESIDUAL_FILTERS,
                                       output_channels=self.RESIDUAL_FILTERS)
            W_pol_conv = weight_variable([3, 3, self.RESIDUAL_FILTERS, 80],
                                         name='W_pol_conv2')
            b_pol_conv = bias_variable([80], name='b_pol_conv2')
            self.weights.append(W_pol_conv)
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                 b_pol_conv)
            self.weights.append(b_pol_conv)
            conv_pol2 = tf.nn.bias_add(conv2d(conv_pol, W_pol_conv),
                                       b_pol_conv,
                                       data_format='NCHW')

            h_conv_pol_flat = tf.reshape(conv_pol2, [-1, 80 * 8 * 8])
            fc1_init = tf.constant(lc0_az_policy_map.make_map())
            W_fc1 = tf.get_variable("policy_map",
                                    initializer=fc1_init,
                                    trainable=False)
            h_fc1 = tf.matmul(h_conv_pol_flat, W_fc1, name='policy_head')
        elif self.POLICY_HEAD == pb.NetworkFormat.POLICY_CLASSICAL:
            conv_pol = self.conv_block(flow,
                                       filter_size=1,
                                       input_channels=self.RESIDUAL_FILTERS,
                                       output_channels=self.policy_channels)
            h_conv_pol_flat = tf.reshape(conv_pol,
                                         [-1, self.policy_channels * 8 * 8])
            W_fc1 = weight_variable([self.policy_channels * 8 * 8, 1858],
                                    name='fc1/weight')
            b_fc1 = bias_variable([1858], name='fc1/bias')
            self.weights.append(W_fc1)
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, b_fc1)
            self.weights.append(b_fc1)
            h_fc1 = tf.add(tf.matmul(h_conv_pol_flat, W_fc1),
                           b_fc1,
                           name='policy_head')
        else:
            raise ValueError("Unknown policy head type {}".format(
                self.POLICY_HEAD))

        # Value head
        conv_val = self.conv_block(flow,
                                   filter_size=1,
                                   input_channels=self.RESIDUAL_FILTERS,
                                   output_channels=32)
        h_conv_val_flat = tf.reshape(conv_val, [-1, 32 * 8 * 8])
        W_fc2 = weight_variable([32 * 8 * 8, 128], name='fc2/weight')
        b_fc2 = bias_variable([128], name='fc2/bias')
        self.weights.append(W_fc2)
        self.weights.append(b_fc2)
        h_fc2 = tf.nn.relu(tf.add(tf.matmul(h_conv_val_flat, W_fc2), b_fc2))
        value_outputs = 3 if self.wdl else 1
        W_fc3 = weight_variable([128, value_outputs], name='fc3/weight')
        b_fc3 = bias_variable([value_outputs], name='fc3/bias')
        self.weights.append(W_fc3)
        self.weights.append(b_fc3)
        h_fc3 = tf.add(tf.matmul(h_fc2, W_fc3), b_fc3, name='value_head')
        if not self.wdl:
            h_fc3 = tf.nn.tanh(h_fc3)
        else:
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, b_fc3)

        return h_fc1, h_fc3
示例#2
0
 def __init__(self, **kwargs):
     super(ApplyPolicyMap, self).__init__(**kwargs)
     self.fc1 = tf.constant(lc0_az_policy_map.make_map())