Esempio n. 1
0
    def initialize(self,
                   grid_size=(5, 5),
                   report_interval=5.0,
                   grid_orientation="vertical",
                   node_layout="rect",
                   show_plots=False,
                   cts_type="oriented_hex",
                   run_duration=1.0,
                   output_interval=1.0e99,
                   plot_every_transition=False,
                   **kwds):

        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1],
                                              grid_orientation, node_layout,
                                              cts_type)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        nsg = self.initialize_node_state_grid()

        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == "raster":
            from landlab.ca.raster_cts import RasterCTS

            self.ca = RasterCTS(self.grid, ns_dict, xn_list, nsg)
        elif cts_type == "oriented_raster":
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS

            self.ca = OrientedRasterCTS(self.grid, ns_dict, xn_list, nsg)
        elif cts_type == "hex":
            from landlab.ca.hex_cts import HexCTS

            self.ca = HexCTS(self.grid, ns_dict, xn_list, nsg)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS

            self.ca = OrientedHexCTS(self.grid, ns_dict, xn_list, nsg)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots:
            self.initialize_plotting(**kwds)
Esempio n. 2
0
    def initialize(self, grid_size=(5, 5), report_interval=5.0,
                 grid_orientation='vertical', grid_shape='rect',
                 show_plots=False, cts_type='oriented_hex', 
                 run_duration=1.0, output_interval=1.0e99,
                 plot_every_transition=False, **kwds):
        
        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1], 
                                              grid_orientation, grid_shape,
                                              cts_type)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        nsg = self.initialize_node_state_grid()
        
        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == 'raster':
            from landlab.ca.raster_cts import RasterCTS
            self.ca = RasterCTS(self.grid, ns_dict, xn_list, nsg)
        elif cts_type == 'oriented_raster':
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS
            self.ca = OrientedRasterCTS(self.grid, ns_dict, xn_list, nsg)
        elif cts_type == 'hex':
            from landlab.ca.hex_cts import HexCTS
            self.ca = HexCTS(self.grid, ns_dict, xn_list, nsg)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS
            self.ca = OrientedHexCTS(self.grid, ns_dict, xn_list, nsg)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots == True:
            self.initialize_plotting(**kwds)
Esempio n. 3
0
def test_oriented_hex_cts():
    """Tests instantiation of an OrientedHexCTS() object"""
    mg = HexModelGrid(3, 2, 1.0, orientation="vertical", reorient_links=True)
    nsd = {0: "zero", 1: "one"}
    xnlist = []
    xnlist.append(Transition((0, 1, 0), (1, 1, 0), 1.0, "transitioning"))
    nsg = mg.add_zeros("node", "node_state_grid")
    ohcts = OrientedHexCTS(mg, nsd, xnlist, nsg)

    assert ohcts.num_link_states == 12
    # assert_array_equal(ohcts.link_orientation, [2, 1, 0, 0, 0, 2, 1, 0, 2, 1, 0])
    assert_array_equal(ohcts.link_orientation, [2, 0, 1, 0, 2, 0, 1, 0, 2, 0, 1])
Esempio n. 4
0
def test_oriented_hex_cts():
    """Tests instantiation of an OrientedHexCTS() object"""
    mg = HexModelGrid(3, 2, 1.0, orientation='vertical', reorient_links=True)
    nsd = {0 : 'zero', 1 : 'one'}
    xnlist = []
    xnlist.append(Transition((0,1,0), (1,1,0), 1.0, 'transitioning'))
    nsg = mg.add_zeros('node', 'node_state_grid')
    ohcts = OrientedHexCTS(mg, nsd, xnlist, nsg)
    
    assert_equal(ohcts.num_link_states, 12)
    #assert_array_equal(ohcts.link_orientation, [2, 1, 0, 0, 0, 2, 1, 0, 2, 1, 0])
    assert_array_equal(ohcts.link_orientation, [2, 0, 1, 0, 2, 0, 1, 0, 2, 0, 1])
Esempio n. 5
0
def test_shift_link_and_transition_data_upward():
    """Test the LatticeUplifter method that uplifts link data and tr'ns."""

    mg = HexModelGrid((4, 3),
                      spacing=1.0,
                      orientation="vertical",
                      node_layout="rect")
    nsd = {0: "yes", 1: "no"}
    xnlist = []
    xnlist.append(Transition((0, 0, 0), (1, 1, 0), 1.0, "frogging"))
    xnlist.append(Transition((0, 0, 1), (1, 1, 1), 1.0, "frogging"))
    xnlist.append(Transition((0, 0, 2), (1, 1, 2), 1.0, "frogging"))
    nsg = mg.add_zeros("node_state_grid", at="node")
    ohcts = OrientedHexCTS(mg, nsd, xnlist, nsg)

    assert_array_equal(ohcts.link_state[mg.active_links],
                       [0, 4, 8, 8, 4, 0, 4, 8, 8, 4, 0])

    assert_array_equal(ohcts.next_trn_id[mg.active_links],
                       [0, 1, 2, 2, 1, 0, 1, 2, 2, 1, 0])

    assert_array_equal(
        np.round(ohcts.next_update[mg.active_links], 2),
        [0.8, 1.26, 0.92, 0.79, 0.55, 1.04, 0.58, 2.22, 3.31, 0.48, 1.57],
    )

    pq = ohcts.priority_queue

    assert_equal(pq._queue[0][2], 19)  # link for first event = 19, not shifted
    assert_equal(round(pq._queue[0][0], 2), 0.48)  # trn scheduled for t = 0.48
    assert_equal(pq._queue[2][2], 14)  # this event scheduled for link 15...
    assert_equal(round(pq._queue[2][0], 2), 0.58)  # ...trn sched for t = 0.58

    lu = LatticeUplifter(grid=mg)
    lu.shift_link_and_transition_data_upward(ohcts, 0.0)

    # note new events lowest 5 links
    assert_array_equal(
        np.round(ohcts.next_update[mg.active_links], 2),
        [0.75, 0.84, 2.6, 0.07, 0.09, 0.8, 0.02, 1.79, 1.51, 2.04, 3.85],
    )
    assert_equal(pq._queue[0][2], 14)  # new soonest event
    assert_equal(pq._queue[9][2], 13)  # was previously 7, now shifted up...
    assert_equal(round(pq._queue[9][0], 2), 0.8)  # ...still sched for t = 0.80
