def construct_net(self, planes): # NCHW format # batch, 112 input channels, 8 x 8 x_planes = tf.reshape(planes, [-1, 112, 8, 8]) # Input convolution flow = self.conv_block(x_planes, filter_size=3, input_channels=112, output_channels=self.RESIDUAL_FILTERS, bn_scale=True) # Residual tower for _ in range(0, self.RESIDUAL_BLOCKS): flow = self.residual_block(flow, self.RESIDUAL_FILTERS) # Policy head if self.POLICY_HEAD == pb.NetworkFormat.POLICY_CONVOLUTION: conv_pol = self.conv_block(flow, filter_size=3, input_channels=self.RESIDUAL_FILTERS, output_channels=self.RESIDUAL_FILTERS) W_pol_conv = weight_variable([3, 3, self.RESIDUAL_FILTERS, 80], name='W_pol_conv2') b_pol_conv = bias_variable([80], name='b_pol_conv2') self.weights.append(W_pol_conv) tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, b_pol_conv) self.weights.append(b_pol_conv) conv_pol2 = tf.nn.bias_add(conv2d(conv_pol, W_pol_conv), b_pol_conv, data_format='NCHW') h_conv_pol_flat = tf.reshape(conv_pol2, [-1, 80 * 8 * 8]) fc1_init = tf.constant(lc0_az_policy_map.make_map()) W_fc1 = tf.get_variable("policy_map", initializer=fc1_init, trainable=False) h_fc1 = tf.matmul(h_conv_pol_flat, W_fc1, name='policy_head') elif self.POLICY_HEAD == pb.NetworkFormat.POLICY_CLASSICAL: conv_pol = self.conv_block(flow, filter_size=1, input_channels=self.RESIDUAL_FILTERS, output_channels=self.policy_channels) h_conv_pol_flat = tf.reshape(conv_pol, [-1, self.policy_channels * 8 * 8]) W_fc1 = weight_variable([self.policy_channels * 8 * 8, 1858], name='fc1/weight') b_fc1 = bias_variable([1858], name='fc1/bias') self.weights.append(W_fc1) tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, b_fc1) self.weights.append(b_fc1) h_fc1 = tf.add(tf.matmul(h_conv_pol_flat, W_fc1), b_fc1, name='policy_head') else: raise ValueError("Unknown policy head type {}".format( self.POLICY_HEAD)) # Value head conv_val = self.conv_block(flow, filter_size=1, input_channels=self.RESIDUAL_FILTERS, output_channels=32) h_conv_val_flat = tf.reshape(conv_val, [-1, 32 * 8 * 8]) W_fc2 = weight_variable([32 * 8 * 8, 128], name='fc2/weight') b_fc2 = bias_variable([128], name='fc2/bias') self.weights.append(W_fc2) self.weights.append(b_fc2) h_fc2 = tf.nn.relu(tf.add(tf.matmul(h_conv_val_flat, W_fc2), b_fc2)) value_outputs = 3 if self.wdl else 1 W_fc3 = weight_variable([128, value_outputs], name='fc3/weight') b_fc3 = bias_variable([value_outputs], name='fc3/bias') self.weights.append(W_fc3) self.weights.append(b_fc3) h_fc3 = tf.add(tf.matmul(h_fc2, W_fc3), b_fc3, name='value_head') if not self.wdl: h_fc3 = tf.nn.tanh(h_fc3) else: tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, b_fc3) return h_fc1, h_fc3
def __init__(self, **kwargs): super(ApplyPolicyMap, self).__init__(**kwargs) self.fc1 = tf.constant(lc0_az_policy_map.make_map())