Example #1
0
class Animator:
    """Creator for animations and videos."""

    def __init__(self, hist_path: str, varname="b", video_path=None, visible=True):
        """Open the history file and load parts of it."""
        # Initialize self.hist_file to prevent the destructor from failing
        self.hist_file = None

        self.tracers = False

        # Set parameters needed for Plotting that cannot be determined
        # so far; maybe make them command line arguments in the future
        param = {
            "figsize": (RESOLUTION[0] / DPI, RESOLUTION[1] / DPI),
            "aspect": "equal",
            "rotation_speed": 3,
        }

        if varname == "b":
            param["style"] = "b-interface"
            param["stable_stratification"] = True  # TODO make this a command line argument
        elif varname == "t0":
            # This option is to plot only the first tracer and also a
            # shorter notation in the common case with only one tracer
            param["style"] = "tracer"
            param["n_tracers"] = 1
        elif varname == "tracer":
            param["style"] = "tracer"
            self.tracers = True
        else:
            raise NotImplementedError("The given variable is not yet supported.")

        # Save necessary arguments
        self.video_path = video_path
        self.visible = visible

        # Create the metadata for the video
        if self.video_path:
            # Extract the name of the experiment
            exp_name = os.path.basename(
                hist_path[:-8] if hist_path.endswith("_hist.nc") else hist_path
            )
            self.metadata = {
                "title": "Nyles experiment {}".format(exp_name),
                "artist": CREATOR,
                "genre": "Computational Fluid Dynamics (CFD)",
                "comment": "Created on {} with Nyles.  Nyles is a Large Eddy "
                           "Simulation written in Python.  For more information"
                           " visit https://github.com/pvthinker/Nyles."
                           .format(time.strftime('%d %b %Y')),
                "date": time.strftime("%Y-%m-%d"),
            }

        # Open the history file and keep it open to allow sequential reading
        print("Loading history file {!r}:".format(hist_path))
        self.hist_file = nc.Dataset(hist_path)
        print(self.hist_file)

        # Load the needed data
        if self.tracers:
            param["n_tracers"] = self.hist_file.n_tracers
            self.tracers_data = [
                self.hist_file["t{}".format(i)] for i in range(self.hist_file.n_tracers)
            ]
        else:
            self.vardata = self.hist_file[varname]
        self.t = self.hist_file["t"]
        self.n = self.hist_file["n"]
        self.n_frames = self.n.size

        # Load parameters needed for Grid
        param["Lx"] = self.hist_file.Lx
        param["Ly"] = self.hist_file.Ly
        param["Lz"] = self.hist_file.Lz
        param["nx"] = self.hist_file.global_nx
        param["ny"] = self.hist_file.global_ny
        param["nz"] = self.hist_file.global_nz

        # Set parameters needed for Scalar
        param["nh"] = 0
        param["neighbours"] = {}

        grid = Grid(param)

        # Create one or several Scalar variables as an interface for
        # passing data to the plotting module.  Note: Scalar takes
        # actually a dimension instead of a unit, but that does not
        # matter because this information is not processed here.
        if self.tracers:
            tracer_list = []
            self.arrays = []
            for data in self.tracers_data:
                tracer = Scalar(param, data.long_name, data.name, data.units)
                tracer_list.append(tracer)
                self.arrays.append(tracer.view("i"))
            state = State(tracer_list)
        else:
            scalar = Scalar(param, self.vardata.long_name, varname, self.vardata.units)
            self.array = scalar.view("i")
            state = State([scalar])

        self.p = Plotting(param, state, grid)

    def __del__(self):
        """Close the history file in the destructor."""
        if self.hist_file:
            self.hist_file.close()

    def init(self):
        """Show the inital frame."""
        if self.tracers:
            print("Variable:", len(self.tracers_data), "tracer")
        else:
            print("Variable:", self.vardata.long_name)
        print("Number of frames:", self.n_frames)
        if self.video_path:
            print("Output file:", self.video_path, end="")
            if os.path.exists(self.video_path):
                print(" -- file exists already and will be overwritten!")
            else:
                print("")
            if not self.visible:
                print("Fast mode: no animation will be visible during the process.")
            else:
                print('Slow mode: call script with "--fast" to speed up the video creation.')
        else:
            print("No video will be created.")
        # Load the initial data and show it
        if self.tracers:
            for array, data in zip(self.arrays, self.tracers_data):
                array[...] = data[0]
        else:
            self.array[...] = self.vardata[0]
        self.p.init(self.t[0], self.n[0])

    def run(self):
        """Create the animation and optionally save it."""
        if not self.video_path:
            plt.ioff()
        self.anim = animation.FuncAnimation(
            self.p.fig,
            self.update,
            frames=self.n_frames,
            repeat=False,
            interval=0,
        )
        if self.visible:
            plt.show()
        if self.video_path:
            self.anim.save(
                self.video_path,
                fps=FPS,
                dpi=DPI,
                bitrate=BPS,
                metadata=self.metadata,
            )

    def update(self, frame):
        """Load the data of the given frame and display it."""
        print("\rProcessing frame {} of {} ...".format(frame+1, self.n_frames), end="")
        # Load the data and show it
        if self.tracers:
            for array, data in zip(self.arrays, self.tracers_data):
                array[...] = data[frame]
        else:
            self.array[...] = self.vardata[frame]
        self.p.update(self.t[frame], self.n[frame], self.visible)
        # At the end
        if frame + 1 == self.n_frames:
            if self.video_path:
                print("\b\b\b-- saved.")
            else:
                print("\b\b\b-- finished.")
                plt.pause(0.5)
                plt.close(self.p.fig)