Esempio n. 6
0
class CTSModel(object):
    """
    Implement a generic CellLab-CTS model.

    This is the base class from which models should inherit.
    """
    def __init__(self,
                 grid_size=(5, 5),
                 report_interval=5.0,
                 grid_orientation='vertical',
                 grid_shape='rect',
                 show_plots=False,
                 cts_type='oriented_hex',
                 run_duration=1.0,
                 output_interval=1.0e99,
                 plot_every_transition=False,
                 initial_state_grid=None,
                 prop_data=None,
                 prop_reset_value=None,
                 **kwds):

        self.initialize(grid_size, report_interval, grid_orientation,
                        grid_shape, show_plots, cts_type, run_duration,
                        output_interval, plot_every_transition,
                        initial_state_grid, prop_data, prop_reset_value,
                        **kwds)

    def initialize(self,
                   grid_size=(5, 5),
                   report_interval=5.0,
                   grid_orientation='vertical',
                   grid_shape='rect',
                   show_plots=False,
                   cts_type='oriented_hex',
                   run_duration=1.0,
                   output_interval=1.0e99,
                   plot_every_transition=False,
                   initial_state_grid=None,
                   prop_data=None,
                   prop_reset_value=None,
                   **kwds):
        """Initialize CTSModel."""
        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1],
                                              grid_orientation, grid_shape,
                                              cts_type)

        # If prop_data is a string, we assume it is a field name
        if isinstance(prop_data, string_types):
            prop_data = self.grid.add_zeros('node', prop_data)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        if initial_state_grid is None:
            nsg = self.initialize_node_state_grid()
        else:
            try:
                nsg = initial_state_grid
                self.grid.at_node['node_state'][:] = nsg
            except:
                #TODO: use new Messaging capability
                print('If initial_state_grid given, must be array of int')
                raise

        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == 'raster':
            from landlab.ca.raster_cts import RasterCTS
            self.ca = RasterCTS(self.grid, ns_dict, xn_list, nsg, prop_data,
                                prop_reset_value)
        elif cts_type == 'oriented_raster':
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS
            self.ca = OrientedRasterCTS(self.grid, ns_dict, xn_list, nsg,
                                        prop_data, prop_reset_value)
        elif cts_type == 'hex':
            from landlab.ca.hex_cts import HexCTS
            self.ca = HexCTS(self.grid, ns_dict, xn_list, nsg, prop_data,
                             prop_reset_value)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS
            self.ca = OrientedHexCTS(self.grid, ns_dict, xn_list, nsg,
                                     prop_data, prop_reset_value)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots == True:
            self.initialize_plotting(**kwds)

    def create_grid_and_node_state_field(self, num_rows, num_cols,
                                         grid_orientation, grid_shape,
                                         cts_type):
        """Create the grid and the field containing node states."""

        if cts_type == 'raster' or cts_type == 'oriented_raster':
            from landlab import RasterModelGrid
            self.grid = RasterModelGrid(shape=(num_rows, num_cols),
                                        spacing=1.0)
        else:
            from landlab import HexModelGrid
            self.grid = HexModelGrid(num_rows,
                                     num_cols,
                                     1.0,
                                     orientation=grid_orientation,
                                     shape=grid_shape)

        self.grid.add_zeros('node', 'node_state', dtype=int)
        for edge in (self.grid.nodes_at_right_edge,
                     self.grid.nodes_at_top_edge):
            self.grid.status_at_node[edge] = CLOSED_BOUNDARY

    def node_state_dictionary(self):
        """Create and return a dictionary of all possible node (cell) states.

        This method creates a default set of states (just two); it is a
        template meant to be overridden.
        """
        ns_dict = {0: 'on', 1: 'off'}
        return ns_dict

    def transition_list(self):
        """Create and return a list of transition objects.

        This method creates a default set of transitions (just two); it is a
        template meant to be overridden.
        """
        xn_list = []
        xn_list.append(Transition((0, 1, 0), (1, 0, 0), 1.0))
        xn_list.append(Transition((1, 0, 0), (0, 1, 0), 1.0))
        return xn_list

    def write_output(self, grid, outfilename, iteration):
        """Write output to file (currently netCDF)."""
        filename = outfilename + str(iteration).zfill(4) + '.nc'
        save_grid(grid, filename)

    def initialize_node_state_grid(self):
        """Initialize values in the node-state grid.

        This method should be overridden. The default is random "on" and "off".
        """
        num_states = 2
        for i in range(self.grid.number_of_nodes):
            self.grid.at_node['node_state'][i] = random.randint(num_states)
        return self.grid.at_node['node_state']

    def initialize_plotting(self, **kwds):
        """Create and configure CAPlotter object."""
        self.ca_plotter = CAPlotter(self.ca, **kwds)
        self.ca_plotter.update_plot()
        axis('off')

    def run_for(self, dt):

        self.ca.run(self.ca.current_time + dt, self.ca.node_state)
Esempio n. 7
0
def main():

    # INITIALIZE

    # User-defined parameters
    nr = 41
    nc = 61
    g = 1.0
    f = 0.7
    silo_y0 = 30.0
    silo_opening_half_width = 6
    plot_interval = 10.0
    run_duration = 240.0
    report_interval = 300.0  # report interval, in real-time seconds
    p_init = 0.4  # probability that a cell is occupied at start
    plot_every_transition = False

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create a grid
    hmg = HexModelGrid(
        nr, nc, 1.0, orientation="vertical", shape="rect", reorient_links=True
    )

    # Set up the states and pair transitions.
    # Transition data here represent particles moving on a lattice: one state
    # per direction (for 6 directions), plus an empty state, a stationary
    # state, and a wall state.
    ns_dict = {
        0: "empty",
        1: "moving up",
        2: "moving right and up",
        3: "moving right and down",
        4: "moving down",
        5: "moving left and down",
        6: "moving left and up",
        7: "rest",
        8: "wall",
    }
    xn_list = setup_transition_list(g, f)

    # Create data and initialize values.
    node_state_grid = hmg.add_zeros("node", "node_state_grid")

    # Make the grid boundary all wall particles
    node_state_grid[hmg.boundary_nodes] = 8

    # Place wall particles to form the base of the silo, initially closed
    tan30deg = numpy.tan(numpy.pi / 6.)
    rampy1 = silo_y0 - hmg.node_x * tan30deg
    rampy2 = silo_y0 - ((nc * 0.866 - 1.) - hmg.node_x) * tan30deg
    rampy = numpy.maximum(rampy1, rampy2)
    (ramp_nodes,) = numpy.where(
        numpy.logical_and(hmg.node_y > rampy - 0.5, hmg.node_y < rampy + 0.5)
    )
    node_state_grid[ramp_nodes] = 8

    # Seed the grid interior with randomly oriented particles
    for i in hmg.core_nodes:
        if hmg.node_y[i] > rampy[i] and random.random() < p_init:
            node_state_grid[i] = random.randint(1, 7)

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    import matplotlib

    rock = (0.0, 0.0, 0.0)  # '#5F594D'
    sed = (0.6, 0.6, 0.6)  # '#A4874B'
    # sky = '#CBD5E1'
    # sky = '#85A5CC'
    sky = (1.0, 1.0, 1.0)  # '#D0E4F2'
    mob = (0.3, 0.3, 0.3)  # '#D98859'
    # mob = '#DB764F'
    # mob = '#FFFF00'
    # sed = '#CAAE98'
    # clist = [(0.5, 0.9, 0.9),mob, mob, mob, mob, mob, mob,'#CD6839',(0.3,0.3,0.3)]
    clist = [sky, mob, mob, mob, mob, mob, mob, sed, rock]
    my_cmap = matplotlib.colors.ListedColormap(clist)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca, cmap=my_cmap)
    k = 0

    # Plot the initial grid
    ca_plotter.update_plot()

    # RUN

    # Run with closed silo
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print(
                "Current sim time",
                current_time,
                "(",
                100 * current_time / run_duration,
                "%)",
            )
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(
            current_time + plot_interval,
            ca.node_state,
            plot_each_transition=plot_every_transition,
            plotter=ca_plotter,
        )
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()

    # Open the silo
    xmid = nc * 0.866 * 0.5
    for i in range(hmg.number_of_nodes):
        if (
            node_state_grid[i] == 8
            and hmg.node_x[i] > (xmid - silo_opening_half_width)
            and hmg.node_x[i] < (xmid + silo_opening_half_width)
            and hmg.node_y[i] > 0
            and hmg.node_y[i] < 38.0
        ):
            node_state_grid[i] = 0

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca)

    # Plot the initial grid
    ca_plotter.update_plot()

    # Re-run with open silo
    savefig("silo" + str(k) + ".png")
    k += 1
    current_time = 0.0
    while current_time < 5 * run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print(
                "Current sim time",
                current_time,
                "(",
                100 * current_time / run_duration,
                "%)",
            )
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(
            current_time + plot_interval,
            ca.node_state,
            plot_each_transition=plot_every_transition,
            plotter=ca_plotter,
        )
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()
        savefig("silo" + str(k) + ".png")
        k += 1

    # FINALIZE

    # Plot
    ca_plotter.finalize()
