def _res_assignments_new(self, o, **kwargs):

        removed = []

        for assignment in list(o.assignments):
            fixed_value = self.visit(assignment.rhs_map)
            if fixed_value:


                removed.extend([assignment, assignment.lhs])

                # Replace the 'Assigned' object with a 'SymbolicConst' in the tree:
                sym_node = ast.SymbolicConstant(symbol=assignment.lhs.symbol, value=fixed_value)
                #ReplaceNode(assignment.lhs, sym_node).visit(o)
                ReplaceNode.replace_and_check(srcObj=assignment.lhs, dstObj=sym_node, root = o)

                # Remove the Assignment equation:
                o._eqn_assignment._objs.remove(assignment)
                o._symbolicconstants._add_item(sym_node)

                

        for a in removed:
            nc = EqnsetVisitorNodeCollector(o)
            assert not a in nc.all(), 'Did not fully remove: %s' % a
Beispiel #2
0
    def VisitLibrary(self, o, **kwargs):

        removed = []
        for aKey in o._eqn_assignment.keys():
            a = o._eqn_assignment[aKey]
            alhs = a.lhs
            fixed_value = self.visit(a.rhs)
            if fixed_value:

                sym_suffix = '_as_symconst'
                sym_suffix = ''
                s = ast.SymbolicConstant(symbol=aKey.symbol
                        + sym_suffix, value=fixed_value)

                #assert False
                #print 'Replacing Node:', a.lhs.symbol
                ReplaceNode(a.lhs, s).visit(o)

                #o._cache_nodes()
                #print 'Done replacing symbol'

                o._symbolicconstants[aKey.symbol] = s
                del o._eqn_assignment[aKey]

                removed.append(alhs)


        # Double check they have gone:
        for a in removed:
            nc = EqnsetVisitorNodeCollector()
            nc.visit(o)
            assert not a in nc.all()

        # Should be no more assignments:
        assert len(o._eqn_assignment) == 0
    def _res_assignments(self, o, **kwargs):
        removed = []
        for aKey in o._eqn_assignment.keys():
            a = o._eqn_assignment[aKey]
            alhs = a.lhs
            fixed_value = self.visit(a.rhs_map)
            if fixed_value:

                sym_suffix = '_as_symconst'
                sym_suffix = ''
                s = ast.SymbolicConstant(symbol=aKey.symbol
                        + sym_suffix, value=fixed_value)

                #ReplaceNode(a.lhs, s).visit(o)
                ReplaceNode.replace_and_check(srcObj=a.lhs, dstObj=s, root = o)



                o._symbolicconstants[aKey.symbol] = s

                from neurounits.misc import SeqUtils
                old_ass = SeqUtils.filter_expect_single( o._eqn_assignment, lambda o:o.symbol == aKey.symbol )
                del o._eqn_assignment[ old_ass ] #o.get_terminal_obj(aKey.symbol) ]

                #del o._eqn_assignment[ o.get_terminal_obj(aKey.symbol) ]

                removed.append(alhs)

        # Double check they have gone:
        for a in removed:
            nc = EqnsetVisitorNodeCollector()
            nc.visit(o)
            assert not a in nc.all()
Beispiel #4
0
    def replace_and_check(cls, srcObj, dstObj, root):
        root = ReplaceNode(srcObj, dstObj).visit(root)

        if srcObj in EqnsetVisitorNodeCollector(root).all():

            from .ast_node_connections import ASTAllConnections
            print 'A node has not been completely removed: %s' % srcObj
            print 'The following are still connected:'
            for node in EqnsetVisitorNodeCollector(root).all():
                conns = ASTAllConnections().visit(node)
                if srcObj in conns:
                    print '    node:', node

            print 'OK'

        # Lets make sure its completely gone:
        assert not srcObj in EqnsetVisitorNodeCollector(root).all()
Beispiel #5
0
 def _cache_nodes(self):
     t = EqnsetVisitorNodeCollector()
     t.visit(self)
     self._parameters = t.nodes[Parameter]
     self._supplied_values = t.nodes[SuppliedValue]
