Beispiel #1
0
def fit_ising_matlab(n, X):
    X = _convert_2d_array_matlab(np.array(X).astype(np.int), n)
    spikes = Spikes(spikes = X)
    learner = Learner(spikes)
    learner.learn_from_spikes(spikes, window_size = 1)
    J = learner.network.J
    theta = learner.network.theta
    return J.ravel().tolist(), theta.ravel().tolist()
Beispiel #2
0
    def sample_from_model(self,
                          J=None,
                          theta=None,
                          trials=None,
                          reshape=False):
        """
        Returns new spikes object with iid Ising spike trains:
        (with Ising model determined by learning with MPF)
        
        .. warning:

            MIGHT NOT BE WORKING PROPERLY!

        Parameters
        ----------
        J : Type, optional
            Description (default None)
        theta : Type, optional
            Description (default None)
        trials : Type, optional
            Description (default None)
        reshape : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """

        trials = trials or range(self._original_spikes.T)
        X = np.zeros(
            (len(trials), self._original_spikes.N, self._original_spikes.M))

        learner = Learner(spikes=self._original_spikes)

        no_net = False
        if J is None or theta is None:
            no_net = True

        for c, t in enumerate(trials):
            if no_net:
                learner.learn_from_spikes(window_size=1, trials=[t])
                J = learner._network.J
                theta = learner._network.theta
            X[c, :, :] = sample_from_ising(J, theta, self._original_spikes.M)

        return Spikes(spikes=X)
Beispiel #3
0
    def sample_from_model(self, J=None, theta=None, trials=None, reshape=False):
        """
        Returns new spikes object with iid Ising spike trains:
        (with Ising model determined by learning with MPF)
        
        .. warning:

            MIGHT NOT BE WORKING PROPERLY!

        Parameters
        ----------
        J : Type, optional
            Description (default None)
        theta : Type, optional
            Description (default None)
        trials : Type, optional
            Description (default None)
        reshape : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """

        trials = trials or range(self._original_spikes.T)
        X = np.zeros((len(trials), self._original_spikes.N, self._original_spikes.M))

        learner = Learner(spikes=self._original_spikes)

        no_net = False
        if J is None or theta is None:
            no_net = True

        for c, t in enumerate(trials):
            if no_net:
                learner.learn_from_spikes(window_size=1, trials=[t])
                J = learner._network.J
                theta = learner._network.theta
            X[c, :, :] = sample_from_ising(J, theta, self._original_spikes.M)

        return Spikes(spikes=X)
Beispiel #4
0
 def fit(self, trials=None, remove_zeros=True, reshape=False):
     """
     Missing documentation
     
     Parameters
     ----------
     trials : Type, optional
         Description (default None)
     remove_zeros : bool, optional
         Remove all 0 training patterns (default True)
     reshape : bool, optional
         Description (default False)
     
     Returns
     -------
     Value : Type
         Description
     """
     # TODO: take care of remove_zeros
     self._sample_spikes = self.sample_from_model(trials=trials, reshape=reshape)
     self._learner = Learner(spikes=self._sample_spikes)
     self._learner.learn_from_spikes(remove_zeros=remove_zeros)
Beispiel #5
0
 def fit(self, trials=None, remove_zeros=False, reshape=False):
     """
     Missing documentation
     
     Parameters
     ----------
     trials : Type, optional
         Description (default None)
     remove_zeros : bool, optional
         Description (default False)
     reshape : bool, optional
         Description (default False)
     
     Returns
     -------
     Value : Type
         Description
     """
     # TODO: take care of remove_zeros
     self._sample_spikes = self.sample_from_model(trials=trials, reshape=reshape)
     self._learner = Learner(spikes=self._sample_spikes)
     self._learner.learn_from_spikes(remove_zeros=remove_zeros)