Esempio n. 8
0
def main():

    # INITIALIZE

    # User-defined parameters
    nr = 41
    nc = 61
    g = 1.0
    f = 0.7
    silo_y0 = 30.0
    silo_opening_half_width = 6
    plot_interval = 10.0
    run_duration = 240.0
    report_interval = 300.0  # report interval, in real-time seconds
    p_init = 0.4  # probability that a cell is occupied at start
    plot_every_transition = False

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create a grid
    hmg = HexModelGrid(nr,
                       nc,
                       1.0,
                       orientation='vertical',
                       shape='rect',
                       reorient_links=True)

    # Set up the states and pair transitions.
    # Transition data here represent particles moving on a lattice: one state
    # per direction (for 6 directions), plus an empty state, a stationary
    # state, and a wall state.
    ns_dict = {
        0: 'empty',
        1: 'moving up',
        2: 'moving right and up',
        3: 'moving right and down',
        4: 'moving down',
        5: 'moving left and down',
        6: 'moving left and up',
        7: 'rest',
        8: 'wall'
    }
    xn_list = setup_transition_list(g, f)

    # Create data and initialize values.
    node_state_grid = hmg.add_zeros('node', 'node_state_grid')

    # Make the grid boundary all wall particles
    node_state_grid[hmg.boundary_nodes] = 8

    # Place wall particles to form the base of the silo, initially closed
    tan30deg = numpy.tan(numpy.pi / 6.)
    rampy1 = silo_y0 - hmg.node_x * tan30deg
    rampy2 = silo_y0 - ((nc * 0.866 - 1.) - hmg.node_x) * tan30deg
    rampy = numpy.maximum(rampy1, rampy2)
    (ramp_nodes, ) = numpy.where(numpy.logical_and(hmg.node_y>rampy-0.5, \
                                   hmg.node_y<rampy+0.5))
    node_state_grid[ramp_nodes] = 8

    # Seed the grid interior with randomly oriented particles
    for i in hmg.core_nodes:
        if hmg.node_y[i] > rampy[i] and random.random() < p_init:
            node_state_grid[i] = random.randint(1, 7)

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    import matplotlib
    rock = (0.0, 0.0, 0.0)  #'#5F594D'
    sed = (0.6, 0.6, 0.6)  #'#A4874B'
    #sky = '#CBD5E1'
    #sky = '#85A5CC'
    sky = (1.0, 1.0, 1.0)  #'#D0E4F2'
    mob = (0.3, 0.3, 0.3)  #'#D98859'
    #mob = '#DB764F'
    #mob = '#FFFF00'
    #sed = '#CAAE98'
    #clist = [(0.5, 0.9, 0.9),mob, mob, mob, mob, mob, mob,'#CD6839',(0.3,0.3,0.3)]
    clist = [sky, mob, mob, mob, mob, mob, mob, sed, rock]
    my_cmap = matplotlib.colors.ListedColormap(clist)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca, cmap=my_cmap)
    k = 0

    # Plot the initial grid
    ca_plotter.update_plot()

    # RUN

    # Run with closed silo
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print 'Current sim time', current_time, '(', 100 * current_time / run_duration, '%)'
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(current_time + plot_interval,
               ca.node_state,
               plot_each_transition=plot_every_transition,
               plotter=ca_plotter)
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()

    # Open the silo
    xmid = nc * 0.866 * 0.5
    for i in range(hmg.number_of_nodes):
        if node_state_grid[i]==8 and hmg.node_x[i]>(xmid-silo_opening_half_width) \
           and hmg.node_x[i]<(xmid+silo_opening_half_width) \
           and hmg.node_y[i]>0 and hmg.node_y[i]<38.0:
            node_state_grid[i] = 0

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca)

    # Plot the initial grid
    ca_plotter.update_plot()

    # Re-run with open silo
    savefig('silo' + str(k) + '.png')
    k += 1
    current_time = 0.0
    while current_time < 5 * run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print 'Current sim time', current_time, '(', 100 * current_time / run_duration, '%)'
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(current_time + plot_interval,
               ca.node_state,
               plot_each_transition=plot_every_transition,
               plotter=ca_plotter)
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()
        savefig('silo' + str(k) + '.png')
        k += 1

    # FINALIZE

    # Plot
    ca_plotter.finalize()
Esempio n. 9
0
def main():

    # INITIALIZE

    # User-defined parameters
    nr = 21
    nc = 21
    plot_interval = 0.5
    run_duration = 25.0
    report_interval = 5.0  # report interval, in real-time seconds

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create a grid
    hmg = HexModelGrid(nr, nc, 1.0, orientation='vertical', reorient_links=True)

    # Close the grid boundaries
    hmg.set_closed_nodes(hmg.open_boundary_nodes)

    # Set up the states and pair transitions.
    # Transition data here represent the disease status of a population.
    ns_dict = { 0 : 'fluid', 1 : 'grain' }
    xn_list = setup_transition_list()

    # Create data and initialize values. We start with the 3 middle columns full
    # of grains, and the others empty.
    node_state_grid = hmg.add_zeros('node', 'node_state_grid')
    middle = 0.25*(nc-1)*sqrt(3)
    is_middle_cols = logical_and(hmg.node_x<middle+1., hmg.node_x>middle-1.)
    node_state_grid[where(is_middle_cols)[0]] = 1

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca)

    # Plot the initial grid
    ca_plotter.update_plot()

    # RUN
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print('Current sim time',current_time,'(',100*current_time/run_duration,'%)')
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(current_time+plot_interval, ca.node_state,
               plot_each_transition=False)
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()


    # FINALIZE

    # Plot
    ca_plotter.finalize()
Esempio n. 10
0
    def initialize(self, grid_size=(5, 5), report_interval=5.0,
                 grid_orientation='vertical', grid_shape='rect',
                 show_plots=False, cts_type='oriented_hex',
                 run_duration=1.0, output_interval=1.0e99,
                 plot_every_transition=False, initial_state_grid=None,
                 prop_data=None, prop_reset_value=None,
                 closed_boundaries=(False, False, False, False), **kwds):
        """Initialize CTSModel."""
        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1],
                                              grid_orientation, grid_shape,
                                              cts_type, closed_boundaries)

        # If prop_data is a string, we assume it is a field name
        if isinstance(prop_data, string_types):
            prop_data = self.grid.add_zeros('node', prop_data)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        if initial_state_grid is None:
            nsg = self.initialize_node_state_grid()
        else:
            try:
                nsg = initial_state_grid
                self.grid.at_node['node_state'][:] = nsg
            except:
                #TODO: use new Messaging capability
                print('If initial_state_grid given, must be array of int')
                raise

        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == 'raster':
            from landlab.ca.raster_cts import RasterCTS
            self.ca = RasterCTS(self.grid, ns_dict, xn_list, nsg, prop_data,
                                prop_reset_value)
        elif cts_type == 'oriented_raster':
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS
            self.ca = OrientedRasterCTS(self.grid, ns_dict, xn_list, nsg,
                                        prop_data, prop_reset_value)
        elif cts_type == 'hex':
            from landlab.ca.hex_cts import HexCTS
            self.ca = HexCTS(self.grid, ns_dict, xn_list, nsg, prop_data,
                             prop_reset_value)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS
            self.ca = OrientedHexCTS(self.grid, ns_dict, xn_list, nsg,
                                     prop_data, prop_reset_value)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots == True:
            self.initialize_plotting(**kwds)
