Ejemplo n.º 1
0
class GateRouter(Layer):
    """
    Takes as input a GateBatch and returns batch of gate coefficients.
    
    """
    def __init__(self,
                 layer_name,
                 num_gates,   
                 irange = 0.05,
                 routing_protocol = 'nearest'
        ):
            
        self.__dict__.update(locals())
        del self.self
        
        self.output_space = VectorSpace(self.num_gates)
        
    def set_input_space(self, space):
        """ Note: this resets parameters! """

        self.pre_input_space = space

        if isinstance(space, VectorSpace):
            self.requires_reformat = True
            self.input_space \
                = GateSpace(num_gates = 1, 
                            gate_dims=space.get_total_dimension())
        elif isinstance(space, GateSpace):
            self.requires_reformat = False
            self.input_space = space
        
        rng = self.mlp.rng
        # a tensor3 of weights. Each dim0 represents the weight matrix 
        # from an input gate. In such weight matrices, 
        # each row (dim1) is a cluster centroid:
        W = rng.uniform(-self.irange, self.irange,
                        (self.input_space.get_num_gates(),
                         self.output_space.get_total_dimension(),
                         self.input_space.get_gate_dimension()))

        self.W = sharedX(W)
        self.W.name = self.layer_name + '_W'

    def fprop(self, G):
        self.pre_input_space.validate(G)

        if self.requires_reformat:
            G = self.input_space.format_from(G, self.pre_input_space)
        # Get the batch of gate activations and indices from G:
        # (batch_size, num_input_gates, input_gate_dims)
        X1 = G.get_activations() 
        # (num_input_gates, input_gate_dims, batch_size)
        X = X1.dimshuffle(1,2,0) 
       
        # (num_input_gates, 1, batch_size) 
        X2 = (X**2).sum(1).dimshuffle(0,'x',1)
        # (num_input_gates, num_output_gates, 1)
        W2 = (self.W**2).sum(2).dimshuffle(0,1,'x')
        
        # The D tensor holds the distance of each example to all clusters 
        # (WITHIN the context of each input gate)
        # This, as well as A, can be use to get a cost.
        # (num_input_gates, num_output_gates, batch_size) 
        D = T.sqr(-2*T.batched_dot(self.W, X) + W2 + X2)
        # The A matrix holds the (mean) distance of each example-gate 
        # to all clusters
        # (num_output_gates, batch_size)
        A = D.mean(0)

        A.name = self.layer_name + '_A'
        D.name = self.layer_name + '_D'

        return A, D
        
    def cost(self, A, D):
        """
        Y must be one-hot binary. Y_hat is a softmax estimate.
        of Y. Returns negative log probability of Y under the Y_hat
        distribution.
        """
        
        # (
        cost = A.min(0).sum() 
        # + GateIndividuality*D.min(0).sum()
        # )*RouterUnsupervision
        [grad_W] = T.grad(cost, [W], disconnected_inputs = 'ignore')
        f = theano.function([X1],cost,updates=[(W,W-(LR*grad_W))])