예제 #1
0
    def __init__(self, K, dt=1.0, dt_max=10.0,
                 B=5, basis=None,
                 sigma=np.inf,
                 lmbda=0):
        """
        Initialize a discrete time network Hawkes model with K processes.

        :param K:       Number of processes
        :param dt:      Time bin size
        :param dt_max:
        """
        self.K = K
        self.dt = dt
        self.dt_max = dt_max
        self.sigma = sigma
        self.lmbda = lmbda

        # Initialize the basis
        if basis is None:
            self.B = B
            self.basis = CosineBasis(self.B, self.dt, self.dt_max, norm=True,
                                     allow_instantaneous=False)
        else:
            self.basis = basis
            self.B = basis.B

        # Initialize nodes
        self.nodes = \
            [self._node_class(self.K, self.B, dt=self.dt,
                              sigma=self.sigma, lmbda=self.lmbda)
             for _ in range(self.K)]
예제 #2
0
    def __init__(self, K, dt=1.0, dt_max=10.0,
                 B=5, basis=None,
                 sigma=np.inf,
                 lmbda=0):
        """
        Initialize a discrete time network Hawkes model with K processes.

        :param K:       Number of processes
        :param dt:      Time bin size
        :param dt_max:
        """
        self.K = K
        self.dt = dt
        self.dt_max = dt_max
        self.sigma = sigma
        self.lmbda = lmbda

        # Initialize the basis
        if basis is None:
            self.B = B
            self.basis = CosineBasis(self.B, self.dt, self.dt_max, norm=True,
                                     allow_instantaneous=False)
        else:
            self.basis = basis
            self.B = basis.B

        # Initialize nodes
        self.nodes = \
            [self._node_class(self.K, self.B, dt=self.dt,
                              sigma=self.sigma, lmbda=self.lmbda)
             for _ in xrange(self.K)]
예제 #3
0
    def __init__(self, K, dt=1.0, dt_max=10.0,
                 B=5, basis=None,
                 alpha=1.0, beta=1.0,
                 allow_instantaneous=False,
                 allow_self_connections=True):
        """
        Initialize a discrete time network Hawkes model with K processes.

        :param K:       Number of processes
        :param dt:      Time bin size
        :param dt_max:
        """
        self.K = K
        self.dt = dt
        self.dt_max = dt_max
        self.allow_self_connections = allow_self_connections

        # Initialize the basis
        if basis is None:
            self.B = B
            self.allow_instantaneous = allow_instantaneous
            self.basis = CosineBasis(self.B, self.dt, self.dt_max, norm=True,
                                     allow_instantaneous=allow_instantaneous)
        else:
            self.basis = basis
            self.allow_instantaneous = basis.allow_instantaneous
            self.B = basis.B

        assert not (self.allow_instantaneous and self.allow_self_connections), \
            "Cannot allow instantaneous self connections"

        # Save the gamma prior
        assert alpha >= 1.0, "Alpha must be greater than 1.0 to ensure log concavity"
        self.alpha = alpha
        self.beta = beta

        # Initialize with sample from Gamma(alpha, beta)
        # self.weights = np.random.gamma(self.alpha, 1.0/self.beta, size=(self.K, 1 + self.K*self.B))
        # self.weights = self.alpha/self.beta * np.ones((self.K, 1 + self.K*self.B))
        self.weights = 1e-3 * np.ones((self.K, 1 + self.K*self.B))
        if not self.allow_self_connections:
            self._remove_self_weights()

        # Initialize the data list to empty
        self.data_list = []