Esempio n. 11
0
class CTSModel(object):
    """
    Implement a generic CellLab-CTS model.

    This is the base class from which models should inherit.
    """

    def __init__(self, grid_size=(5, 5), report_interval=5.0,
                 grid_orientation='vertical', grid_shape='rect',
                 show_plots=False, cts_type='oriented_hex',
                 run_duration=1.0, output_interval=1.0e99,
                 plot_every_transition=False, initial_state_grid=None,
                 prop_data=None, prop_reset_value=None,
                 closed_boundaries=(False, False, False, False), **kwds):

        self.initialize(grid_size, report_interval, grid_orientation,
                        grid_shape, show_plots, cts_type, run_duration,
                        output_interval, plot_every_transition,
                        initial_state_grid, prop_data, prop_reset_value,
                        closed_boundaries, **kwds)


    def initialize(self, grid_size=(5, 5), report_interval=5.0,
                 grid_orientation='vertical', grid_shape='rect',
                 show_plots=False, cts_type='oriented_hex',
                 run_duration=1.0, output_interval=1.0e99,
                 plot_every_transition=False, initial_state_grid=None,
                 prop_data=None, prop_reset_value=None,
                 closed_boundaries=(False, False, False, False), **kwds):
        """Initialize CTSModel."""
        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1],
                                              grid_orientation, grid_shape,
                                              cts_type, closed_boundaries)

        # If prop_data is a string, we assume it is a field name
        if isinstance(prop_data, string_types):
            prop_data = self.grid.add_zeros('node', prop_data)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        if initial_state_grid is None:
            nsg = self.initialize_node_state_grid()
        else:
            try:
                nsg = initial_state_grid
                self.grid.at_node['node_state'][:] = nsg
            except:
                #TODO: use new Messaging capability
                print('If initial_state_grid given, must be array of int')
                raise

        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == 'raster':
            from landlab.ca.raster_cts import RasterCTS
            self.ca = RasterCTS(self.grid, ns_dict, xn_list, nsg, prop_data,
                                prop_reset_value)
        elif cts_type == 'oriented_raster':
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS
            self.ca = OrientedRasterCTS(self.grid, ns_dict, xn_list, nsg,
                                        prop_data, prop_reset_value)
        elif cts_type == 'hex':
            from landlab.ca.hex_cts import HexCTS
            self.ca = HexCTS(self.grid, ns_dict, xn_list, nsg, prop_data,
                             prop_reset_value)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS
            self.ca = OrientedHexCTS(self.grid, ns_dict, xn_list, nsg,
                                     prop_data, prop_reset_value)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots == True:
            self.initialize_plotting(**kwds)


    def _set_closed_boundaries_for_hex_grid(self, closed_boundaries):
        """Setup one or more closed boundaries for a hex grid.
        
        Parameters
        ----------
        closed_boundaries : 4-element tuple of bool\
            Whether right, top, left, and bottom edges have closed nodes
        
        Examples
        --------
        >>> from grainhill import CTSModel
        >>> cm = CTSModel(closed_boundaries=(True, True, True, True))
        >>> cm.grid.status_at_node
        array([4, 4, 4, 4, 4, 4, 0, 4, 0, 0, 4, 0, 4, 0, 0, 4, 0, 4, 0, 0, 4, 4,
               4, 4, 4], dtype=uint8)
        """
        g = self.grid
        if closed_boundaries[0]:
            g.status_at_node[g.nodes_at_right_edge] = CLOSED_BOUNDARY
        if closed_boundaries[1]:
            g.status_at_node[g.nodes_at_top_edge] = CLOSED_BOUNDARY
        if closed_boundaries[2]:
            g.status_at_node[g.nodes_at_left_edge] = CLOSED_BOUNDARY
        if closed_boundaries[3]:
            g.status_at_node[g.nodes_at_bottom_edge] = CLOSED_BOUNDARY


    def create_grid_and_node_state_field(self, num_rows, num_cols,
                                         grid_orientation, grid_shape,
                                         cts_type, closed_bounds):
        """Create the grid and the field containing node states."""

        if cts_type == 'raster' or cts_type == 'oriented_raster':
            from landlab import RasterModelGrid
            self.grid = RasterModelGrid(shape=(num_rows, num_cols),
                                        spacing=1.0)
            self.grid.set_closed_boundaries_at_grid_edges(closed_bounds[0],
                                                          closed_bounds[1],
                                                          closed_bounds[2],
                                                          closed_bounds[3])
        else:
            from landlab import HexModelGrid
            self.grid = HexModelGrid(num_rows, num_cols, 1.0,
                                     orientation=grid_orientation,
                                     shape=grid_shape)
            if True in closed_bounds:
                self._set_closed_boundaries_for_hex_grid(closed_bounds)

        self.grid.add_zeros('node', 'node_state', dtype=int)


    def node_state_dictionary(self):
        """Create and return a dictionary of all possible node (cell) states.

        This method creates a default set of states (just two); it is a
        template meant to be overridden.
        """
        ns_dict = { 0 : 'on',
                    1 : 'off'}
        return ns_dict


    def transition_list(self):
        """Create and return a list of transition objects.

        This method creates a default set of transitions (just two); it is a
        template meant to be overridden.
        """
        xn_list = []
        xn_list.append(Transition((0, 1, 0), (1, 0, 0), 1.0))
        xn_list.append(Transition((1, 0, 0), (0, 1, 0), 1.0))
        return xn_list


    def write_output(self, grid, outfilename, iteration):
        """Write output to file (currently netCDF)."""
        filename = outfilename + str(iteration).zfill(4) + '.nc'
        save_grid(grid, filename)


    def initialize_node_state_grid(self):
        """Initialize values in the node-state grid.

        This method should be overridden. The default is random "on" and "off".
        """
        num_states = 2
        for i in range(self.grid.number_of_nodes):
            self.grid.at_node['node_state'][i] = random.randint(num_states)
        return self.grid.at_node['node_state']


    def initialize_plotting(self, **kwds):
        """Create and configure CAPlotter object."""
        self.ca_plotter = CAPlotter(self.ca, **kwds)
        self.ca_plotter.update_plot()
        axis('off')


    def run_for(self, dt):

        self.ca.run(self.ca.current_time + dt, self.ca.node_state)
Esempio n. 12
0
class CTSModel(object):
    """
    Implement a generic CellLab-CTS model.

    This is the base class from which models should inherit.
    """

    def __init__(self, grid_size=(5, 5), report_interval=5.0,
                 grid_orientation='vertical', grid_shape='rect',
                 show_plots=False, cts_type='oriented_hex', 
                 run_duration=1.0, output_interval=1.0e99,
                 plot_every_transition=False, **kwds):

        self.initialize(grid_size, report_interval, grid_orientation,
                        grid_shape, show_plots, cts_type, run_duration,
                        output_interval, plot_every_transition, **kwds)


    def initialize(self, grid_size=(5, 5), report_interval=5.0,
                 grid_orientation='vertical', grid_shape='rect',
                 show_plots=False, cts_type='oriented_hex', 
                 run_duration=1.0, output_interval=1.0e99,
                 plot_every_transition=False, **kwds):
        
        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1], 
                                              grid_orientation, grid_shape,
                                              cts_type)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        nsg = self.initialize_node_state_grid()
        
        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == 'raster':
            from landlab.ca.raster_cts import RasterCTS
            self.ca = RasterCTS(self.grid, ns_dict, xn_list, nsg)
        elif cts_type == 'oriented_raster':
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS
            self.ca = OrientedRasterCTS(self.grid, ns_dict, xn_list, nsg)
        elif cts_type == 'hex':
            from landlab.ca.hex_cts import HexCTS
            self.ca = HexCTS(self.grid, ns_dict, xn_list, nsg)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS
            self.ca = OrientedHexCTS(self.grid, ns_dict, xn_list, nsg)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots == True:
            self.initialize_plotting(**kwds)


    def create_grid_and_node_state_field(self, num_rows, num_cols, 
                                         grid_orientation, grid_shape,
                                         cts_type):
        """Create the grid and the field containing node states."""

        if cts_type == 'raster' or cts_type == 'oriented_raster':
            from landlab import RasterModelGrid
            self.grid = RasterModelGrid(shape=(num_rows, num_cols),
                                        spacing=1.0)
        else:
            from landlab import HexModelGrid
            self.grid = HexModelGrid(num_rows, num_cols, 1.0, 
                                     orientation=grid_orientation, 
                                     shape=grid_shape)

        self.grid.add_zeros('node', 'node_state', dtype=int)


    def node_state_dictionary(self):
        """Create and return a dictionary of all possible node (cell) states.
        
        This method creates a default set of states (just two); it is a
        template meant to be overridden.
        """
        ns_dict = { 0 : 'on', 
                    1 : 'off'}
        return ns_dict


    def transition_list(self):
        """Create and return a list of transition objects.
        
        This method creates a default set of transitions (just two); it is a
        template meant to be overridden.
        """
        xn_list = []
        xn_list.append(Transition((0, 1, 0), (1, 0, 0), 1.0))
        xn_list.append(Transition((1, 0, 0), (0, 1, 0), 1.0))
        return xn_list


    def write_output(self, grid, outfilename, iteration):
        """Write output to file (currently netCDF)."""
        filename = outfilename + str(iteration).zfill(4) + '.nc'
        save_grid(grid, filename)


    def initialize_node_state_grid(self):
        """Initialize values in the node-state grid.
        
        This method should be overridden. The default is random "on" and "off".        
        """
        num_states = 2
        for i in range(self.grid.number_of_nodes):
            self.grid.at_node['node_state'][i] = random.randint(num_states)
        return self.grid.at_node['node_state']


    def initialize_plotting(self, **kwds):
        """Create and configure CAPlotter object."""
        self.ca_plotter = CAPlotter(self.ca, **kwds)
        self.ca_plotter.update_plot()
        axis('off')


    def run_for(self, dt):

        self.ca.run(self.ca.current_time + dt, self.ca.node_state)