Beispiel #6
0
    def InferFromEqnset(cls, eqnset):

        currents = {}
        supplied_values = {}

        for io_info in [io_info for io_info in eqnset.io_data if io_info.iotype in (IOType.Output, IOType.Input)]:
            if not io_info.metadata or not 'mf' in io_info.metadata:
                continue
            role = io_info.metadata['mf'].get('role', None)

            if role:

                if not eqnset.has_terminal_obj(io_info.symbol):
                    continue

                obj = eqnset.get_terminal_obj(io_info.symbol)

                # Outputs:
                if role == "TRANSMEMBRANECURRENT":
                    assert io_info.iotype== IOType.Output

                    currents[obj] = NeuronMembraneCurrent( obj=obj,  symbol=obj.symbol)

                # Inputs (Supplied Values):
                elif role == "MEMBRANEVOLTAGE":
                    assert io_info.iotype== IOType.Input
                    supplied_values[obj] = NeuronSuppliedValues.MembraneVoltage
                elif role == "TIME":
                    assert io_info.iotype== IOType.Input
                    supplied_values[obj] = NeuronSuppliedValues.Time
                elif role == "TEMPERATURE":
                    assert io_info.iotype== IOType.Input
                    supplied_values[obj] = NeuronSuppliedValues.Temperature
                else:
                    assert False

        if not currents:
            raise ValueError('Mechanism does not expose any currents! %s'% eqnset.name)

        # PointProcess or Distributed  Process:
        mech_type = currents.values()[0].getCurrentType()
        for c in currents.values():
            assert mech_type == c.getCurrentType(),  'Mixed Current types [Distributed/Point]'


        # Work out the units of all the terminal symbols:
        symbol_units = {}

        objs = eqnset.terminal_symbols
        n = EqnsetVisitorNodeCollector()
        n.visit(eqnset)
        objs = n.all()
        for s in objs:

            if s in currents:
                symbol_units[s] = NEURONMappings.current_units[mech_type]
            elif s in supplied_values:

                t = supplied_values[s]
                if t in NeuronSuppliedValues.All:
                    symbol_units[s] = NEURONMappings.supplied_value_units[t]  #NeuroUnitParser.Unit("mV")
                else:
                    print 'Unknown supplied value:', t
                    assert False
            else:

                if isinstance(s,(EqnTimeDerivativeByRegime, EqnAssignmentByRegime, EqnSet, ConstValue)):
                    continue
                if isinstance(s,(InEquality, OnEvent)):
                    continue

                symbol_units[s] = s.get_dimension()





        # Event Handling:
        assert False, 'Deprecated, needs rewrite'
        zero_arg_events = [ev for ev in eqnset.onevents if len(ev.parameters) == 0]
        if len(zero_arg_events) == 0:
            event_function = None
        elif len(zero_arg_events) == 1:
            event_function= zero_arg_events[0]
        else:
            raise ValueError('Multiple Zero-Param Events')

        return MODLBuildParameters(mechanismtype=mech_type, currents=currents, supplied_values=supplied_values, suffix="nmmodl"+eqnset.name, symbol_units=symbol_units, event_function=event_function  )
    def plot_func_exp(self):
        import tables
        h5file = tables.openFile("output.hd5")

        float_group = h5file.root._f_getChild(
            '/simulation_fixed/float/variables/')
        time_array = h5file.root._f_getChild(
            '/simulation_fixed/float/time').read()

        from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector

        from neurounits import ast
        func_call_nodes = EqnsetVisitorNodeCollector(
            self.eqnwriter.component).nodes[ast.FunctionDefInstantiation]
        print func_call_nodes

        #node_locs = dict(self.eqnwriter.intermediate_store_locs)
        #print node_locs

        nbits = 10
        LUT_size = 2**nbits

        in_min_max = None

        for func_call_node in func_call_nodes:

            try:
                node_loc = self.eqnwriter.node_labels[func_call_node]

                data = h5file.root._f_getChild(
                    '/simulation_fixed/float/operations/' +
                    "op%d" % node_loc).read()

                print data.shape
                print time_array.shape

                in_data = data[:, 0]
                out_data = data[:, 1]

                in_min = np.min(in_data)
                in_max = np.max(in_data)
                out_min = np.min(out_data)
                out_max = np.max(out_data)

                if in_min_max == None:
                    in_min_max = (in_min, in_max)
                else:
                    in_min_max = (np.min([in_min, in_min_max[0]]),
                                  np.max([in_max, in_min_max[1]]))

                f = pylab.figure()
                ax1 = f.add_subplot(3, 1, 1)
                ax2 = f.add_subplot(3, 1, 2)
                ax3 = f.add_subplot(3, 1, 3)

                x_space = np.linspace(in_min, in_max, num=2**(nbits + 8))
                x_space_binned = np.linspace(in_min, in_max, num=LUT_size)

                ax1.plot(time_array,
                         in_data,
                         label='in_data (in: %f to %f)' % (in_min, in_max))
                ax2.plot(time_array,
                         out_data,
                         label='out_data (in: %f to %f)' % (out_min, out_max))
                ax1.legend()
                ax2.legend()

                ax3.plot(x_space, np.exp(x_space))
                #fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex=True, sharey=True)

                pylab.legend()

            except tables.exceptions.NoSuchNodeError, e:
                print 'Not such group!: ', e