예제 #4
0
class _NonlinearHawkesProcessBase(object):
    """
    Discrete time nonlinear Hawkes process, i.e. Poisson GLM
    """
    __metaclass__ = abc.ABCMeta

    _node_class = None

    def __init__(self,
                 K,
                 dt=1.0,
                 dt_max=10.0,
                 B=5,
                 basis=None,
                 sigma=np.inf,
                 lmbda=0):
        """
        Initialize a discrete time network Hawkes model with K processes.

        :param K:       Number of processes
        :param dt:      Time bin size
        :param dt_max:
        """
        self.K = K
        self.dt = dt
        self.dt_max = dt_max
        self.sigma = sigma
        self.lmbda = lmbda

        # Initialize the basis
        if basis is None:
            self.B = B
            self.basis = CosineBasis(self.B,
                                     self.dt,
                                     self.dt_max,
                                     norm=True,
                                     allow_instantaneous=False)
        else:
            self.basis = basis
            self.B = basis.B

        # Initialize nodes
        self.nodes = \
            [self._node_class(self.K, self.B, dt=self.dt,
                              sigma=self.sigma, lmbda=self.lmbda)
             for _ in range(self.K)]

    def initialize_to_background_rate(self):
        for node in self.nodes:
            node.initialize_to_background_rate()

    @property
    def W(self):
        full_W = np.array([node.w for node in self.nodes])
        WB = full_W[:, 1:].reshape((self.K, self.K, self.B))

        # Weight matrix is summed over impulse response functions
        WT = WB.sum(axis=2)

        # Then we transpose so that the weight matrix is (outgoing x incoming)
        W = WT.T
        return W

    @property
    def G(self):
        full_W = np.array([node.w for node in self.nodes])
        WB = full_W[:, 1:].reshape((self.K, self.K, self.B))

        # Weight matrix is summed over impulse response functions
        WT = WB.sum(axis=2)

        # Impulse response weights are normalized weights
        GT = WB / WT[:, :, None]

        # Then we transpose so that the impuolse matrix is (outgoing x incoming x basis)
        G = np.transpose(GT, [1, 0, 2])

        # TODO: Decide if this is still necessary
        for k1 in range(self.K):
            for k2 in range(self.K):
                if G[k1, k2, :].sum() < 1e-2:
                    G[k1, k2, :] = 1.0 / self.B
        return G

    @property
    def bias(self):
        full_W = np.array([node.w for node in self.nodes])
        return full_W[:, 0]

    def add_data(self, S, F=None):
        """
        Add a data set to the list of observations.
        First, filter the data with the impulse response basis,
        then instantiate a set of parents for this data set.

        :param S: a TxK matrix of of event counts for each time bin
                  and each process.
        """
        assert isinstance(S, np.ndarray) and S.ndim == 2 and S.shape[1] == self.K \
               and np.amin(S) >= 0 and S.dtype == np.int, \
               "Data must be a TxK array of event counts"

        T = S.shape[0]

        if F is None:
            # Filter the data into a TxKxB array
            Ftens = self.basis.convolve_with_basis(S)

            # Flatten this into a T x (KxB) matrix
            # [F00, F01, F02, F10, F11, ... F(K-1)0, F(K-1)(B-1)]
            F = Ftens.reshape((T, self.K * self.B))
            assert np.allclose(F[:, 0], Ftens[:, 0, 0])
            if self.B > 1:
                assert np.allclose(F[:, 1], Ftens[:, 0, 1])
            if self.K > 1:
                assert np.allclose(F[:, self.B], Ftens[:, 1, 0])

            # Prepend a column of ones
            F = np.hstack((np.ones((T, 1)), F))

        for k, node in enumerate(self.nodes):
            node.add_data(F, S[:, k])

    def remove_data(self, index):
        for node in self.nodes:
            del node.data_list[index]

    def log_likelihood(self, index=None):
        ll = np.sum([node.log_likelihood(index=index) for node in self.nodes])
        return ll

    def heldout_log_likelihood(self, S, F=None):
        self.add_data(S, F=F)
        hll = self.log_likelihood(index=-1)
        self.remove_data(-1)
        return hll

    def copy_sample(self):
        """
        Return a copy of the parameters of the model
        """
        # Shallow copy the data
        nodes_original = copy.copy(self.nodes)

        # Make a deep copy without the data
        self.nodes = [n.copy_node() for n in nodes_original]
        model_copy = copy.deepcopy(self)

        # Reset the data and return the data-less copy
        self.nodes = nodes_original
        return model_copy

    def fit_with_bfgs(self):
        # TODO: This can be parallelized
        for k, node in enumerate(self.nodes):
            print("")
            print("Fitting Node ", k)
            node.fit_with_bfgs()