Esempio n. 13
0
def main():
    
    # INITIALIZE
    
    # User-defined parameters
    nr = 41
    nc = 61
    g = 0.8
    f = 1.0
    plot_interval = 1.0
    run_duration = 200.0
    report_interval = 5.0  # report interval, in real-time seconds
    p_init = 0.4  # probability that a cell is occupied at start
    plot_every_transition = False
    
    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create a grid
    hmg = HexModelGrid(nr, nc, 1.0, orientation='vertical', reorient_links=True)
    
    # Close the grid boundaries
    #hmg.set_closed_nodes(hmg.open_boundary_nodes)
    
    # Set up the states and pair transitions.
    # Transition data here represent particles moving on a lattice: one state
    # per direction (for 6 directions), plus an empty state, a stationary
    # state, and a wall state.
    ns_dict = { 0 : 'empty', 
                1 : 'moving up',
                2 : 'moving right and up',
                3 : 'moving right and down',
                4 : 'moving down',
                5 : 'moving left and down',
                6 : 'moving left and up',
                7 : 'rest',
                8 : 'wall'}
    xn_list = setup_transition_list(g, f)

    # Create data and initialize values.
    node_state_grid = hmg.add_zeros('node', 'node_state_grid')
    
    # Make the grid boundary all wall particles
    node_state_grid[hmg.boundary_nodes] = 8
    
    # Seed the grid interior with randomly oriented particles
    for i in hmg.core_nodes:
        if random.random()<p_init:
            node_state_grid[i] = random.randint(1, 7)
    
    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)
    
    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca)
    
    # Plot the initial grid
    ca_plotter.update_plot()

    # RUN
    current_time = 0.0
    while current_time < run_duration:
        
        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print('Current sim time',current_time,'(',100*current_time/run_duration,'%)')
            next_report = current_real_time + report_interval
        
        # Run the model forward in time until the next output step
        ca.run(current_time+plot_interval, ca.node_state, 
               plot_each_transition=plot_every_transition, plotter=ca_plotter)
        current_time += plot_interval
        
        # Plot the current grid
        ca_plotter.update_plot()


    # FINALIZE

    # Plot
    ca_plotter.finalize()
def run(uplift_interval, d):  #d_ratio_exp):

    # INITIALIZE

    #uplift_interval = 1e7

    #filenm = 'test_output'
    #imagenm = 'Hill141213/hill'+str(int(d_ratio_exp))+'d'

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval
    next_uplift = uplift_interval

    # Create a grid
    hmg = HexModelGrid(nr,
                       nc,
                       1.0,
                       orientation='vertical',
                       shape='rect',
                       reorient_links=True)

    # Close the right-hand grid boundaries
    #hmg.set_closed_nodes(arange((nc-1)*nr, hmg.number_of_nodes))

    # Set up the states and pair transitions.
    # Transition data here represent particles moving on a lattice: one state
    # per direction (for 6 directions), plus an empty state, a stationary
    # state, and a wall state.
    ns_dict = {
        0: 'empty',
        1: 'moving up',
        2: 'moving right and up',
        3: 'moving right and down',
        4: 'moving down',
        5: 'moving left and down',
        6: 'moving left and up',
        7: 'rest',
        8: 'wall'
    }
    xn_list = setup_transition_list(g, f, d)
    #xn_list = []
    #xn_list.append( Transition((0,1,0), (0,7,0), g, 'gravity 1') )

    # Create data and initialize values.
    node_state_grid = hmg.add_zeros('node', 'node_state_grid', dtype=int)

    # Lower rows get resting particles
    if nc % 2 == 0:  # if even num cols, bottom right is...
        bottom_right = nc - 1
    else:
        bottom_right = nc // 2
    right_side_x = 0.866025403784 * (nc - 1)
    for i in range(hmg.number_of_nodes):
        if hmg.node_y[i] < 3.0:
            if hmg.node_x[i] > 0.0 and hmg.node_x[i] < right_side_x:
                node_state_grid[i] = 7
        #elif hmg.node_x[i]>((nc-1)*0.866):
        #    node_state_grid[i] = 8
    node_state_grid[0] = 8  # bottom left
    node_state_grid[bottom_right] = 8
    #for i in range(hmg.number_of_nodes):
    #    print i, hmg.node_x[i], hmg.node_y[i], node_state_grid[i]

    # Create an uplift object
    uplifter = LatticeUplifter(hmg, node_state_grid)

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    # Create a CAPlotter object for handling screen display
    # potential colors: red3='#CD0000'
    #mob = 'r'
    #rock = '#5F594D'
    sed = '#A4874B'
    #sky = '#CBD5E1'
    #sky = '#85A5CC'
    sky = '#D0E4F2'
    rock = '#000000'  #sky
    mob = '#D98859'
    #mob = '#DB764F'
    #mob = '#FFFF00'
    #sed = '#CAAE98'
    #clist = [(0.5, 0.9, 0.9),mob, mob, mob, mob, mob, mob,'#CD6839',(0.3,0.3,0.3)]
    clist = [sky, mob, mob, mob, mob, mob, mob, sed, rock]
    my_cmap = matplotlib.colors.ListedColormap(clist)
    ca_plotter = CAPlotter(ca, cmap=my_cmap)
    k = 0

    # Plot the initial grid
    ca_plotter.update_plot()
    axis('off')
    #savefig(imagenm+str(k)+'.png')
    k += 1

    # Write output for initial grid
    #write_output(hmg, filenm, 0)
    #output_iteration = 1

    # Create an array to store the numbers of states at each plot interval
    #nstates = zeros((9, int(run_duration/plot_interval)))
    #k = 0

    # Work out the next times to plot and output
    next_output = output_interval
    next_plot = plot_interval

    # RUN
    current_time = 0.0
    while current_time < run_duration:

        # Figure out what time to run to this iteration
        next_pause = min(next_output, next_plot)
        next_pause = min(next_pause, next_uplift)
        next_pause = min(next_pause, run_duration)

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print 'Current sim time', current_time, '(', 100 * current_time / run_duration, '%)'
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        print('Running to...' + str(next_pause))
        ca.run(next_pause, ca.node_state)  #,
        #plot_each_transition=plot_every_transition, plotter=ca_plotter)
        current_time = next_pause

        # Handle output to file
        if current_time >= next_output:
            #write_output(hmg, filenm, output_iteration)
            #output_iteration += 1
            next_output += output_interval

        # Handle plotting on display
        if current_time >= next_plot:
            #node_state_grid[hmg.number_of_node_rows-1] = 8
            ca_plotter.update_plot()
            axis('off')
            next_plot += plot_interval

        # Handle uplift
        if current_time >= next_uplift:
            uplifter.uplift_interior_nodes(rock_state=7)
            ca.update_link_states_and_transitions(current_time)
            next_uplift += uplift_interval

    print('Finished with main loop')
Esempio n. 15
0
    def initialize(self,
                   grid_size=(5, 5),
                   report_interval=5.0,
                   grid_orientation='vertical',
                   grid_shape='rect',
                   show_plots=False,
                   cts_type='oriented_hex',
                   run_duration=1.0,
                   output_interval=1.0e99,
                   plot_every_transition=False,
                   initial_state_grid=None,
                   prop_data=None,
                   prop_reset_value=None,
                   **kwds):
        """Initialize CTSModel."""
        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1],
                                              grid_orientation, grid_shape,
                                              cts_type)

        # If prop_data is a string, we assume it is a field name
        if isinstance(prop_data, string_types):
            prop_data = self.grid.add_zeros('node', prop_data)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        if initial_state_grid is None:
            nsg = self.initialize_node_state_grid()
        else:
            try:
                nsg = initial_state_grid
                self.grid.at_node['node_state'][:] = nsg
            except:
                #TODO: use new Messaging capability
                print('If initial_state_grid given, must be array of int')
                raise

        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == 'raster':
            from landlab.ca.raster_cts import RasterCTS
            self.ca = RasterCTS(self.grid, ns_dict, xn_list, nsg, prop_data,
                                prop_reset_value)
        elif cts_type == 'oriented_raster':
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS
            self.ca = OrientedRasterCTS(self.grid, ns_dict, xn_list, nsg,
                                        prop_data, prop_reset_value)
        elif cts_type == 'hex':
            from landlab.ca.hex_cts import HexCTS
            self.ca = HexCTS(self.grid, ns_dict, xn_list, nsg, prop_data,
                             prop_reset_value)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS
            self.ca = OrientedHexCTS(self.grid, ns_dict, xn_list, nsg,
                                     prop_data, prop_reset_value)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots == True:
            self.initialize_plotting(**kwds)