Beispiel #8
0
    def InferFromEqnset(cls, eqnset):

        currents = {}
        supplied_values = {}

        for io_info in [
                io_info for io_info in eqnset.io_data
                if io_info.iotype in (IOType.Output, IOType.Input)
        ]:
            if not io_info.metadata or not 'mf' in io_info.metadata:
                continue
            role = io_info.metadata['mf'].get('role', None)

            if role:

                if not eqnset.has_terminal_obj(io_info.symbol):
                    continue

                obj = eqnset.get_terminal_obj(io_info.symbol)

                # Outputs:
                if role == "TRANSMEMBRANECURRENT":
                    assert io_info.iotype == IOType.Output

                    currents[obj] = NeuronMembraneCurrent(obj=obj,
                                                          symbol=obj.symbol)

                # Inputs (Supplied Values):
                elif role == "MEMBRANEVOLTAGE":
                    assert io_info.iotype == IOType.Input
                    supplied_values[obj] = NeuronSuppliedValues.MembraneVoltage
                elif role == "TIME":
                    assert io_info.iotype == IOType.Input
                    supplied_values[obj] = NeuronSuppliedValues.Time
                elif role == "TEMPERATURE":
                    assert io_info.iotype == IOType.Input
                    supplied_values[obj] = NeuronSuppliedValues.Temperature
                else:
                    assert False

        if not currents:
            raise ValueError('Mechanism does not expose any currents! %s' %
                             eqnset.name)

        # PointProcess or Distributed  Process:
        mech_type = currents.values()[0].getCurrentType()
        for c in currents.values():
            assert mech_type == c.getCurrentType(
            ), 'Mixed Current types [Distributed/Point]'

        # Work out the units of all the terminal symbols:
        symbol_units = {}

        objs = eqnset.terminal_symbols
        n = EqnsetVisitorNodeCollector()
        n.visit(eqnset)
        objs = n.all()
        for s in objs:

            if s in currents:
                symbol_units[s] = NEURONMappings.current_units[mech_type]
            elif s in supplied_values:

                t = supplied_values[s]
                if t in NeuronSuppliedValues.All:
                    symbol_units[s] = NEURONMappings.supplied_value_units[
                        t]  #NeuroUnitParser.Unit("mV")
                else:
                    print 'Unknown supplied value:', t
                    assert False
            else:

                if isinstance(s, (EqnTimeDerivativeByRegime,
                                  EqnAssignmentByRegime, EqnSet, ConstValue)):
                    continue
                if isinstance(s, (InEquality, OnEvent)):
                    continue

                symbol_units[s] = s.get_dimension()

        # Event Handling:
        assert False, 'Deprecated, needs rewrite'
        zero_arg_events = [
            ev for ev in eqnset.onevents if len(ev.parameters) == 0
        ]
        if len(zero_arg_events) == 0:
            event_function = None
        elif len(zero_arg_events) == 1:
            event_function = zero_arg_events[0]
        else:
            raise ValueError('Multiple Zero-Param Events')

        return MODLBuildParameters(mechanismtype=mech_type,
                                   currents=currents,
                                   supplied_values=supplied_values,
                                   suffix="nmmodl" + eqnset.name,
                                   symbol_units=symbol_units,
                                   event_function=event_function)
Beispiel #9
0
 def _cache_nodes(self):
     t = EqnsetVisitorNodeCollector()
     t.visit(self)
     self._parameters = t.nodes[Parameter]
     self._supplied_values = t.nodes[SuppliedValue]