Beispiel #6
0
class SpikeModel(Restoreable, object):
    """
    Generic model of spikes (and stimulus).

    Parameters
    ----------
    spikes : Type, optional
        Description (default None)
    stimulus : Type, optional
        Description (default None)
    window_size : int, optional
        Description (default 1)
    learner : Type, optional
        Description (default None)

    Parameters
    spikes: spikes to model
    stimulus: corresp stimulus if existent
    window_size: length of time window in binary bins

    Returns
    -------
    Value : Type
        Description
    """
    _SAVE_ATTRIBUTES_V1 = ['_window_size', '_learn_time']
    _SAVE_VERSION = 1
    _SAVE_TYPE = 'SpikeModel'
    _INTERNAL_OBJECTS = zip([
        Spikes, Spikes, Spikes, PatternsRaw, PatternsHopfield, Stimulus,
        Learner
    ], [
        '_original_spikes', '_sample_spikes', '_hopfield_spikes',
        '_raw_patterns', '_hopfield_patterns', '_stimulus', '_learner'
    ], [
        'spikes_original', 'spikes_sample', 'spikes_hopfield', 'patterns_raw',
        'patterns_hopfield', 'stimulus', 'learner'
    ])

    def __init__(self,
                 spikes=None,
                 stimulus=None,
                 window_size=1,
                 learner=None):
        object.__init__(self)
        Restoreable.__init__(self)

        self._stimulus = stimulus
        self._window_size = window_size
        self._learner = learner or None
        self._original_spikes = spikes
        self._learn_time = None
        self._sample_spikes = None
        self._raw_patterns = None
        self._hopfield_patterns = None
        self._hopfield_spikes = None

    @property
    def stimulus(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._stimulus

    @property
    def window_size(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._window_size

    @property
    def learner(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._learner

    @property
    def original_spikes(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._original_spikes

    @property
    def learn_time(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._learn_time

    @property
    def sample_spikes(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._sample_spikes

    @property
    def raw_patterns(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._raw_patterns

    @property
    def hopfield_patterns(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._hopfield_patterns

    @property
    def hopfield_spikes(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._hopfield_spikes

    def fit(self, trials=None, remove_zeros=False, reshape=False):
        """
        Missing documentation
        
        Parameters
        ----------
        trials : Type, optional
            Description (default None)
        remove_zeros : bool, optional
            Description (default False)
        reshape : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """
        # TODO: take care of remove_zeros
        self._sample_spikes = self.sample_from_model(trials=trials,
                                                     reshape=reshape)
        self._learner = Learner(spikes=self._sample_spikes)
        self._learner.learn_from_spikes(remove_zeros=remove_zeros)

    def chomp(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        hdlog.debug("Chomping samples from model")
        self._raw_patterns = PatternsRaw(save_sequence=True)
        self._raw_patterns.chomp_spikes(spikes=self._sample_spikes)
        hdlog.info("Raw: %d-bit, %d patterns" %
                   (self._sample_spikes.N, len(self._raw_patterns)))

        hdlog.debug(
            "Chomping dynamics (from network learned on the samples) applied to samples"
        )
        self._hopfield_patterns = PatternsHopfield(learner=self._learner,
                                                   save_sequence=True)
        self._hopfield_patterns.chomp_spikes(spikes=self._sample_spikes)
        hdlog.info("Hopfield: %d-bit, %d patterns" %
                   (self._sample_spikes.N, len(self._hopfield_patterns)))

        # print "Before dynamics:"
        # print self.sample_spikes.spikes
        # print "Applied dynamics:"
        self._hopfield_spikes = self._hopfield_patterns.apply_dynamics(
            spikes=self._sample_spikes, reshape=True)
        # ma_err = np.abs(self.sample_spikes.spikes - hop_model_spikes).mean()
        #        print hop_model_spikes
        # print "Mean prediction: %1.4f/1.0 (vs guess zero: %1.4f)" % (
        #    (1 - ma_err), 1 - np.abs(self.sample_spikes.spikes).mean())
        # # distortion
        # self.sample_spikes

    def distinct_patterns_over_windows(self,
                                       window_sizes=None,
                                       trials=None,
                                       save_couplings=False,
                                       remove_zeros=False):
        """
        Returns tuple: counts, entropies [, couplings]
        counts, entropies: arrays of size 2 x T x WSizes
        (0: empirical from model sample, 1: dynamics from learned model on sample)
        
        Parameters
        ----------
        window_sizes : Type, optional
            Description (default None)
        trials : Type, optional
            Description (default None)
        save_couplings : bool, optional
            Description (default False)
        remove_zeros : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """
        if window_sizes is None:
            window_sizes = [1]
        trials = trials or range(self._original_spikes.T)
        counts = np.zeros((2, len(trials), len(window_sizes)))
        #entropies = np.zeros((2, len(trials), len(window_sizes)))

        couplings = {}

        tot_learn_time = 0

        for ws, window_size in enumerate(window_sizes):
            couplings[window_size] = []

            for c, trial in enumerate(trials):
                hdlog.info("Trial %d | ws %d" % (trial, window_size))

                self._window_size = window_size

                t = now()
                self.fit(trials=[trial], remove_zeros=remove_zeros)
                diff = now() - t
                hdlog.info("[%1.3f min]" % (diff / 60.))
                tot_learn_time += diff

                if save_couplings:
                    couplings[ws].append(self._learner.network.J.copy())

                self.chomp()
                #entropies[0, c, ws] = self._raw_patterns.entropy()
                counts[0, c, ws] = len(self._raw_patterns)
                #entropies[1, c, ws] = self._hopfield_patterns.entropy()
                counts[1, c, ws] = len(self._hopfield_patterns)

        hdlog.info("Total learn time: %1.3f mins" % (tot_learn_time / 60.))
        self._learn_time = tot_learn_time
        if save_couplings:
            return counts, couplings
        return counts

    def sample_from_model(self, trials=None, reshape=False):
        """
        Missing documentation
        
        Parameters
        ----------
        trials : Type, optional
            Description (default None)
        reshape : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._original_spikes.to_windowed(window_size=self._window_size,
                                                 trials=trials,
                                                 reshape=reshape)

    def save(self, folder_name='spikes_model'):
        """
        saves as npz's: network, params, spikes file_name
        
        Parameters
        ----------
        folder_name : str, optional
            Description (default 'spikes_model')
        
        Returns
        -------
        Value : Type
            Description
        """
        super(SpikeModel, self)._save('spikes_model.npz',
                                      self._SAVE_ATTRIBUTES_V1,
                                      self._SAVE_VERSION,
                                      has_internal=True,
                                      folder_name=folder_name,
                                      internal_objects=self._INTERNAL_OBJECTS)

    @classmethod
    def load(cls, folder_name='spikes_model', load_extra=False):
        """
        Missing documentation
        
        Parameters
        ----------
        folder_name : str, optional
            Description (default 'spikes_model')
        load_extra : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """
        # TODO: document
        return super(SpikeModel,
                     cls)._load('spikes_model.npz',
                                has_internal=True,
                                folder_name=folder_name,
                                internal_objects=cls._INTERNAL_OBJECTS,
                                load_extra=load_extra)

    def _load_v1(self, contents, load_extra=False):
        # internal function to load v1 file format
        hdlog.debug('Loading SpikeModel, format version 1')
        return Restoreable._load_attributes(self, contents,
                                            self._SAVE_ATTRIBUTES_V1)

    # representation

    def __repr__(self):
        return '<SpikeModel: {s}, window size {ws}>'.\
            format(s=repr(self.original_spikes), ws=self.window_size)
Beispiel #7
0
    def test_patterns_hopfield(self):
        file_contents = np.load(os.path.join(os.path.dirname(__file__), 'test_data/tiny_spikes.npz'))
        spikes = Spikes(file_contents[file_contents.keys()[0]])
        learner = Learner(spikes)
        learner.learn_from_spikes(spikes)

        patterns = PatternsHopfield(learner=learner)
        patterns.chomp_spikes(spikes)
        # print spikes.spikes
        self.assertEqual(len(patterns), 3)
        # print "%d fixed-points (entropy H = %1.3f):" % (len(patterns), patterns.entropy())
        # print map(patterns.pattern_for_key, patterns.counts.keys())

        patterns.save(os.path.join(self.TMP_PATH, 'patterns'))
        patterns2 = PatternsHopfield.load(os.path.join(self.TMP_PATH, 'patterns'))
        self.assertTrue(isinstance(patterns2, PatternsHopfield))
        self.assertEqual(len(patterns2), 3)
        self.assertEqual(len(patterns2.mtas), 3)
        self.assertEqual(len(patterns2.mtas_raw), 3)

        learner.learn_from_spikes(spikes, window_size=3)
        patterns = PatternsHopfield(learner=learner)
        patterns.chomp_spikes(spikes, window_size=3)
        # print spikes.spikes
        
        # print patterns.counts
        self.assertEqual(len(patterns), 4)
        # print "%d fixed-points (entropy H = %1.3f):" % (len(patterns), patterns.entropy())
        # for x in patterns.list_patterns(): print x

        spikes_arr1 = np.array([[1, 0, 1], [0, 0, 1], [0, 1, 0]])
        spikes = Spikes(spikes=spikes_arr1)
        learner = Learner(spikes)
        learner.learn_from_spikes(spikes)

        # test recording fixed-points
        file_contents = np.load(os.path.join(os.path.dirname(__file__), 'test_data/spikes_trials.npz'))
        spikes = Spikes(file_contents[file_contents.keys()[0]])
        learner = Learner(spikes)
        learner.learn_from_spikes(spikes)
        patterns = PatternsHopfield(learner, save_sequence=True)
        patterns.chomp_spikes(spikes)
        self.assertEqual(patterns._sequence, [0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1])

        file_contents = np.load(os.path.join(os.path.dirname(__file__), 'test_data/spikes_trials.npz'))
        spikes = Spikes(file_contents[file_contents.keys()[0]])
        learner = Learner(spikes)
        learner.learn_from_spikes(spikes, window_size=2)
        patterns = PatternsHopfield(learner, save_sequence=True)
        patterns.chomp_spikes(spikes, window_size=2)
        # print patterns.mtas
        # print patterns.sequence
        # for x in patterns.list_patterns(): print x
        # print spikes.spikes
        self.assertEqual(patterns._sequence, [0, 1, 2, 3, 0, 1, 4, 5, 6, 5, 7, 3])
        # self.assertTrue(np.mean(patterns.pattern_to_binary_matrix(1) == [[0, 0], [0, 1], [1, 0]]))
        # self.assertTrue(np.mean(patterns.pattern_to_mta_matrix(1) == [[0, 0], [0, 1], [1, .5]]))
        
        hdlog.info(spikes._spikes)
        hdlog.info(patterns.pattern_to_trial_raster(3))
Beispiel #8
0
class SpikeModel(Restoreable, object):
    """
    Generic model of spikes (and stimulus).

    Parameters
    ----------
    spikes : Type, optional
        Description (default None)
    stimulus : Type, optional
        Description (default None)
    window_size : int, optional
        Description (default 1)
    learner : Type, optional
        Description (default None)

    Parameters
    spikes: spikes to model
    stimulus: corresp stimulus if existent
    window_size: length of time window in binary bins

    Returns
    -------
    Value : Type
        Description
    """
    _SAVE_ATTRIBUTES_V1 = ['_window_size', '_learn_time']
    _SAVE_VERSION = 1
    _SAVE_TYPE = 'SpikeModel'
    _INTERNAL_OBJECTS = zip([Spikes, Spikes, Spikes, PatternsRaw, PatternsHopfield, Stimulus, Learner],
                            ['_original_spikes', '_sample_spikes', '_hopfield_spikes',
                              '_raw_patterns', '_hopfield_patterns', '_stimulus', '_learner'],
                             ['spikes_original', 'spikes_sample', 'spikes_hopfield',
                              'patterns_raw', 'patterns_hopfield', 'stimulus', 'learner'])

    def __init__(self, spikes=None, stimulus=None, window_size=1, learner=None):
        object.__init__(self)
        Restoreable.__init__(self)

        self._stimulus = stimulus
        self._window_size = window_size
        self._learner = learner or None
        self._original_spikes = spikes
        self._learn_time = None
        self._sample_spikes = None
        self._raw_patterns = None
        self._hopfield_patterns = None
        self._hopfield_spikes = None

    @property
    def stimulus(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._stimulus

    @property
    def window_size(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._window_size

    @property
    def learner(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._learner

    @property
    def original_spikes(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._original_spikes

    @property
    def learn_time(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._learn_time

    @property
    def sample_spikes(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._sample_spikes

    @property
    def raw_patterns(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._raw_patterns

    @property
    def hopfield_patterns(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._hopfield_patterns

    @property
    def hopfield_spikes(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._hopfield_spikes

    def fit(self, trials=None, remove_zeros=False, reshape=False):
        """
        Missing documentation
        
        Parameters
        ----------
        trials : Type, optional
            Description (default None)
        remove_zeros : bool, optional
            Description (default False)
        reshape : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """
        # TODO: take care of remove_zeros
        self._sample_spikes = self.sample_from_model(trials=trials, reshape=reshape)
        self._learner = Learner(spikes=self._sample_spikes)
        self._learner.learn_from_spikes(remove_zeros=remove_zeros)

    def chomp(self):
        """
        Missing documentation
        
        Returns
        -------
        Value : Type
            Description
        """
        hdlog.debug("Chomping samples from model")
        self._raw_patterns = PatternsRaw(save_sequence=True)
        self._raw_patterns.chomp_spikes(spikes=self._sample_spikes)
        hdlog.info("Raw: %d-bit, %d patterns" % (
            self._sample_spikes.N, len(self._raw_patterns)))

        hdlog.debug("Chomping dynamics (from network learned on the samples) applied to samples")
        self._hopfield_patterns = PatternsHopfield(learner=self._learner, save_sequence=True)
        self._hopfield_patterns.chomp_spikes(spikes=self._sample_spikes)
        hdlog.info("Hopfield: %d-bit, %d patterns" % (
            self._sample_spikes.N, len(self._hopfield_patterns)))

        # print "Before dynamics:"
        # print self.sample_spikes.spikes
        # print "Applied dynamics:"
        self._hopfield_spikes = self._hopfield_patterns.apply_dynamics(spikes=self._sample_spikes, reshape=True)
        # ma_err = np.abs(self.sample_spikes.spikes - hop_model_spikes).mean()
        #        print hop_model_spikes
        # print "Mean prediction: %1.4f/1.0 (vs guess zero: %1.4f)" % (
        #    (1 - ma_err), 1 - np.abs(self.sample_spikes.spikes).mean())
        # # distortion
        # self.sample_spikes

    def distinct_patterns_over_windows(self, window_sizes=None, trials=None, save_couplings=False, remove_zeros=False):
        """
        Returns tuple: counts, entropies [, couplings]
        counts, entropies: arrays of size 2 x T x WSizes
        (0: empirical from model sample, 1: dynamics from learned model on sample)
        
        Parameters
        ----------
        window_sizes : Type, optional
            Description (default None)
        trials : Type, optional
            Description (default None)
        save_couplings : bool, optional
            Description (default False)
        remove_zeros : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """
        if window_sizes is None:
            window_sizes = [1]
        trials = trials or range(self._original_spikes.T)
        counts = np.zeros((2, len(trials), len(window_sizes)))
        #entropies = np.zeros((2, len(trials), len(window_sizes)))

        couplings = {}

        tot_learn_time = 0

        for ws, window_size in enumerate(window_sizes):
            couplings[window_size] = []

            for c, trial in enumerate(trials):
                hdlog.info("Trial %d | ws %d" % (trial, window_size))

                self._window_size = window_size

                t = now()
                self.fit(trials=[trial], remove_zeros=remove_zeros)
                diff = now() - t
                hdlog.info("[%1.3f min]" % (diff / 60.))
                tot_learn_time += diff

                if save_couplings:
                    couplings[ws].append(self._learner.network.J.copy())

                self.chomp()
                #entropies[0, c, ws] = self._raw_patterns.entropy()
                counts[0, c, ws] = len(self._raw_patterns)
                #entropies[1, c, ws] = self._hopfield_patterns.entropy()
                counts[1, c, ws] = len(self._hopfield_patterns)

        hdlog.info("Total learn time: %1.3f mins" % (tot_learn_time / 60.))
        self._learn_time = tot_learn_time
        if save_couplings:
            return counts, couplings
        return counts

    def sample_from_model(self, trials=None, reshape=False):
        """
        Missing documentation
        
        Parameters
        ----------
        trials : Type, optional
            Description (default None)
        reshape : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """
        return self._original_spikes.to_windowed(window_size=self._window_size, trials=trials, reshape=reshape)

    def save(self, folder_name='spikes_model'):
        """
        saves as npz's: network, params, spikes file_name
        
        Parameters
        ----------
        folder_name : str, optional
            Description (default 'spikes_model')
        
        Returns
        -------
        Value : Type
            Description
        """
        super(SpikeModel, self)._save(
            'spikes_model.npz', self._SAVE_ATTRIBUTES_V1, self._SAVE_VERSION,
            has_internal=True, folder_name=folder_name, internal_objects=self._INTERNAL_OBJECTS)

    @classmethod
    def load(cls, folder_name='spikes_model', load_extra=False):
        """
        Missing documentation
        
        Parameters
        ----------
        folder_name : str, optional
            Description (default 'spikes_model')
        load_extra : bool, optional
            Description (default False)
        
        Returns
        -------
        Value : Type
            Description
        """
        # TODO: document
        return super(SpikeModel, cls)._load('spikes_model.npz', has_internal=True,
                                            folder_name=folder_name,
                                            internal_objects=cls._INTERNAL_OBJECTS,
                                            load_extra=load_extra)

    def _load_v1(self, contents, load_extra=False):
        # internal function to load v1 file format
        hdlog.debug('Loading SpikeModel, format version 1')
        return Restoreable._load_attributes(self, contents, self._SAVE_ATTRIBUTES_V1)

    # representation

    def __repr__(self):
        return '<SpikeModel: {s}, window size {ws}>'.\
            format(s=repr(self.original_spikes), ws=self.window_size)
Beispiel #9
0
    def test_basic(self):
        file_contents = np.load(os.path.join(os.path.dirname(__file__), 'test_data/tiny_spikes.npz'))
        spikes = Spikes(file_contents[file_contents.keys()[0]])
        learner = Learner(spikes)
        self.assertEqual(learner._spikes.N, 3)

        learner.learn_from_spikes()
        self.assertTrue(learner._network.J.mean() != 0.)

        learner.learn_from_spikes(spikes)
        self.assertTrue(learner._network.J.mean() != 0.)

        learner.learn_from_spikes(spikes, window_size=3)
        self.assertTrue(learner._network.J.mean() != 0.)
        self.assertTrue(learner._network.J.shape == (9, 9))

        learner._params['hi'] = 'chris'
        learner.save(os.path.join(self.TMP_PATH, 'learner'))
        learner2 = Learner.load(os.path.join(self.TMP_PATH, 'learner'))
        self.assertEqual(learner2.params['hi'], 'chris')
        self.assertEqual(learner2.window_size, 3)
        self.assertTrue(learner2.network.J.mean() != 0.)
        self.assertTrue(learner2.network.J.shape == (9, 9))
Beispiel #10
0
    def test_basic(self):
        file_contents = np.load(
            os.path.join(os.path.dirname(__file__),
                         'test_data/tiny_spikes.npz'))
        spikes = Spikes(file_contents[file_contents.keys()[0]])
        learner = Learner(spikes)
        self.assertEqual(learner._spikes.N, 3)

        learner.learn_from_spikes()
        self.assertTrue(learner._network.J.mean() != 0.)

        learner.learn_from_spikes(spikes)
        self.assertTrue(learner._network.J.mean() != 0.)

        learner.learn_from_spikes(spikes, window_size=3)
        self.assertTrue(learner._network.J.mean() != 0.)
        self.assertTrue(learner._network.J.shape == (9, 9))

        learner._params['hi'] = 'chris'
        learner.save(os.path.join(self.TMP_PATH, 'learner'))
        learner2 = Learner.load(os.path.join(self.TMP_PATH, 'learner'))
        self.assertEqual(learner2.params['hi'], 'chris')
        self.assertEqual(learner2.window_size, 3)
        self.assertTrue(learner2.network.J.mean() != 0.)
        self.assertTrue(learner2.network.J.shape == (9, 9))