def main():

    # INITIALIZE

    # User-defined parameters
    nr = 41
    nc = 61
    g = 0.05
    plot_interval = 1.0
    run_duration = 100.0
    report_interval = 5.0  # report interval, in real-time seconds
    p_init = 0.1  # probability that a cell is occupied at start
    plot_every_transition = False

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create a grid
    hmg = HexModelGrid(nr,
                       nc,
                       1.0,
                       orientation='vertical',
                       reorient_links=True)

    # Close the grid boundaries
    #hmg.set_closed_nodes(hmg.open_boundary_nodes)

    # Set up the states and pair transitions.
    # Transition data here represent particles moving on a lattice: one state
    # per direction (for 6 directions), plus an empty state, a stationary
    # state, and a wall state.
    ns_dict = {
        0: 'empty',
        1: 'moving up',
        2: 'moving right and up',
        3: 'moving right and down',
        4: 'moving down',
        5: 'moving left and down',
        6: 'moving left and up',
        7: 'rest',
        8: 'wall'
    }
    xn_list = setup_transition_list(g)

    # Create data and initialize values.
    node_state_grid = hmg.add_zeros('node', 'node_state_grid')

    # Make the grid boundary all wall particles
    node_state_grid[hmg.boundary_nodes] = 8

    # Seed the grid interior with randomly oriented particles
    for i in hmg.core_nodes:
        if random.random() < p_init:
            node_state_grid[i] = random.randint(1, 7)

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca)

    # Plot the initial grid
    ca_plotter.update_plot()

    # RUN
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print('Current sim time', current_time, '(',
                  100 * current_time / run_duration, '%)')
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(current_time + plot_interval,
               ca.node_state,
               plot_each_transition=plot_every_transition,
               plotter=ca_plotter)
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()

    # FINALIZE

    # Plot
    ca_plotter.finalize()
Esempio n. 17
0
def main():

    # INITIALIZE

    # User-defined parameters
    nr = 52
    nc = 120
    plot_interval = 1.0
    run_duration = 100.0
    report_interval = 5.0  # report interval, in real-time seconds
    p_init = 0.1  # probability that a cell is occupied at start
    plot_every_transition = False

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create a grid
    hmg = HexModelGrid(nr, nc, 1.0, orientation='vertical', reorient_links=True)

    # Close the grid boundaries
    #hmg.set_closed_nodes(hmg.open_boundary_nodes)

    # Set up the states and pair transitions.
    # Transition data here represent particles moving on a lattice: one state
    # per direction (for 6 directions), plus an empty state, a stationary
    # state, and a wall state.
    ns_dict = { 0 : 'empty',
                1 : 'moving up',
                2 : 'moving right and up',
                3 : 'moving right and down',
                4 : 'moving down',
                5 : 'moving left and down',
                6 : 'moving left and up',
                7 : 'rest',
                8 : 'wall'}
    xn_list = setup_transition_list()

    # Create data and initialize values.
    node_state_grid = hmg.add_zeros('node', 'node_state_grid', dtype=int)

    # Make the grid boundary all wall particles
    node_state_grid[hmg.boundary_nodes] = 8

    # Seed the grid interior with randomly oriented particles
    for i in hmg.core_nodes:
        if random.random()<p_init:
            node_state_grid[i] = random.randint(1, 7)

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    # Set up a color map for plotting
    import matplotlib
    clist = [ (1.0, 1.0, 1.0),   # empty = white
              (1.0, 0.0, 0.0),   # up = red
              (1.0, 1.0, 0.0),   # right-up = yellow
              (0.0, 1.0, 0.0),   # down-up = green
              (0.0, 1.0, 1.0),   # down = cyan
              (0.0, 0.0, 1.0),   # left-down = blue
              (1.0, 0.0, 1.0),   # left-up = magenta
              (0.5, 0.5, 0.5),   # resting = gray
              (0.0, 0.0, 0.0) ]   # wall = black
    my_cmap = matplotlib.colors.ListedColormap(clist)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca, cmap=my_cmap)

    # Plot the initial grid
    ca_plotter.update_plot()

    # Create an array to store the numbers of states at each plot interval
    nstates = zeros((9, int(run_duration/plot_interval)))
    k = 0

    # RUN
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print('Current sim time',current_time,'(',100*current_time/run_duration,'%)')
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(current_time+plot_interval, ca.node_state,
               plot_each_transition=plot_every_transition, plotter=ca_plotter)
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()
        axis('off')

        # Record numbers in each state
        nstates[:,k] = bincount(node_state_grid)
        k += 1

    # FINALIZE

    # Plot
    ca_plotter.finalize()

    # Display the numbers of each state
    fig, ax = subplots()
    for i in range(1, 8):
        plot(arange(plot_interval, run_duration+plot_interval, plot_interval), nstates[i,:], label=ns_dict[i], color=clist[i])
    ax.legend()
    xlabel('Time')
    ylabel('Number of particles in state')
    title('Particle distribution by state')
    axis([0, run_duration, 0, 2*nstates[7,0]])
    show()
Esempio n. 18
0
class CTSModel(object):
    """
    Implement a generic CellLab-CTS model.

    This is the base class from which models should inherit.
    """
    def __init__(self,
                 grid_size=(5, 5),
                 report_interval=5.0,
                 grid_orientation="vertical",
                 node_layout="rect",
                 show_plots=False,
                 cts_type="oriented_hex",
                 run_duration=1.0,
                 output_interval=1.0e99,
                 plot_every_transition=False,
                 **kwds):

        self.initialize(grid_size, report_interval, grid_orientation,
                        node_layout, show_plots, cts_type, run_duration,
                        output_interval, plot_every_transition, **kwds)

    def initialize(self,
                   grid_size=(5, 5),
                   report_interval=5.0,
                   grid_orientation="vertical",
                   node_layout="rect",
                   show_plots=False,
                   cts_type="oriented_hex",
                   run_duration=1.0,
                   output_interval=1.0e99,
                   plot_every_transition=False,
                   **kwds):

        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1],
                                              grid_orientation, node_layout,
                                              cts_type)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        nsg = self.initialize_node_state_grid()

        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == "raster":
            from landlab.ca.raster_cts import RasterCTS

            self.ca = RasterCTS(self.grid, ns_dict, xn_list, nsg)
        elif cts_type == "oriented_raster":
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS

            self.ca = OrientedRasterCTS(self.grid, ns_dict, xn_list, nsg)
        elif cts_type == "hex":
            from landlab.ca.hex_cts import HexCTS

            self.ca = HexCTS(self.grid, ns_dict, xn_list, nsg)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS

            self.ca = OrientedHexCTS(self.grid, ns_dict, xn_list, nsg)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots:
            self.initialize_plotting(**kwds)

    def create_grid_and_node_state_field(self, num_rows, num_cols,
                                         grid_orientation, node_layout,
                                         cts_type):
        """Create the grid and the field containing node states."""

        if cts_type == "raster" or cts_type == "oriented_raster":
            from landlab import RasterModelGrid

            self.grid = RasterModelGrid(shape=(num_rows, num_cols),
                                        xy_spacing=1.0)
        else:
            from landlab import HexModelGrid

            self.grid = HexModelGrid(
                (num_rows, num_cols),
                spacing=1.0,
                orientation=grid_orientation,
                node_layout=node_layout,
            )

        self.grid.add_zeros("node_state", at="node", dtype=int)

    def node_state_dictionary(self):
        """Create and return a dictionary of all possible node (cell) states.

        This method creates a default set of states (just two); it is a
        template meant to be overridden.
        """
        ns_dict = {0: "on", 1: "off"}
        return ns_dict

    def transition_list(self):
        """Create and return a list of transition objects.

        This method creates a default set of transitions (just two); it is a
        template meant to be overridden.
        """
        xn_list = []
        xn_list.append(Transition((0, 1, 0), (1, 0, 0), 1.0))
        xn_list.append(Transition((1, 0, 0), (0, 1, 0), 1.0))
        return xn_list

    def write_output(self, grid, outfilename, iteration):
        """Write output to file (currently netCDF)."""
        filename = outfilename + str(iteration).zfill(4) + ".nc"
        save_grid(grid, filename)

    def initialize_node_state_grid(self):
        """Initialize values in the node-state grid.

        This method should be overridden. The default is random "on" and "off".
        """
        num_states = 2
        for i in range(self.grid.number_of_nodes):
            self.grid.at_node["node_state"][i] = random.randint(num_states)
        return self.grid.at_node["node_state"]

    def initialize_plotting(self, **kwds):
        """Create and configure CAPlotter object."""
        self.ca_plotter = CAPlotter(self.ca, **kwds)
        self.ca_plotter.update_plot()
        axis("off")

    def run_for(self, dt):

        self.ca.run(self.ca.current_time + dt, self.ca.node_state)