Beispiel #10
0
    def clone(self, ):

        from neurounits.visitors.common.ast_replace_node import ReplaceNode
        from neurounits.visitors.common.ast_node_connections import ASTAllConnections
        from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector

        class ReplaceNodeHack(ReplaceNode):
            def __init__(self, mapping_dict):
                assert isinstance(mapping_dict, dict)
                self.mapping_dict = mapping_dict

            def replace_or_visit(self, o):
                return self.replace(o)

            def replace(
                self,
                o,
            ):
                if o in self.mapping_dict:
                    return self.mapping_dict[o]
                else:
                    return o

        from neurounits.visitors.common.ast_cloning import ASTClone
        from collections import defaultdict

        import neurounits.ast as ast

        # CONCEPTUALLY THIS IS VERY SIMPLE< BUT THE CODE
        # IS A HORRIBLE HACK!

        no_remap = (ast.Interface, ast.InterfaceWireContinuous,
                    ast.InterfaceWireEvent, ast.BuiltInFunction,
                    ast.FunctionDefParameter)
        # First, lets clone each and every node:
        old_nodes = list(set(list(EqnsetVisitorNodeCollector(self).all())))
        old_to_new_dict = {}
        for old_node in old_nodes:

            if not isinstance(old_node, no_remap):
                new_node = ASTClone().visit(old_node)
            else:
                new_node = old_node

            #print old_node, '-->', new_node
            assert type(old_node) == type(new_node)
            old_to_new_dict[old_node] = new_node

        # Clone self:
        old_to_new_dict[self] = ASTClone().visit(self)

        # Check that all the nodes hav been replaced:
        overlap = (set(old_to_new_dict.keys()) & set(old_to_new_dict.values()))
        for o in overlap:
            assert isinstance(o, no_remap)

        # Now, lets visit each of the new nodes, and replace (old->new) on it:
        #print
        #print 'Replacing Nodes:'

        # Build the mapping dictionary:
        mapping_dict = {}
        for old_repl, new_repl in old_to_new_dict.items():
            #if new_repl == new_node:
            #    continue
            #print ' -- Replacing:',old_repl, new_repl

            if isinstance(old_repl, no_remap):
                continue

            mapping_dict[old_repl] = new_repl

        # Remap all the nodes:
        for new_node in old_to_new_dict.values():
            #print 'Replacing nodes on:', new_node

            node_mapping_dict = mapping_dict.copy()
            if new_node in node_mapping_dict:
                del node_mapping_dict[new_node]

            replacer = ReplaceNodeHack(mapping_dict=node_mapping_dict)
            new_node.accept_visitor(replacer)

        # ok, so the clone should now be all clear:
        new_obj = old_to_new_dict[self]

        new_nodes = list(EqnsetVisitorNodeCollector(new_obj).all())

        # Who points to what!?
        connections_map_obj_to_conns = {}
        connections_map_conns_to_objs = defaultdict(list)
        for node in new_nodes:

            conns = list(node.accept_visitor(ASTAllConnections()))
            connections_map_obj_to_conns[node] = conns
            for c in conns:
                connections_map_conns_to_objs[c].append(node)

        shared_nodes = set(new_nodes) & set(old_nodes)
        shared_nodes_invalid = [
            sn for sn in shared_nodes if not isinstance(sn, no_remap)
        ]

        if len(shared_nodes_invalid) != 0:
            print 'Shared Nodes:'
            print shared_nodes_invalid
            for s in shared_nodes_invalid:
                print ' ', s, s in old_to_new_dict
                print '  Referenced by:'
                for c in connections_map_conns_to_objs[s]:
                    print '    *', c
                print
            assert len(shared_nodes_invalid) == 0

        return new_obj
Beispiel #11
0
 def all_ast_nodes(self):
     from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector
     c = EqnsetVisitorNodeCollector()
     c.visit(self)
     return itertools.chain(*c.nodes.values())
Beispiel #12
0
 def output_event_port_lut(self):
     import neurounits.ast as ast
     from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector
     t = EqnsetVisitorNodeCollector(obj=self)
     return LookUpDict(t.nodes[ast.OutEventPort])
Beispiel #13
0
 def _analog_reduce_ports_lut(self):
     from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector
     t = EqnsetVisitorNodeCollector(obj=self)
     return LookUpDict(t.nodes[AnalogReducePort])
Beispiel #14
0
 def _supplied_lut(self):
     from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector
     t = EqnsetVisitorNodeCollector(obj=self)
     return LookUpDict(t.nodes[SuppliedValue])
Beispiel #15
0
 def _parameters_lut(self):
     from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector
     t = EqnsetVisitorNodeCollector(obj=self)
     return LookUpDict(t.nodes[Parameter])
Beispiel #16
0
 def all_ast_nodes(self):
     from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector
     c = EqnsetVisitorNodeCollector()
     c.visit(self)
     return itertools.chain( *c.nodes.values() )