예제 #5
0
class _NonlinearHawkesProcessBase(object):
    """
    Discrete time nonlinear Hawkes process, i.e. Poisson GLM
    """
    __metaclass__ = abc.ABCMeta

    _node_class = None

    def __init__(self, K, dt=1.0, dt_max=10.0,
                 B=5, basis=None,
                 sigma=np.inf,
                 lmbda=0):
        """
        Initialize a discrete time network Hawkes model with K processes.

        :param K:       Number of processes
        :param dt:      Time bin size
        :param dt_max:
        """
        self.K = K
        self.dt = dt
        self.dt_max = dt_max
        self.sigma = sigma
        self.lmbda = lmbda

        # Initialize the basis
        if basis is None:
            self.B = B
            self.basis = CosineBasis(self.B, self.dt, self.dt_max, norm=True,
                                     allow_instantaneous=False)
        else:
            self.basis = basis
            self.B = basis.B

        # Initialize nodes
        self.nodes = \
            [self._node_class(self.K, self.B, dt=self.dt,
                              sigma=self.sigma, lmbda=self.lmbda)
             for _ in xrange(self.K)]

    def initialize_to_background_rate(self):
        for node in self.nodes:
            node.initialize_to_background_rate()

    @property
    def W(self):
        full_W = np.array([node.w for node in self.nodes])
        WB = full_W[:,1:].reshape((self.K,self.K, self.B))

        # Weight matrix is summed over impulse response functions
        WT = WB.sum(axis=2)

        # Then we transpose so that the weight matrix is (outgoing x incoming)
        W = WT.T
        return W

    @property
    def G(self):
        full_W = np.array([node.w for node in self.nodes])
        WB = full_W[:,1:].reshape((self.K,self.K, self.B))

        # Weight matrix is summed over impulse response functions
        WT = WB.sum(axis=2)

        # Impulse response weights are normalized weights
        GT = WB / WT[:,:,None]

        # Then we transpose so that the impuolse matrix is (outgoing x incoming x basis)
        G = np.transpose(GT, [1,0,2])

        # TODO: Decide if this is still necessary
        for k1 in xrange(self.K):
            for k2 in xrange(self.K):
                if G[k1,k2,:].sum() < 1e-2:
                    G[k1,k2,:] = 1.0/self.B
        return G

    @property
    def bias(self):
        full_W = np.array([node.w for node in self.nodes])
        return full_W[:,0]

    def add_data(self, S, F=None):
        """
        Add a data set to the list of observations.
        First, filter the data with the impulse response basis,
        then instantiate a set of parents for this data set.

        :param S: a TxK matrix of of event counts for each time bin
                  and each process.
        """
        assert isinstance(S, np.ndarray) and S.ndim == 2 and S.shape[1] == self.K \
               and np.amin(S) >= 0 and S.dtype == np.int, \
               "Data must be a TxK array of event counts"

        T = S.shape[0]

        if F is None:
            # Filter the data into a TxKxB array
            Ftens = self.basis.convolve_with_basis(S)

            # Flatten this into a T x (KxB) matrix
            # [F00, F01, F02, F10, F11, ... F(K-1)0, F(K-1)(B-1)]
            F = Ftens.reshape((T, self.K * self.B))
            assert np.allclose(F[:,0], Ftens[:,0,0])
            if self.B > 1:
                assert np.allclose(F[:,1], Ftens[:,0,1])
            if self.K > 1:
                assert np.allclose(F[:,self.B], Ftens[:,1,0])

            # Prepend a column of ones
            F = np.hstack((np.ones((T,1)), F))

        for k,node in enumerate(self.nodes):
            node.add_data(F, S[:,k])

    def remove_data(self, index):
        for node in self.nodes:
            del node.data_list[index]

    def log_likelihood(self, index=None):
        ll = np.sum([node.log_likelihood(index=index) for node in self.nodes])
        return ll

    def heldout_log_likelihood(self, S, F=None):
        self.add_data(S, F=F)
        hll = self.log_likelihood(index=-1)
        self.remove_data(-1)
        return hll

    def copy_sample(self):
        """
        Return a copy of the parameters of the model
        """
        # Shallow copy the data
        nodes_original = copy.copy(self.nodes)

        # Make a deep copy without the data
        self.nodes = [n.copy_node() for n in nodes_original]
        model_copy = copy.deepcopy(self)

        # Reset the data and return the data-less copy
        self.nodes = nodes_original
        return model_copy

    def fit_with_bfgs(self):
        # TODO: This can be parallelized
        for k, node in enumerate(self.nodes):
            print ""
            print "Fitting Node ", k
            node.fit_with_bfgs()