Esempio n. 19
0
def main():

    # INITIALIZE

    # User-defined parameters
    nr = 21
    nc = 21
    plot_interval = 0.5
    run_duration = 25.0
    report_interval = 5.0  # report interval, in real-time seconds

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create a grid
    hmg = HexModelGrid(nr, nc, 1.0, orientation="vertical", reorient_links=True)

    # Close the grid boundaries
    hmg.set_closed_nodes(hmg.open_boundary_nodes)

    # Set up the states and pair transitions.
    # Transition data here represent the disease status of a population.
    ns_dict = {0: "fluid", 1: "grain"}
    xn_list = setup_transition_list()

    # Create data and initialize values. We start with the 3 middle columns full
    # of grains, and the others empty.
    node_state_grid = hmg.add_zeros("node", "node_state_grid")
    middle = 0.25 * (nc - 1) * sqrt(3)
    is_middle_cols = logical_and(hmg.node_x < middle + 1., hmg.node_x > middle - 1.)
    node_state_grid[where(is_middle_cols)[0]] = 1

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca)

    # Plot the initial grid
    ca_plotter.update_plot()

    # RUN
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print(
                "Current sim time",
                current_time,
                "(",
                100 * current_time / run_duration,
                "%)",
            )
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(current_time + plot_interval, ca.node_state, plot_each_transition=False)
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()

    # FINALIZE

    # Plot
    ca_plotter.finalize()
Esempio n. 20
0
def main():

    # INITIALIZE

    # User-defined parameters
    nr = 52
    nc = 120
    plot_interval = 10.0
    run_duration = 1000.0
    report_interval = 5.0  # report interval, in real-time seconds
    p_init = 0.1  # probability that a cell is occupied at start
    plot_every_transition = False

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create a grid
    hmg = HexModelGrid(nr,
                       nc,
                       1.0,
                       orientation="vertical",
                       reorient_links=True)

    # Set up the states and pair transitions.
    # Transition data here represent particles moving on a lattice: one state
    # per direction (for 6 directions), plus an empty state, a stationary
    # state, and a wall state.
    ns_dict = {
        0: "empty",
        1: "moving up",
        2: "moving right and up",
        3: "moving right and down",
        4: "moving down",
        5: "moving left and down",
        6: "moving left and up",
        7: "rest",
        8: "wall",
    }
    xn_list = setup_transition_list()

    # Create data and initialize values.
    node_state_grid = hmg.add_zeros("node", "node_state_grid", dtype=int)

    # Make the grid boundary all wall particles
    node_state_grid[hmg.boundary_nodes] = 8

    # Seed the grid interior with randomly oriented particles
    for i in hmg.core_nodes:
        if random.random() < p_init:
            node_state_grid[i] = random.randint(1, 7)

    # Create the CA model
    ca = OrientedHexCTS(hmg, ns_dict, xn_list, node_state_grid)

    # Set up a color map for plotting
    import matplotlib

    clist = [
        (1.0, 1.0, 1.0),  # empty = white
        (1.0, 0.0, 0.0),  # up = red
        (0.8, 0.8, 0.0),  # right-up = yellow
        (0.0, 1.0, 0.0),  # down-up = green
        (0.0, 1.0, 1.0),  # down = cyan
        (0.0, 0.0, 1.0),  # left-down = blue
        (1.0, 0.0, 1.0),  # left-up = magenta
        (0.5, 0.5, 0.5),  # resting = gray
        (0.0, 0.0, 0.0),
    ]  # wall = black
    line_styles = ["", "-", "--", "--", "-", "-.", "--", "-", ":"]
    my_cmap = matplotlib.colors.ListedColormap(clist)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca, cmap=my_cmap)

    # Plot the initial grid
    ca_plotter.update_plot()

    # Create an array to store the numbers of states at each plot interval
    nstates = zeros((9, int(run_duration / plot_interval)))
    k = 0

    # RUN
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation and real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print(
                "Current sim time",
                current_time,
                "(",
                100 * current_time / run_duration,
                "%)",
            )
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(
            current_time + plot_interval,
            ca.node_state,
            plot_each_transition=plot_every_transition,
            plotter=ca_plotter,
        )
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()
        axis("off")

        # Record numbers in each state
        nstates[:, k] = bincount(node_state_grid)
        k += 1

    # FINALIZE

    # Plot
    ca_plotter.finalize()

    # Display the numbers of each state
    fig, ax = subplots()
    for i in range(1, 8):
        plot(
            arange(plot_interval, run_duration + plot_interval, plot_interval),
            nstates[i, :],
            label=ns_dict[i],
            color=clist[i],
            linestyle=line_styles[i],
        )
    ax.legend()
    xlabel("Time")
    ylabel("Number of particles in state")
    title("Particle distribution by state")
    axis([0, run_duration, 0, 2 * nstates[7, 0]])
    show()
def main():
    # INITIALIZE
    # User-defined parameters
    nr = 80  # number of rows in grid
    nc = 50  # number of columns in grid
    plot_interval = 0.5  # time interval for plotting, sec
    run_duration = 20.0  # duration of run, sec
    report_interval = 10.0  # report interval, in real-time seconds

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create grid
    mg = HexModelGrid(nr, nc, 1.0)

    # Make the boundaries be walls
    #    mg.set_closed_boundaries_at_grid_edges(True, True, True, True)<--I am not sure what the equivalent is for hexgrid

    #Create a node-state dictionary
    ns_dict = {0: 'fluid', 1: 'particle'}

    #Create the transition list
    xn_list = setup_transition_list()

    # Create the node-state array and attach it to the grid
    node_state_grid = mg.add_zeros('node', 'node_state_map', dtype=int)

    # Initialize the node-state array: here, the initial condition is a pile of
    # resting grains at the bottom of a container.
    bottom_rows = where(mg.node_y < 0.1 * nr)[0]
    node_state_grid[bottom_rows] = 1

    # For visual display purposes, set all boundary nodes to fluid
    node_state_grid[mg.closed_boundary_nodes] = 0

    # Create the CA model
    ca = OrientedHexCTS(mg, ns_dict, xn_list, node_state_grid)

    # Set up colors for plotting
    grain = '#5F594D'
    fluid = '#D0E4F2'
    clist = [fluid, grain]
    my_cmap = matplotlib.colors.ListedColormap(clist)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca, cmap=my_cmap)

    # Plot the initial grid
    ca_plotter.update_plot()

    # RUN
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print('Current simulation time ' + str(current_time) + '  \
                   (' + str(int(100 * current_time / run_duration)) + '%)')
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(current_time + plot_interval,
               ca.node_state,
               plot_each_transition=False)
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()

    ca_plotter.finalize()
def main():
    # INITIALIZE
    # User-defined parameters
    nr = 80  # number of rows in grid
    nc = 50  # number of columns in grid
    plot_interval = 0.5   # time interval for plotting, sec
    run_duration = 20.0   # duration of run, sec
    report_interval = 10.0  # report interval, in real-time seconds

    # Remember the clock time, and calculate when we next want to report
    # progress.
    current_real_time = time.time()
    next_report = current_real_time + report_interval

    # Create grid
    mg = HexModelGrid(nr, nc, 1.0)

    # Make the boundaries be walls
#    mg.set_closed_boundaries_at_grid_edges(True, True, True, True)<--I am not sure what the equivalent is for hexgrid

    #Create a node-state dictionary
    ns_dict = { 0 : 'fluid', 1 : 'particle' }

    #Create the transition list
    xn_list = setup_transition_list()

    # Create the node-state array and attach it to the grid
    node_state_grid = mg.add_zeros('node', 'node_state_map', dtype=int)

    # Initialize the node-state array: here, the initial condition is a pile of
    # resting grains at the bottom of a container.
    bottom_rows = where(mg.node_y<0.1*nr)[0]
    node_state_grid[bottom_rows] = 1

    # For visual display purposes, set all boundary nodes to fluid
    node_state_grid[mg.closed_boundary_nodes] = 0

    # Create the CA model
    ca = OrientedHexCTS(mg, ns_dict, xn_list, node_state_grid)

    # Set up colors for plotting
    grain = '#5F594D'
    fluid = '#D0E4F2'
    clist = [fluid,grain]
    my_cmap = matplotlib.colors.ListedColormap(clist)

    # Create a CAPlotter object for handling screen display
    ca_plotter = CAPlotter(ca, cmap=my_cmap)

    # Plot the initial grid
    ca_plotter.update_plot()

    # RUN
    current_time = 0.0
    while current_time < run_duration:

        # Once in a while, print out simulation real time to let the user
        # know that the sim is running ok
        current_real_time = time.time()
        if current_real_time >= next_report:
            print('Current simulation time '+str(current_time)+'  \
                   ('+str(int(100*current_time/run_duration))+'%)')
            next_report = current_real_time + report_interval

        # Run the model forward in time until the next output step
        ca.run(current_time+plot_interval, ca.node_state, plot_each_transition=False)
        current_time += plot_interval

        # Plot the current grid
        ca_plotter.update_plot()
    
    ca_plotter.finalize()
