class GateRouter(Layer): """ Takes as input a GateBatch and returns batch of gate coefficients. """ def __init__(self, layer_name, num_gates, irange = 0.05, routing_protocol = 'nearest' ): self.__dict__.update(locals()) del self.self self.output_space = VectorSpace(self.num_gates) def set_input_space(self, space): """ Note: this resets parameters! """ self.pre_input_space = space if isinstance(space, VectorSpace): self.requires_reformat = True self.input_space \ = GateSpace(num_gates = 1, gate_dims=space.get_total_dimension()) elif isinstance(space, GateSpace): self.requires_reformat = False self.input_space = space rng = self.mlp.rng # a tensor3 of weights. Each dim0 represents the weight matrix # from an input gate. In such weight matrices, # each row (dim1) is a cluster centroid: W = rng.uniform(-self.irange, self.irange, (self.input_space.get_num_gates(), self.output_space.get_total_dimension(), self.input_space.get_gate_dimension())) self.W = sharedX(W) self.W.name = self.layer_name + '_W' def fprop(self, G): self.pre_input_space.validate(G) if self.requires_reformat: G = self.input_space.format_from(G, self.pre_input_space) # Get the batch of gate activations and indices from G: # (batch_size, num_input_gates, input_gate_dims) X1 = G.get_activations() # (num_input_gates, input_gate_dims, batch_size) X = X1.dimshuffle(1,2,0) # (num_input_gates, 1, batch_size) X2 = (X**2).sum(1).dimshuffle(0,'x',1) # (num_input_gates, num_output_gates, 1) W2 = (self.W**2).sum(2).dimshuffle(0,1,'x') # The D tensor holds the distance of each example to all clusters # (WITHIN the context of each input gate) # This, as well as A, can be use to get a cost. # (num_input_gates, num_output_gates, batch_size) D = T.sqr(-2*T.batched_dot(self.W, X) + W2 + X2) # The A matrix holds the (mean) distance of each example-gate # to all clusters # (num_output_gates, batch_size) A = D.mean(0) A.name = self.layer_name + '_A' D.name = self.layer_name + '_D' return A, D def cost(self, A, D): """ Y must be one-hot binary. Y_hat is a softmax estimate. of Y. Returns negative log probability of Y under the Y_hat distribution. """ # ( cost = A.min(0).sum() # + GateIndividuality*D.min(0).sum() # )*RouterUnsupervision [grad_W] = T.grad(cost, [W], disconnected_inputs = 'ignore') f = theano.function([X1],cost,updates=[(W,W-(LR*grad_W))])