예제 #6
0
class DiscreteTimeStandardHawkesModel(object):
    """
    Discrete time standard Hawkes process model with support for
    regularized (stochastic) gradient descent.
    """
    def __init__(self, K, dt=1.0, dt_max=10.0,
                 B=5, basis=None,
                 alpha=1.0, beta=1.0,
                 allow_instantaneous=False,
                 allow_self_connections=True):
        """
        Initialize a discrete time network Hawkes model with K processes.

        :param K:       Number of processes
        :param dt:      Time bin size
        :param dt_max:
        """
        self.K = K
        self.dt = dt
        self.dt_max = dt_max
        self.allow_self_connections = allow_self_connections

        # Initialize the basis
        if basis is None:
            self.B = B
            self.allow_instantaneous = allow_instantaneous
            self.basis = CosineBasis(self.B, self.dt, self.dt_max, norm=True,
                                     allow_instantaneous=allow_instantaneous)
        else:
            self.basis = basis
            self.allow_instantaneous = basis.allow_instantaneous
            self.B = basis.B

        assert not (self.allow_instantaneous and self.allow_self_connections), \
            "Cannot allow instantaneous self connections"

        # Save the gamma prior
        assert alpha >= 1.0, "Alpha must be greater than 1.0 to ensure log concavity"
        self.alpha = alpha
        self.beta = beta

        # Initialize with sample from Gamma(alpha, beta)
        # self.weights = np.random.gamma(self.alpha, 1.0/self.beta, size=(self.K, 1 + self.K*self.B))
        # self.weights = self.alpha/self.beta * np.ones((self.K, 1 + self.K*self.B))
        self.weights = 1e-3 * np.ones((self.K, 1 + self.K*self.B))
        if not self.allow_self_connections:
            self._remove_self_weights()

        # Initialize the data list to empty
        self.data_list = []

    def _remove_self_weights(self):
        for k in xrange(self.K):
                self.weights[k,1+(k*self.B):1+(k+1)*self.B] = 1e-32

    def initialize_with_gibbs_model(self, gibbs_model):
        """
        Initialize with a sample from the network Hawkes model
        :param W:
        :param g:
        :return:
        """
        assert isinstance(gibbs_model, _DiscreteTimeNetworkHawkesModelBase)
        assert gibbs_model.K == self.K
        assert gibbs_model.B == self.B

        lambda0 = gibbs_model.bias_model.lambda0,
        Weff = gibbs_model.weight_model.W_effective
        g = gibbs_model.impulse_model.g

        for k in xrange(self.K):
            self.weights[k,0]  = lambda0[k]
            self.weights[k,1:] = (Weff[:,k][:,None] * g[:,k,:]).ravel()

        if not self.allow_self_connections:
            self._remove_self_weights()

    def initialize_to_background_rate(self):
        if len(self.data_list) > 0:
            N = 0
            T = 0
            for S,_ in self.data_list:
                N += S.sum(axis=0)
                T += S.shape[0] * self.dt

            lambda0 = N / float(T)
            self.weights[:,0] = lambda0

    @property
    def W(self):
        WB = self.weights[:,1:].reshape((self.K,self.K, self.B))

        # DEBUG
        assert WB[0,0,self.B-1] == self.weights[0,1+self.B-1]
        assert WB[0,self.K-1,0] == self.weights[0,1+(self.K-1)*self.B]

        if self.B > 2:
            assert WB[self.K-1,self.K-1,self.B-2] == self.weights[self.K-1,-2]

        # Weight matrix is summed over impulse response functions
        WT = WB.sum(axis=2)
        # Then we transpose so that the weight matrix is (outgoing x incoming)
        W = WT.T

        return W

    @property
    def bias(self):
        return self.weights[:,0]

    def add_data(self, S, F=None, minibatchsize=None):
        """
        Add a data set to the list of observations.
        First, filter the data with the impulse response basis,
        then instantiate a set of parents for this data set.

        :param S: a TxK matrix of of event counts for each time bin
                  and each process.
        """
        assert isinstance(S, np.ndarray) and S.ndim == 2 and S.shape[1] == self.K \
               and np.amin(S) >= 0 and S.dtype == np.int, \
               "Data must be a TxK array of event counts"

        T = S.shape[0]

        if F is None:
            # Filter the data into a TxKxB array
            Ftens = self.basis.convolve_with_basis(S)

            # Flatten this into a T x (KxB) matrix
            # [F00, F01, F02, F10, F11, ... F(K-1)0, F(K-1)(B-1)]
            F = Ftens.reshape((T, self.K * self.B))
            assert np.allclose(F[:,0], Ftens[:,0,0])
            if self.B > 1:
                assert np.allclose(F[:,1], Ftens[:,0,1])
            if self.K > 1:
                assert np.allclose(F[:,self.B], Ftens[:,1,0])

            # Prepend a column of ones
            F = np.concatenate((np.ones((T,1)), F), axis=1)

        # If minibatchsize is not None, add minibatches of data
        if minibatchsize is not None:
            for offset in np.arange(T, step=minibatchsize):
                end = min(offset+minibatchsize, T)
                S_mb = S[offset:end,:]
                F_mb = F[offset:end,:]

                # Add minibatch to the data list
                self.data_list.append((S_mb, F_mb))

        else:
            self.data_list.append((S,F))

    def check_stability(self):
        """
        Check that the weight matrix is stable

        :return:
        """
        # Compute the effective weight matrix
        W_eff = self.weights.sum(axis=2)
        eigs = np.linalg.eigvals(W_eff)
        maxeig = np.amax(np.real(eigs))
        # print "Max eigenvalue: ", maxeig
        if maxeig < 1.0:
            return True
        else:
            return False

    def copy_sample(self):
        """
        Return a copy of the parameters of the model
        :return: The parameters of the model (A,W,\lambda_0, \beta)
        """
        # return copy.deepcopy(self.get_parameters())

        # Shallow copy the data
        data_list = copy.copy(self.data_list)
        self.data_list = []

        # Make a deep copy without the data
        model_copy = copy.deepcopy(self)

        # Reset the data and return the data-less copy
        self.data_list = data_list
        return model_copy

    def compute_rate(self, index=None, ks=None):
        """
        Compute the rate of the k-th process.

        :param index:   Which dataset to comput the rate of
        :param k:       Which process to compute the rate of
        :return:
        """
        if index is None:
            index = 0
        _,F = self.data_list[index]

        if ks is None:
            ks = np.arange(self.K)

        if isinstance(ks, int):
            R = F.dot(self.weights[ks,:])
            return R

        elif isinstance(ks, np.ndarray):
            Rs = []
            for k in ks:
                Rs.append(F.dot(self.weights[k,:])[:,None])
            return np.concatenate(Rs, axis=1)

        else:
            raise Exception("ks must be int or array of indices in 0..K-1")

    def log_prior(self, ks=None):
        """
        Compute the log prior probability of log W
        :param ks:
        :return:
        """
        lp = 0
        for k in ks:
            # lp += (self.alpha * np.log(self.weights[k,1:])).sum()
            # lp += (-self.beta * self.weights[k,1:]).sum()
            if self.alpha > 1:
                lp += (self.alpha -1) * np.log(self.weights[k,1:]).sum()
            lp += (-self.beta * self.weights[k,1:]).sum()
        return lp

    def log_likelihood(self, indices=None, ks=None):
        """
        Compute the log likelihood
        :return:
        """
        ll = 0

        if indices is None:
            indices = np.arange(len(self.data_list))
        if isinstance(indices, int):
            indices = [indices]

        for index in indices:
            S,F = self.data_list[index]
            R = self.compute_rate(index, ks=ks)

            if ks is not None:
                ll += (-gammaln(S[:,ks]+1) + S[:,ks] * np.log(R) -R*self.dt).sum()
            else:
                ll += (-gammaln(S+1) + S * np.log(R) -R*self.dt).sum()

        return ll

    def log_posterior(self, indices=None, ks=None):
        if ks is None:
            ks = np.arange(self.K)

        lp = self.log_likelihood(indices, ks)
        lp += self.log_prior(ks)
        return lp

    def heldout_log_likelihood(self, S):
        self.add_data(S)
        hll = self.log_likelihood(indices=-1)
        self.data_list.pop()
        return hll

    def compute_gradient(self, k, indices=None):
        """
        Compute the gradient of the log likelihood with respect
        to the log biases and log weights

        :param k:   Which process to compute gradients for.
                    If none, return a list of gradients for each process.
        """
        grad = np.zeros(1 + self.K * self.B)

        if indices is None:
            indices = np.arange(len(self.data_list))

        d_W_d_log_W = self._d_W_d_logW(k)
        for index in indices:
            d_rate_d_W = self._d_rate_d_W(index, k)
            d_rate_d_log_W = d_rate_d_W.dot(d_W_d_log_W)
            d_ll_d_rate = self._d_ll_d_rate(index, k)
            d_ll_d_log_W = d_ll_d_rate.dot(d_rate_d_log_W)

            grad += d_ll_d_log_W

        # Add the prior
        # d_log_prior_d_log_W = self._d_log_prior_d_log_W(k)
        # grad += d_log_prior_d_log_W

        d_log_prior_d_W = self._d_log_prior_d_W(k)
        assert np.allclose(d_log_prior_d_W[0], 0.0)
        grad += d_log_prior_d_W.dot(d_W_d_log_W)

        # Zero out the gradient if
        if not self.allow_self_connections:
            assert np.allclose(self.weights[k,1+k*self.B:1+(k+1)*self.B], 0.0)
            grad[1+k*self.B:1+(k+1)*self.B] = 0

        return grad

    def _d_ll_d_rate(self, index, k):
        S,_ = self.data_list[index]
        T = S.shape[0]

        rate = self.compute_rate(index, k)
        # d/dR  S*ln(R) -R*dt
        grad = S[:,k] / rate  - self.dt * np.ones(T)
        return grad

    def _d_rate_d_W(self, index, k):
        _,F = self.data_list[index]
        grad = F
        return grad

    def _d_W_d_logW(self, k):
        """
        Let u = logW
        d{e^u}/du = e^u
                  = W
        """
        return np.diag(self.weights[k,:])

    def _d_log_prior_d_log_W(self, k):
        """
        Use a gamma prior on W (it is log concave for alpha >= 1)
        By change of variables this implies that
        LN p(LN W) = const + \alpha LN W - \beta W
        and
        d/d (LN W) (LN p(LN W)) = \alpha - \beta W

        TODO: Is this still concave? It is a concave function of W,
        but what about of LN W? As a function of u=LN(W) it is
        linear plus a -\beta e^u which is concave for beta > 0,
        so yes, it is still concave.

        So why does BFGS not converge monotonically?

        """
        d_log_prior_d_log_W = np.zeros_like(self.weights[k,:])
        d_log_prior_d_log_W[1:] = self.alpha  - self.beta * self.weights[k,1:]
        return d_log_prior_d_log_W

    def _d_log_prior_d_W(self, k):
        """
        Use a gamma prior on W (it is log concave for alpha >= 1)

        and
        LN p(W)       = (\alpha-1)LN W - \beta W
        d/dW LN p(W)) = (\alpha -1)/W  - \beta
        """
        d_log_prior_d_W = np.zeros_like(self.weights[k,:])
        if self.alpha > 1.0:
            d_log_prior_d_W[1:] += (self.alpha-1) / self.weights[k,1:]

        d_log_prior_d_W[1:] += -self.beta
        return d_log_prior_d_W

    def fit_with_bfgs(self):
        """
        Fit the model with BFGS
        """
        def objective(x, k):
            self.weights[k,:] = np.exp(x)
            self.weights[k,:] = np.nan_to_num(self.weights[k,:])
            return np.nan_to_num(-self.log_posterior(ks=np.array([k])))

        def gradient(x, k):
            self.weights[k,:] = np.exp(x)
            self.weights[k,:] = np.nan_to_num(self.weights[k,:])
            return np.nan_to_num(-self.compute_gradient(k))

        itr = [0]
        def callback(x):
            if itr[0] % 10 == 0:
                print "Iteration: %03d\t LP: %.1f" % (itr[0], self.log_posterior())
            itr[0] = itr[0] + 1

        for k in xrange(self.K):
            print "Optimizing process ", k
            itr[0] = 0
            x0 = np.log(self.weights[k,:])
            res = minimize(objective, x0, args=(k,), jac=gradient, callback=callback)
            self.weights[k,:] = np.exp(res.x)

    def gradient_descent_step(self, stepsz=0.01):
        grad = np.zeros((self.K, 1+self.K*self.B))

        # Compute gradient and take a step for each process
        for k in xrange(self.K):
            grad[k,:] = self.compute_gradient(k)
            self.weights[k,:] = np.exp(np.log(self.weights[k,:]) + stepsz * grad[k,:])

        # Compute the current objective
        ll = self.log_likelihood()

        return self.weights, ll, grad

    def sgd_step(self, prev_velocity, learning_rate, momentum):
        """
        Take a step of the stochastic gradient descent algorithm
        """
        if prev_velocity is None:
            prev_velocity = np.zeros((self.K, 1+self.K*self.B))

        # Compute this gradient row by row
        grad = np.zeros((self.K, 1+self.K*self.B))
        velocity = np.zeros((self.K, 1+self.K*self.B))

        # Get a minibatch
        mb = np.random.choice(len(self.data_list))
        T = self.data_list[mb][0].shape[0]

        # Compute gradient and take a step for each process
        for k in xrange(self.K):
            grad[k,:] = self.compute_gradient(k, indices=[mb]) / T
            velocity[k,:] = momentum * prev_velocity[k,:] + learning_rate * grad[k,:]

            # Gradient steps are taken in log weight space
            log_weightsk = np.log(self.weights[k,:]) + velocity[k,:]

            # The true weights are stored
            self.weights[k,:] = np.exp(log_weightsk)

        # Compute the current objective
        ll = self.log_likelihood()

        return self.weights, ll, velocity