Esempio n. 23
0
class CTSModel(object):
    """
    Implement a generic CellLab-CTS model.

    This is the base class from which models should inherit.
    """
    def __init__(self,
                 grid_size=(5, 5),
                 report_interval=5.0,
                 grid_orientation='vertical',
                 node_layout='rect',
                 show_plots=False,
                 cts_type='oriented_hex',
                 run_duration=1.0,
                 output_interval=1.0e99,
                 plot_every_transition=False,
                 initial_state_grid=None,
                 prop_data=None,
                 prop_reset_value=None,
                 seed=0,
                 closed_boundaries=(False, False, False, False),
                 **kwds):

        self.initialize(grid_size, report_interval, grid_orientation,
                        node_layout, show_plots, cts_type, run_duration,
                        output_interval, plot_every_transition,
                        initial_state_grid, prop_data, prop_reset_value, seed,
                        closed_boundaries, **kwds)

    def initialize(self,
                   grid_size=(5, 5),
                   report_interval=5.0,
                   grid_orientation='vertical',
                   node_layout='rect',
                   show_plots=False,
                   cts_type='oriented_hex',
                   run_duration=1.0,
                   output_interval=1.0e99,
                   plot_every_transition=False,
                   initial_state_grid=None,
                   prop_data=None,
                   prop_reset_value=None,
                   seed=0,
                   closed_boundaries=(False, False, False, False),
                   **kwds):
        """Initialize CTSModel."""

        # Remember the clock time, and calculate when we next want to report
        # progress.
        self.current_real_time = time.time()
        self.next_report = self.current_real_time + report_interval
        self.report_interval = report_interval

        # Interval for output
        self.output_interval = output_interval

        # Duration for run
        self.run_duration = run_duration

        # Create a grid
        self.create_grid_and_node_state_field(grid_size[0], grid_size[1],
                                              grid_orientation, node_layout,
                                              cts_type, closed_boundaries)

        # If prop_data is a string, we assume it is a field name
        if isinstance(prop_data, string_types):
            prop_data = self.grid.add_zeros('node', prop_data)

        # Create the node-state dictionary
        ns_dict = self.node_state_dictionary()

        # Initialize values of the node-state grid
        if initial_state_grid is None:
            nsg = self.initialize_node_state_grid()
        else:
            try:
                nsg = initial_state_grid
                self.grid.at_node['node_state'][:] = nsg
            except TypeError:
                print('If initial_state_grid given, must be array of int')
                raise

        # Create the transition list
        xn_list = self.transition_list()

        # Create the CA object
        if cts_type == 'raster':
            from landlab.ca.raster_cts import RasterCTS
            self.ca = RasterCTS(self.grid,
                                ns_dict,
                                xn_list,
                                nsg,
                                prop_data,
                                prop_reset_value,
                                seed=seed)
        elif cts_type == 'oriented_raster':
            from landlab.ca.oriented_raster_cts import OrientedRasterCTS
            self.ca = OrientedRasterCTS(self.grid,
                                        ns_dict,
                                        xn_list,
                                        nsg,
                                        prop_data,
                                        prop_reset_value,
                                        seed=seed)
        elif cts_type == 'hex':
            from landlab.ca.hex_cts import HexCTS
            self.ca = HexCTS(self.grid,
                             ns_dict,
                             xn_list,
                             nsg,
                             prop_data,
                             prop_reset_value,
                             seed=seed)
        else:
            from landlab.ca.oriented_hex_cts import OrientedHexCTS
            self.ca = OrientedHexCTS(self.grid,
                                     ns_dict,
                                     xn_list,
                                     nsg,
                                     prop_data,
                                     prop_reset_value,
                                     seed=seed)

        # Initialize graphics
        self._show_plots = show_plots
        if show_plots:
            self.initialize_plotting(**kwds)

    def _set_closed_boundaries_for_hex_grid(self, closed_boundaries):
        """Setup one or more closed boundaries for a hex grid.

        Parameters
        ----------
        closed_boundaries : 4-element tuple of bool\
            Whether right, top, left, and bottom edges have closed nodes

        Examples
        --------
        >>> from grainhill import CTSModel
        >>> cm = CTSModel(closed_boundaries=(True, True, True, True))
        >>> cm.grid.status_at_node  # doctest: +NORMALIZE_WHITESPACE
        array([4, 4, 4, 4, 4, 4, 0, 4, 0, 0, 4, 0, 4, 0, 0, 4, 0, 4, 0, 0, 4,
               4, 4, 4, 4], dtype=uint8)
        """
        g = self.grid
        if closed_boundaries[0]:
            g.status_at_node[g.nodes_at_right_edge] = g.BC_NODE_IS_CLOSED
        if closed_boundaries[1]:
            g.status_at_node[g.nodes_at_top_edge] = g.BC_NODE_IS_CLOSED
        if closed_boundaries[2]:
            g.status_at_node[g.nodes_at_left_edge] = g.BC_NODE_IS_CLOSED
        if closed_boundaries[3]:
            g.status_at_node[g.nodes_at_bottom_edge] = g.BC_NODE_IS_CLOSED

    def create_grid_and_node_state_field(self, num_rows, num_cols,
                                         grid_orientation, node_layout,
                                         cts_type, closed_bounds):
        """Create the grid and the field containing node states."""

        if cts_type == 'raster' or cts_type == 'oriented_raster':
            from landlab import RasterModelGrid
            self.grid = RasterModelGrid(shape=(num_rows, num_cols),
                                        spacing=1.0)
            self.grid.set_closed_boundaries_at_grid_edges(
                closed_bounds[0], closed_bounds[1], closed_bounds[2],
                closed_bounds[3])
        else:
            from landlab import HexModelGrid
            self.grid = HexModelGrid(shape=(num_rows, num_cols),
                                     spacing=1.0,
                                     orientation=grid_orientation,
                                     node_layout=node_layout)
            if True in closed_bounds:
                self._set_closed_boundaries_for_hex_grid(closed_bounds)

        self.grid.add_zeros('node', 'node_state', dtype=int)

    def node_state_dictionary(self):
        """Create and return a dictionary of all possible node (cell) states.

        This method creates a default set of states (just two); it is a
        template meant to be overridden.
        """
        ns_dict = {0: 'on', 1: 'off'}
        return ns_dict

    def transition_list(self):
        """Create and return a list of transition objects.

        This method creates a default set of transitions (just two); it is a
        template meant to be overridden.
        """
        xn_list = []
        xn_list.append(Transition((0, 1, 0), (1, 0, 0), 1.0))
        xn_list.append(Transition((1, 0, 0), (0, 1, 0), 1.0))
        return xn_list

    def write_output(self, grid, outfilename, iteration):
        """Write output to file (currently netCDF)."""
        filename = outfilename + str(iteration).zfill(4) + '.nc'
        save_grid(grid, filename)

    def initialize_node_state_grid(self):
        """Initialize values in the node-state grid.

        This method should be overridden. The default is random "on" and "off".
        """
        num_states = 2
        for i in range(self.grid.number_of_nodes):
            self.grid.at_node['node_state'][i] = random.randint(num_states)
        return self.grid.at_node['node_state']

    def initialize_plotting(self, **kwds):
        """Create and configure CAPlotter object."""
        self.ca_plotter = CAPlotter(self.ca, **kwds)
        self.ca_plotter.update_plot()
        axis('off')

    def run_for(self, dt):

        self.ca.run(self.ca.current_time + dt, self.ca.node_state)