class VideoWriter:
  def __init__(self, filename, fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()
Beispiel #2
0
def test_moviepy(infile, out_dir):
    import numpy as np
    import pipi
    from moviepy.video.io.ffmpeg_reader import FFMPEG_VideoReader
    from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
    chunksize = 128

    vr = FFMPEG_VideoReader(infile)
    w,h = vr.size

    chunk = np.zeros((h, w, vr.depth, chunksize), dtype=np.uint8)
    with pipi.Timer("mp read..."):
        for i in range(chunksize):

            frame = vr.read_frame()
            chunk[...,i] = frame

    vw = FFMPEG_VideoWriter(os.path.join(out_dir, "mp_ov.mkv"), (w,h), 30, codec="libx264", preset="fast", ffmpeg_params=["-crf", "0"])

    with pipi.Timer("mp write..."):
        for i in range(chunksize):
            vw.write_frame(chunk[...,i])

    vr.close()
    vw.close()
Beispiel #3
0
class VideoWriter:
    # name is just reused, does not actually autoplay...
    def __init__(self, filename="_autoplay.mp4", fps=30.0, **kw):
        self.writer = None
        self.params = dict(filename=filename, fps=fps, **kw)

    def add(self, img):
        img = np.asarray(img)
        if self.writer is None:
            h, w = img.shape[:2]
            self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
        if img.dtype in [np.float32, np.float64]:
            img = np.uint8(img.clip(0, 1) * 255)
        if len(img.shape) == 2:
            img = np.repeat(img[..., None], 3, -1)
        if len(img.shape) == 3 and img.shape[-1] == 4:
            img = img[..., :3] * img[..., 3, None]
        self.writer.write_frame(img)

    def close(self):
        if self.writer:
            self.writer.close()

    def __enter__(self):
        return self

    def __exit__(self, *kw):
        self.close()
Beispiel #4
0
def test_moviepy(infile, out_dir):
    import numpy as np
    import pipi
    from moviepy.video.io.ffmpeg_reader import FFMPEG_VideoReader
    from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
    chunksize = 128

    vr = FFMPEG_VideoReader(infile)
    w,h = vr.size

    chunk = np.zeros((h, w, vr.depth, chunksize), dtype=np.uint8)
    with pipi.Timer("mp read..."):
        for i in range(chunksize):

            frame = vr.read_frame()
            chunk[...,i] = frame

    vw = FFMPEG_VideoWriter(os.path.join(out_dir, "mp_ov.mkv"), (w,h), 30, codec="libx264", preset="fast", ffmpeg_params=["-crf", "0"])

    with pipi.Timer("mp write..."):
        for i in range(chunksize):
            vw.write_frame(chunk[...,i])

    vr.close()
    vw.close()
Beispiel #5
0
def MakeVideo(m,fps=60.0,skip=1,video_fn='./tmp.mp4'):
    obs = m.get_observations() 
    frames = m.get_frames() 
    size_x,size_y = frames.shape[1:3]   

    writer = FFMPEG_VideoWriter(video_fn, (size_y, size_x), fps)
    for x in range(0,frames.shape[0],skip):
        writer.write_frame(frames[x])
    writer.close()
Beispiel #6
0
class RendererVideoExporter(RendererObserver):
    def __init__(self, filename: str, width: int, height: int, fps: float):
        super(RendererVideoExporter, self).__init__()
        if not isinstance(filename, str):
            raise TypeError()
        if not isinstance(width, int):
            raise TypeError()
        elif width < 1:
            raise ValueError()
        if not isinstance(height, int):
            raise TypeError()
        elif height < 1:
            raise ValueError()
        if not isinstance(fps, int):
            raise TypeError()
        elif fps < 1:
            raise ValueError()
        self.filename = filename
        self.width = width
        self.height = height
        self.fps = fps

    def onStartRender(self, renderer: Renderer, numberOfFrames: int):
        self.video = FFMPEG_VideoWriter(self.filename,
                                        (self.width, self.height), self.fps)

    def onStopRender(self, renderer: Renderer):
        self.video.close()

    def onFrame(self, renderer: Renderer):
        drawData = renderer.drawableImageBatch().data
        drawData = drawData[0]
        drawData = drawData.data.detach().cpu().numpy()
        drawData = np.moveaxis(drawData, 0, -1)
        drawData = np.uint8(255.0 * drawData)
        self.video.write_frame(drawData)
Beispiel #7
0
        print("Seconds per frame: " + str(spf) + " s/f")
        timeElapsed = datetime.now() - startedTime
        m, s = divmod(timeElapsed.total_seconds(), 60)
        h, m = divmod(m, 60)
        timeElapsedSeconds = "%dH:%02dM:%02dS" % (h, m, s)
        print("Elapsed time: " + timeElapsedSeconds)
        timeLeft = float(timeTaken.total_seconds() *
                         float(totalFrames - countDone))
        m, s = divmod(timeLeft, 60)
        h, m = divmod(m, 60)
        timeLeftSeconds = "%dH:%02dM:%02dS" % (h, m, s)
        print("Estimated time remaining: " + timeLeftSeconds)
        print("")
    #end
    writer.write_frame(nextFrame)
    writer.close()
    reader.close()
    print("Done!")
else:
    # Process image
    tensorInputFirst = torch.FloatTensor(
        numpy.rollaxis(
            numpy.asarray(PIL.Image.open(arguments_strFirst))[:, :, ::-1], 2,
            0).astype(numpy.float32) / 255.0)
    tensorInputSecond = torch.FloatTensor(
        numpy.rollaxis(
            numpy.asarray(PIL.Image.open(arguments_strSecond))[:, :, ::-1], 2,
            0).astype(numpy.float32) / 255.0)
    process(tensorInputFirst, tensorInputSecond, tensorOutput)
    PIL.Image.fromarray(
        (numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:, :, ::-1]
Beispiel #8
0
def wfc_execute(WFC_VISUALIZE=False, WFC_PROFILE=False, WFC_LOGGING=False):

    solver_to_use = "default"  #"minizinc"

    wfc_stats_tracking = {
        "observations": 0,
        "propagations": 0,
        "time_start": None,
        "time_end": None,
        "choices_before_success": 0,
        "choices_per_run": [],
        "success": False
    }
    wfc_stats_data = []
    stats_file_name = f"output/stats_{time.time()}.tsv"

    with open(stats_file_name, "a+") as stats_file:
        stats_file.write(
            "id\tname\tsuccess?\tattempts\tobservations\tpropagations\tchoices_to_solution\ttotal_observations_before_solution_in_last_restart\ttotal_choices_before_success_across_restarts\tbacktracking_total\ttime_passed\ttime_start\ttime_end\tfinal_time_end\tgenerated_size\tpattern_count\tseed\tbacktracking?\tallowed_restarts\tforce_the_use_of_all_patterns?\toutput_filename\n"
        )

    default_backtracking = False
    default_allowed_attempts = 10
    default_force_use_all_patterns = False

    xdoc = ET.ElementTree(file="samples_original.xml")
    counter = 0
    choices_before_success = 0
    for xnode in xdoc.getroot():
        counter += 1
        choices_before_success = 0
        if ("#comment" == xnode.tag):
            continue

        name = xnode.get('name', "NAME")
        global hackstring
        hackstring = name
        print("< {0} ".format(name), end='')
        if "backtracking_on" == xnode.tag:
            default_backtracking = True
        if "backtracking_off" == xnode.tag:
            default_backtracking = False
        if "one_allowed_attempts" == xnode.tag:
            default_allowed_attempts = 1
        if "ten_allowed_attempts" == xnode.tag:
            default_allowed_attempts = 10
        if "force_use_all_patterns" == xnode.tag:
            default_force_use_all_patterns = True
        if "overlapping" == xnode.tag:
            choices_before_success = 0
            print("beginning...")
            print(xnode.attrib)
            current_output_file_number = 97000 + (counter * 10)
            wfc_ns = types.SimpleNamespace(
                output_path="output/",
                img_filename="samples/" + xnode.get('name', "NAME") +
                ".png",  # name of the input file
                output_file_number=current_output_file_number,
                operation_name=xnode.get('name', "NAME"),
                output_filename="output/" + xnode.get('name', "NAME") + "_" +
                str(current_output_file_number) + "_" + str(time.time()) +
                ".png",  # name of the output file
                debug_log_filename="output/" + xnode.get('name', "NAME") +
                "_" + str(current_output_file_number) + "_" +
                str(time.time()) + ".log",
                seed=11975,  # seed for random generation, can be any number
                tile_size=int(xnode.get('tile_size',
                                        1)),  # size of tile, in pixels
                pattern_width=int(
                    xnode.get('N', 2)
                ),  # Size of the patterns we want. 2x2 is the minimum, larger scales get slower fast.
                channels=3,  # Color channels in the image (usually 3 for RGB)
                symmetry=int(xnode.get('symmetry', 8)),
                ground=int(xnode.get('ground', 0)),
                adjacency_directions=dict(
                    enumerate([
                        CoordXY(x=0, y=-1),
                        CoordXY(x=1, y=0),
                        CoordXY(x=0, y=1),
                        CoordXY(x=-1, y=0)
                    ])
                ),  # The list of adjacencies that we care about - these will be turned into the edges of the graph
                periodic_input=string2bool(xnode.get(
                    'periodicInput', True)),  # Does the input wrap?
                periodic_output=string2bool(
                    xnode.get('periodicOutput',
                              False)),  # Do we want the output to wrap?
                generated_size=(int(xnode.get('width', 48)),
                                int(xnode.get('height',
                                              48))),  #Size of the final image
                screenshots=int(
                    xnode.get('screenshots', 3)
                ),  # Number of times to run the algorithm, will produce this many distinct outputs
                iteration_limit=int(
                    xnode.get('iteration_limit', 0)
                ),  # After this many iterations, time out. 0 = never time out.
                allowed_attempts=int(
                    xnode.get('allowed_attempts', default_allowed_attempts)
                ),  # Give up after this many contradictions
                stats_tracking=wfc_stats_tracking.copy(),
                backtracking=string2bool(
                    xnode.get('backtracking', default_backtracking)),
                force_use_all_patterns=default_force_use_all_patterns,
                force_fail_first_solution=False)
            wfc_ns.stats_tracking[
                "choices_before_success"] += choices_before_success
            wfc_ns.stats_tracking["time_start"] = time.time()
            pr = cProfile.Profile()
            pr.enable()
            wfc_ns = find_pattern_center(wfc_ns)
            wfc_ns = wfc.wfc_utilities.load_visualizer(wfc_ns)
            ##
            ## Load image and make tile data structures
            ##
            wfc_ns.img = load_source_image(wfc_ns.img_filename)
            wfc_ns.channels = wfc_ns.img.shape[
                -1]  # detect if it uses channels other than RGB...
            wfc_ns.tiles = image_to_tiles(wfc_ns.img, wfc_ns.tile_size)
            wfc_ns.tile_catalog, wfc_ns.tile_grid, wfc_ns.code_list, wfc_ns.unique_tiles = make_tile_catalog(
                wfc_ns)
            wfc_ns.tile_ids = {
                v: k
                for k, v in dict(enumerate(wfc_ns.unique_tiles[0])).items()
            }
            wfc_ns.tile_weights = {
                a: b
                for a, b in zip(wfc_ns.unique_tiles[0], wfc_ns.unique_tiles[1])
            }

            if WFC_VISUALIZE:
                show_input_to_output(wfc_ns)
                show_extracted_tiles(wfc_ns)
                show_false_color_tile_grid(wfc_ns)

            wfc_ns.pattern_catalog, wfc_ns.pattern_weights, wfc_ns.patterns, wfc_ns.pattern_grid = make_pattern_catalog_with_symmetry(
                wfc_ns.tile_grid, wfc_ns.pattern_width, wfc_ns.symmetry,
                wfc_ns.periodic_input)
            if WFC_VISUALIZE:
                show_pattern_catalog(wfc_ns)
            adjacency_relations = adjacency_extraction_consistent(
                wfc_ns, wfc_ns.patterns)
            if WFC_VISUALIZE:
                show_adjacencies(wfc_ns, adjacency_relations[:256])
            wfc_ns = wfc.wfc_patterns.detect_ground(wfc_ns)
            pr.disable()

            screenshots_collected = 0
            while screenshots_collected < wfc_ns.screenshots:
                wfc_logger.info(f"Starting solver #{screenshots_collected}")
                screenshots_collected += 1
                wfc_ns.seed += 100

                choice_before_success = 0
                #wfc_ns.stats_tracking["choices_before_success"] = 0# += choices_before_success
                wfc_ns.stats_tracking["time_start"] = time.time()
                wfc_ns.stats_tracking["final_time_end"] = None

                # update output name so each iteration has a unique filename
                output_filename = "output/" + xnode.get(
                    'name', "NAME"
                ) + "_" + str(current_output_file_number) + "_" + str(
                    time.time()) + "_" + str(
                        wfc_ns.seed) + ".png",  # name of the output file

                profile_filename = "" + str(
                    wfc_ns.output_path) + "setup_" + str(
                        wfc_ns.output_file_number) + "_" + str(
                            wfc_ns.seed) + "_" + str(time.time()) + "_" + str(
                                wfc_ns.seed) + ".profile"
                if WFC_PROFILE:
                    with open(profile_filename, 'w') as profile_file:
                        ps = pstats.Stats(pr, stream=profile_file)
                        ps.sort_stats('cumtime', 'ncalls')
                        ps.print_stats(20)
                solution = None

                if "minizinc" == solver_to_use:
                    attempt_count = 0
                    #while attempt_count < wfc_ns.allowed_attempts:
                    #    attempt_count += 1
                    #    solution = mz_run(wfc_ns)
                    #    solution.wfc_ns.stats_tracking["attempt_count"] = attempt_count
                    #    solution.wfc_ns.stats_tracking["choices_before_success"] += solution.wfc_ns.stats_tracking["observations"]

                else:
                    if True:
                        attempt_count = 0
                        #print("allowed attempts: " + str(wfc_ns.allowed_attempts))
                        attempt_wfc_ns = copy.deepcopy(wfc_ns)
                        attempt_wfc_ns.stats_tracking[
                            "time_start"] = time.time()
                        attempt_wfc_ns.stats_tracking[
                            "choices_before_success"] = 0
                        attempt_wfc_ns.stats_tracking[
                            "total_observations_before_success"] = 0
                        wfc.wfc_solver.reset_backtracking_count(
                        )  # reset the count of how many times we've backtracked, because multiple attempts are handled here instead of there
                        while attempt_count < wfc_ns.allowed_attempts:
                            attempt_count += 1
                            print(attempt_count, end=' ')
                            attempt_wfc_ns.seed += 7  # change seed for each attempt...
                            solution = wfc_run(attempt_wfc_ns,
                                               visualize=WFC_VISUALIZE,
                                               logging=WFC_LOGGING)
                            solution.wfc_ns.stats_tracking[
                                "attempt_count"] = attempt_count
                            solution.wfc_ns.stats_tracking[
                                "choices_before_success"] += solution.wfc_ns.stats_tracking[
                                    "observations"]
                            attempt_wfc_ns.stats_tracking[
                                "total_observations_before_success"] += solution.wfc_ns.stats_tracking[
                                    'total_observations']
                            wfc_logger.info("result: {} is {}".format(
                                attempt_count, solution.result))
                            if solution.result == -2:
                                attempt_count = wfc_ns.allowed_attempts
                                solution.wfc_ns.stats_tracking[
                                    "time_end"] = time.time()
                            wfc_stats_data.append(
                                solution.wfc_ns.stats_tracking.copy())
                    solution.wfc_ns.stats_tracking[
                        "final_time_end"] = time.time()
                    print("tracking choices before success...")
                    choices_before_success = solution.wfc_ns.stats_tracking[
                        "choices_before_success"]
                    time_passed = None
                    if None != solution.wfc_ns.stats_tracking["time_end"]:
                        time_passed = solution.wfc_ns.stats_tracking[
                            "time_end"] - solution.wfc_ns.stats_tracking[
                                "time_start"]
                    else:
                        if None != solution.wfc_ns.stats_tracking[
                                "final_time_end"]:
                            time_passed = solution.wfc_ns.stats_tracking[
                                "final_time_end"] - solution.wfc_ns.stats_tracking[
                                    "time_start"]

                    print("...finished calculating time passed")
                    #print(wfc_stats_data)
                    print("writing stats...", end='')

                    with open(stats_file_name, "a+") as stats_file:
                        stats_file.write(
                            f"{solution.wfc_ns.output_file_number}\t{solution.wfc_ns.operation_name}\t{solution.wfc_ns.stats_tracking['success']}\t{solution.wfc_ns.stats_tracking['attempt_count']}\t{solution.wfc_ns.stats_tracking['observations']}\t{solution.wfc_ns.stats_tracking['propagations']}\t{solution.wfc_ns.stats_tracking['choices_before_success']}\t{solution.wfc_ns.stats_tracking['total_observations']}\t{attempt_wfc_ns.stats_tracking['total_observations_before_success']}\t{solution.backtracking_total}\t{time_passed}\t{solution.wfc_ns.stats_tracking['time_start']}\t{solution.wfc_ns.stats_tracking['time_end']}\t{solution.wfc_ns.stats_tracking['final_time_end']}\t{solution.wfc_ns.generated_size}\t{len(solution.wfc_ns.pattern_weights.keys())}\t{solution.wfc_ns.seed}\t{solution.wfc_ns.backtracking}\t{solution.wfc_ns.allowed_attempts}\t{solution.wfc_ns.force_use_all_patterns}\t{solution.wfc_ns.output_filename}\n"
                        )
                    print("done")

                if WFC_VISUALIZE:
                    print("visualize")
                    if None == solution:
                        print("n u l l")
                    #print(solution)
                    print(1)
                    solution_vis = wfc.wfc_solver.render_recorded_visualization(
                        solution.recorded_vis)
                    #print(solution)
                    print(2)

                    video_fn = f"{solution.wfc_ns.output_path}/crystal_example_{solution.wfc_ns.output_file_number}_{time.time()}.mp4"
                    wfc_logger.info("*****************************")
                    wfc_logger.warning(video_fn)
                    print(
                        f"solver recording stack len - {len(solution_vis.solver_recording_stack)}"
                    )
                    print(solution_vis.solver_recording_stack[0].shape)
                    if len(solution_vis.solver_recording_stack) > 0:
                        wfc_logger.info(
                            solution_vis.solver_recording_stack[0].shape)
                        writer = FFMPEG_VideoWriter(video_fn, [
                            solution_vis.solver_recording_stack[0].shape[0],
                            solution_vis.solver_recording_stack[0].shape[1]
                        ], 12.0)
                        for img_data in solution_vis.solver_recording_stack:
                            writer.write_frame(img_data)
                        print('!', end='')
                        writer.close()
                        mpy.ipython_display(video_fn, height=700)
                print("recording done")
                if WFC_VISUALIZE:
                    solution = wfc_partial_output(solution)
                    show_rendered_patterns(solution, True)
                print("render to output")
                render_patterns_to_output(solution, True, False)
                print("completed")
                print("\n{0} >".format(name))

        elif "simpletiled" == xnode.tag:
            print("> ", end="\n")
            continue
        else:
            continue
    print "n_frames:", movie_reader.nframes
    # the 1st frame of the saved video can't be directly read by movie_reader.read_frame(), I don't know why
    # maybe it's a bug of anim.save, actually, if we look at the movie we get from anim.save
    # we can easilly see that the 1st frame just close very soon.
    # so I manually get it at time 0, this is just a trick, I think.
    tmp_frame = movie_reader.get_frame(0)
    [
        movie_writer.write_frame(tmp_frame)
        for _ in range(int(new_fps * play_slow_rate))
    ]
    # for the above reason, we should read (movie_reader.nframes-1) frames so that the last frame is not
    # read twice (not that get_frame(0) alread read once)
    # However, I soon figure out that it should be (movie_reader.nframes-2). The details: we have actually
    # 6 frames, but (print movie_reader.nframes) is 7. I read the first frame through movie_reader.get_frame(0)
    # then are are 5 left. So I should use movie_reader.nframes - 2. Note that in fig1_pcm_fs2.py
    # in the case of: original fps=1
    # new_fps = 24, play_slow_rate = 1.5 the result is: 1st frame last 1.8s, others 1.5s, i.e., the 1st frame
    # has more duration. This is messy.
    for i in range(movie_reader.nframes - 2):
        tmp_frame = movie_reader.read_frame()
        [
            movie_writer.write_frame(tmp_frame)
            for _ in range(int(new_fps * play_slow_rate))
        ]
        pass
    movie_reader.close()
    movie_writer.close()
    os.remove(tmp_video_name)
    # plt.show()
    pass
Beispiel #10
0
class MovieEditor(object):
    def __init__(self,
                 input,
                 output,
                 width=None,
                 height=None,
                 log_level=FFMPEG_LOGLEVEL):
        self._reader = FFMPEG_VideoReader(input)
        self._fps = self._reader.fps
        # self.cur_frame = 1
        # self._frame = None
        self._querier = FfmpegQuerier()
        self._info = self._querier(input)
        self._duration = self._querier.duration
        self.draw_dict = {}
        self._resize = (int(width),
                        int(height)) if (width and height) else None
        self._loglevel = log_level
        self._output = output
        self._tmp_file = self._make_tmp_file(input) if self._resize else None
        self._writer = FFMPEG_VideoWriter(
            self._tmp_file if self.need_resize else output,
            size=self._reader.size,
            fps=self._reader.fps)

    @property
    def need_resize(self):
        return self._resize and list(self._resize) != list(self._reader.size)

    # for resizing
    def _make_tmp_file(self, input):
        return '%s-%s.mp4' % (input, get_time())

    def _remove_tmp_file(self):
        if os.path.exists(self._tmp_file):
            os.remove(self._tmp_file)

    # approximate frame count
    def get_total_frame(self):
        return timestamp2frame(self._duration, self._fps)

    # def seek_frame(self, index):
    #     self._reader.skip_frames(index - self.cur_frame)
    #     self._frame = self._reader.read_frame()
    #
    # def seek_timestamp(self, timestamp):
    #     index = FrameEditor.which_frame(timestamp, self._fps)
    #     self.seek_frame(index)
    #
    # def draw_anno(self, anno):
    #     if self._frame is not None:
    #         self._frame_editor.frame = self._frame
    #         self._frame_editor.draw_anno(anno)

    def draw_anno_file(self, anno_filename):
        # parse anno file
        ax = AnnoXml()
        ax.read(anno_filename)
        ax.parse()
        commands = ax.command

        fps = self._fps
        frame = self._reader.read_frame()
        self._writer.write_frame(frame)
        # total frame
        total_frame = self.get_total_frame()
        with click.progressbar(length=total_frame,
                               label='Processing...') as bar:
            # current
            frame_index = 1
            bar.update(1)
            for command in commands:
                pageid = command.pageid
                command_frame = CommandParser.get_frame_index(command, fps)
                if pageid in self.draw_dict:
                    frame_editor = self.draw_dict[pageid]
                else:
                    frame_editor = FrameEditor(fps=fps, pageid=pageid)
                    self.draw_dict[pageid] = frame_editor
                # before
                while frame_index < command_frame:
                    if not frame_editor.commands:
                        frame = self.read_frame()
                        self.write_frame(frame)
                    else:
                        frame = self.read_frame()
                        frame_editor.draw(frame)
                        self.write_frame(frame_editor.image)
                    frame_index += 1
                    bar.update(1)
                # append
                frame_editor.append_command(command)
                # current
                frame = self.read_frame()
                frame_editor.draw(frame)
                self.write_frame(frame_editor.image)
                frame_index += 1
                bar.update(1)
            # write left frames
            while frame_index < total_frame:
                self.write_frame(self.read_frame())
                frame_index += 1
                bar.update(1)
        # close stream
        self.close()
        # resize if needed
        if self.need_resize:
            self.resize()
            self._remove_tmp_file()

    def read_frame(self):
        return self._reader.read_frame()

    def write_frame(self, frame):
        # if self._resize:
        #     if not isinstance(frame, Image.Image):
        #         frame = Image.fromarray(frame.astype('uint8'))
        # Here's a bug.
        #     frame = frame.resize(self._resize, Image.ANTIALIAS)
        self._writer.write_frame(frame)

    def resize(self):
        parameters = [
            FFMPEG_FILE, '-i', self._tmp_file, '-s',
            '{w}*{h}'.format(w=self._resize[0], h=self._resize[1]),
            '-loglevel', self._loglevel, '-y', self._output
        ]

        subprocess.call(parameters)

    def close(self):
        self._reader.close()
        self._writer.close()
Beispiel #11
0
class ChatEditor(object):
    def __init__(self, width, height, filename, fps, maxsize, **kwargs):
        self._width = width
        self._height = height
        self._filename = filename
        self._fps = fps
        self._writer = FFMPEG_VideoWriter(filename, (width, height), fps)
        self._total_frame = 0  # init after loading chat xml
        self._chat_queue = ChatQueue(maxsize)
        self._draw = None
        self._image = None
        self._x_offset = kwargs.get('x_offset') or 0
        self._y_offset = kwargs.get('y_offset') or 0
        self._font = kwargs.get('font') or default_font
        self._bg_color = kwargs.get('bg_color') or (244, 244, 244)
        self._color_map = {}
        # max characters in one line
        self._max_characters = (self._width -
                                self._x_offset) // self._font.size

    def init_color_map(self, chats):
        self._color_map = self.make_color_map(chats)

    def init_total_frame(self, duration):
        self._total_frame = int(float(duration) * self._fps)

    def get_color(self, senderid):
        # default white color
        return self._color_map.get(senderid, (10, 10, 10))

    def get_color_by_chat(self, chat: Chat):
        return self.get_color(chat.senderId)

    def chat_to_show(self):
        return self._chat_queue.get_all()

    def draw_chat(self, chat, x, y) -> int:
        self._image = Image.new('RGB', (self._width, self._height),
                                color=self._bg_color)
        self._draw = ImageDraw.Draw(self._image)
        content = self._format_chat(chat)
        color = self.get_color_by_chat(chat)
        y_pos = self._draw_text(self._draw, content, x, y, color=color)
        self.write_frame(self._image)
        return y_pos

    def draw_chats(self, chats):
        # draw all chats in chat queue
        self._image = Image.new('RGB', (self._width, self._height),
                                color=self._bg_color)
        self._draw = ImageDraw.Draw(self._image)
        x_pos, y_pos = 0, 0
        for chat in chats:
            content = self._format_chat(chat)
            color = self.get_color_by_chat(chat)
            y_pos = self._draw_text(self._draw,
                                    content,
                                    x_pos,
                                    y_pos,
                                    color=color)
        self.write_frame(self._image)

    def foresee_chats(self, chats):
        if not chats:
            return
        # ensure full content of chats to be shown in the video,
        # if not, remove the former chats
        x_pos, y_pos = 0, 0
        remove_count = len(chats)
        for chat in chats[::-1]:
            content = self._format_chat(chat)
            y_pos = self._draw_text(self._draw,
                                    content,
                                    x_pos,
                                    y_pos,
                                    to_draw=False)
            # y_pos > height, to remove chat
            if y_pos > self._height:
                break
            remove_count -= 1
        # remove
        while remove_count > 0:
            self._chat_queue.pop()
            remove_count -= 1

    def write_frame(self, frame):
        self._writer.write_frame(frame)

    def close(self):
        self._writer.close()

    def _draw_text(self, draw, text, x, y, color=None, to_draw=True):
        # refer: https://stackoverflow.com/questions/7698231/python-pil-draw-multiline-text-on-image
        lines = textwrap.wrap(text, width=self._max_characters)
        y_text = y + self._y_offset
        for line in lines:
            width, height = self._font.getsize(line)
            to_draw and draw.text((x + self._x_offset, y_text),
                                  line,
                                  font=self._font,
                                  fill=color or (0, 0, 0))
            y_text += height + self._y_offset
        return y_text

    def _format_chat(self, chat: Chat):
        name = chat.sender
        content = self.format_chat_content(chat.content)
        s = '[%s] %s' % (name, content)

        return s

    @staticmethod
    def format_chat_content(content):
        return clean_text(content)

    @staticmethod
    def make_color_map(chats):
        all_sender_id = list(set(list(map(lambda x: x.senderId, chats))))
        color_map = {
            _id: (random.randint(0, 255), random.randint(0, 255),
                  random.randint(0, 255))
            for _id in all_sender_id
        }
        return color_map

    def draw(self, chat_filename):
        cx = ChatXml()
        cx.read(chat_filename)
        cx.parse()

        chats = cx.chat

        # init total frame
        self.init_total_frame(cx.duration)

        # init color map
        self.init_color_map(chats)

        with click.progressbar(length=self._total_frame,
                               label='Processing...') as bar:
            frame = 0
            for chat in chats:
                chat_frame = timestamp2frame(chat.timestamp, self._fps)
                self.foresee_chats(self._chat_queue.get_all())
                draw_chats = self._chat_queue.get_all()
                # before
                while frame < chat_frame:
                    self.draw_chats(draw_chats)
                    frame += 1
                    bar.update(1)
                # current
                self._chat_queue.push(chat)
                # forsee
                self.foresee_chats(self._chat_queue.get_all())
                self.draw_chats(self._chat_queue.get_all())
                frame += 1
                bar.update(1)
            # draw left
            while frame < self._total_frame:
                # forsee
                self.foresee_chats(self._chat_queue.get_all())
                self.draw_chats(self._chat_queue.get_all())
                frame += 1
                bar.update(1)
Beispiel #12
0
class OMRController(ArenaController.ArenaController, Machine):
    def __init__(self, parent, arenaMain):
        super(OMRController, self).__init__(parent, arenaMain)

        print type(parent)

        #calibration
        self.arenaCamCorners = None
        self.arenaMidLine = []
        self.arenaSide1Sign = 1
        self.arenaProjCorners = []
        self.fishImg = None #image of fish

        #tracking
        self.arenaCvMask = None   

        #state
        self.mutex = Lock()
        states = ['off', 'between_session', 'omr_to_s1', 'no_omr_post_s1', 'omr_to_s2', 'no_omr_post_s2'] # off between L-N-R-N-L-N-R between LNRNLNR between off
        Machine.__init__(self, states=states, initial='off', after_state_change='update_state_data')

        self.add_transition('begin', 'off', 'between_session', conditions='isReadyToStart')
        self.add_transition('start_session', 'between_session', 'omr_to_s1')
        self.add_transition('start_omr', 'no_omr_post_s1', 'omr_to_s2')
        self.add_transition('start_omr', 'no_omr_post_s2', 'omr_to_s1')
        self.add_transition('stop_omr', 'omr_to_s1', 'no_omr_post_s1')
        self.add_transition('stop_omr', 'omr_to_s2', 'no_omr_post_s2')
        self.add_transition('stop_session', ['omr_to_s1', 'omr_to_s2', 'no_omr_post_s1', 'no_omr_post_s2'], 'between_session')
        self.add_transition('quit', '*', 'off')

        self.fishPosUpdate = False
        self.fishPos = (0,0)
        self.fishPosBuffer = []
        self.fishPosBufferT = []
        self.currCvFrame = None
        #Note additional state variable will be created during state transitions (like nextStateTime)

        #data results
        self.arenaData = None
        self.fishSize = None

        #init UI
        #arena info group box
        self.arenaGroup = QtGui.QGroupBox(self)
        self.arenaGroup.setTitle('Arena Info')
        self.arenaLayout = QtGui.QGridLayout()
        self.arenaLayout.setHorizontalSpacing(3)
        self.arenaLayout.setVerticalSpacing(3)
        self.camCalibButton = QtGui.QPushButton('Set Cam Position')
        self.camCalibButton.setMaximumWidth(150)
        self.camCalibButton.clicked.connect(self.getArenaCameraPosition)
        self.resetCamCorners = None
        self.projGroup = QtGui.QGroupBox(self.arenaGroup)
        self.projGroup.setTitle('Projector Position')
        self.projLayout = QtGui.QGridLayout(self.projGroup)
        self.projLayout.setHorizontalSpacing(3)
        self.projLayout.setVerticalSpacing(3)
        self.projCalibButton = QtGui.QPushButton('Calibrate Projector Position')
        self.projCalibButton.setCheckable(True)
        self.projCalibButton.clicked.connect(self.projectorPositionChanged)
        self.projLayout.addWidget(self.projCalibButton,0,0,1,6)
        self.projPosL = QtGui.QLabel('Pos')
        self.projX = QtGui.QSpinBox()
        self.projX.setRange(0,1000)
        self.projX.setValue(250)
        self.projX.setMaximumWidth(50)
        self.projY = QtGui.QSpinBox()
        self.projY.setRange(0,1000)
        self.projY.setValue(250)
        self.projY.setMaximumWidth(50)
        self.projLayout.addWidget(self.projPosL,1,0)
        self.projLayout.addWidget(self.projX,1,1)
        self.projLayout.addWidget(self.projY,1,2)
        self.projX.valueChanged.connect(self.projectorPositionChanged)
        self.projY.valueChanged.connect(self.projectorPositionChanged)
        self.projSizeL = QtGui.QLabel('Size W,L')
        self.projWid = QtGui.QSpinBox()
        self.projWid.setRange(0,1000)
        self.projWid.setValue(115)
        self.projWid.setMaximumWidth(50)
        self.projLen = QtGui.QSpinBox()
        self.projLen.setRange(0,1000)
        self.projLen.setValue(220)
        self.projLen.setMaximumWidth(50)
        self.projLen.valueChanged.connect(self.projectorPositionChanged)
        self.projWid.valueChanged.connect(self.projectorPositionChanged)
        self.projLayout.addWidget(self.projSizeL,2,0)
        self.projLayout.addWidget(self.projWid,2,1)
        self.projLayout.addWidget(self.projLen,2,2)
        self.projRotL = QtGui.QLabel('Rotation')
        self.projRot = QtGui.QSpinBox()
        self.projRot.setRange(0,360)
        self.projRot.setValue(270)
        self.projRot.setMaximumWidth(50)
        self.projLayout.addWidget(self.projRotL,3,0)
        self.projLayout.addWidget(self.projRot,3,1)
        self.projRot.valueChanged.connect(self.projectorPositionChanged)
        
        self.tankLength = LabeledSpinBox(None, 'Tank Len (mm)', 0,100,46,60)
        self.projLayout.addWidget(self.tankLength, 4,0,1,2)
        self.tankWidth = LabeledSpinBox(None, 'Tank Wid (mm)', 0,100,21,60)
        self.projLayout.addWidget(self.tankWidth, 4,2,1,2)

        self.projLayout.setColumnStretch(6,1)
        self.projGroup.setLayout(self.projLayout)
        self.arenaLayout.addWidget(self.camCalibButton, 0,0)
        self.arenaLayout.addWidget(self.projGroup,1,0)
        self.arenaLayout.setColumnStretch(1,1)
        self.arenaGroup.setLayout(self.arenaLayout)

        #tracking group box
        self.trackWidget = FishTrackerWidget(self, self.arenaMain.ftDisp)

        #start button for individual tank
        self.startButton = QtGui.QPushButton('Start')
        self.startButton.setMaximumWidth(150)
        self.startButton.clicked.connect(self.startstop)
        
        #whether to save movie  
        self.paramSaveMovie = QtGui.QCheckBox('Save Movie')
            
        #experimental parameters groupbox
        self.paramGroup = QtGui.QGroupBox()
        self.paramGroup.setTitle('Exprimental Parameters')
        self.paramLayout = QtGui.QGridLayout()
        self.paramLayout.setHorizontalSpacing(2)
        self.paramLayout.setVerticalSpacing(2)

        #high level flow - Between - OMR Session ( OMRS1 - pause - OMRS2
        self.paramNumSession = LabeledSpinBox(None, 'Number of Sessions', 0, 200, 20, 60) #Number of Sessions (a cluster of OMR tests)
        self.paramLayout.addWidget(self.paramNumSession, 1,0,1,2) 
        self.paramBetween = LabeledSpinBox(None, 'Between Sessions (s)', 0, 1200, 300, 60)
        self.paramLayout.addWidget(self.paramBetween, 1,2,1,2)
        self.paramOMRPerSession = LabeledSpinBox(None,'#OMR Per Session', 0 ,10,4,60) # Number of OMR tests per session
        self.paramLayout.addWidget(self.paramOMRPerSession, 2,0,1,2)        
        self.paramOMRDuration = LabeledSpinBox(None,'OMR Duration (s)',0,60,12,60) #duration of each OMR tests
        self.paramLayout.addWidget(self.paramOMRDuration, 2,2,1,2)
        self.paramOMRInterval = LabeledSpinBox(None,'OMR Interval (s)',0,60 ,3 ,60) #time between OMR tests
        self.paramLayout.addWidget(self.paramOMRInterval,3,0,1,2)
 
        #Vision properities of OMR
        self.paramOMRPeriod = LabeledSpinBox(None,'OMR Grating Period (mm)',0,50,5,60) #spacing between grating bars
        self.paramLayout.addWidget(self.paramOMRPeriod,4,0,1,2)
        self.paramOMRDutyCycle = LabeledSpinBox(None, 'OMR DutyCycle %',0,100,50,60) #width of grating bars
        self.paramLayout.addWidget(self.paramOMRDutyCycle,4,2,1,2)
        self.paramOMRVelocity = LabeledSpinBox(None, 'OMR Speed (mm/s)',0,50,5,60) #velocity of moving gratings.
        self.paramLayout.addWidget(self.paramOMRVelocity,5,0,1,2)

        (self.paramColorNeutral,self.labelColorNeutral) = self.getColorComboBox('Neutral', 0)
        self.paramLayout.addWidget(self.paramColorNeutral,6,0)
        self.paramLayout.addWidget(self.labelColorNeutral,6,1)
        (self.paramColorOMR,self.labelColorOMR) = self.getColorComboBox('OMR', 5)
        self.paramLayout.addWidget(self.paramColorOMR,6,2)
        self.paramLayout.addWidget(self.labelColorOMR,6,3)
        self.paramGroup.setLayout(self.paramLayout)

        #Experimental info group
        self.infoGroup = QtGui.QGroupBox()
        self.infoGroup.setTitle('Experiment Info')
        self.infoLayout = QtGui.QGridLayout()
        self.infoLayout.setHorizontalSpacing(3)
        self.infoLayout.setVerticalSpacing(3)

        self.labelDir = QtGui.QLabel('Dir: ')
        self.infoDir = PathSelectorWidget(browseCaption='Experimental Data Directory', default=os.path.expanduser('~/data/test'))
        self.infoLayout.addWidget(self.labelDir,0,0)
        self.infoLayout.addWidget(self.infoDir,0,1)

        self.labelDOB = QtGui.QLabel('DOB: ')
        self.infoDOB = QtGui.QDateEdit(QtCore.QDate.currentDate())
        self.infoLayout.addWidget(self.labelDOB,1,0)
        self.infoLayout.addWidget(self.infoDOB,1,1)

        self.labelType = QtGui.QLabel('Line: ')
        self.infoType = QtGui.QLineEdit('Nacre')
        self.infoLayout.addWidget(self.labelType,2,0)
        self.infoLayout.addWidget(self.infoType,2,1)     

        self.infoFish = QtGui.QPushButton('Snap Fish Image')
        self.infoFish.clicked.connect(self.getFishSize)
        self.infoLayout.addWidget(self.infoFish,3,0,1,2)

        self.infoGroup.setLayout(self.infoLayout)

        self.settingsLayout = QtGui.QGridLayout()
        self.settingsLayout.addWidget(self.arenaGroup,0,0,1,2)
        self.settingsLayout.addWidget(self.trackWidget,1,0,1,2)
        self.settingsLayout.addWidget(self.infoGroup,2,0,1,2)
        self.settingsLayout.addWidget(self.startButton,3,0,1,1)
        self.settingsLayout.addWidget(self.paramSaveMovie,3,1,1,1)
        self.settingsLayout.addWidget(self.paramGroup,4,0,1,2)
        self.setLayout(self.settingsLayout)

        #HACK Couldn't quick initialize arena lcation until after frist image arrives, so stuck in onNewFrame

        #initialize projector
        self.projectorPositionChanged()
        self.t = 0

    def getColorComboBox(self, labelName, defaultNdx=1):
        colorBox = QtGui.QComboBox()
        colorBox.addItem('White')
        colorBox.addItem('Red')
        colorBox.addItem('Blue')
        colorBox.addItem('Gray')
        colorBox.addItem('Magenta')
        colorBox.addItem('Black')
        colorBox.setCurrentIndex(defaultNdx)
        colorLabel = QtGui.QLabel(labelName)
        return (colorBox,colorLabel)

    def configTank(self, ardConfig, camPos, projPos):
        #set tank position in camera
        self.resetCamCorners = camPos

        #set tank position in projector
        self.projX.setValue(projPos[0])
        self.projY.setValue(projPos[1])
        self.projWid.setValue(projPos[2])
        self.projLen.setValue(projPos[3])
        self.projRot.setValue(projPos[4])
        

    #---------------------------------------------------
    # OVERLOADED METHODS
    #---------------------------------------------------

    #update to overload an return image based on dispmode.
    #def getArenaView(self):
    #    return self.arenaView

    def onNewFrame(self, frame, time):
        self.mutex.acquire()
        try:
            self.frameTime = time
            self.currCvFrame = frame # may need to make a deep copy

            (found, pos, view, allFish) = self.trackWidget.findFish(self.currCvFrame)

            if found:
                self.fishPosUpdate = found
                self.fishPos = pos
                #if experiment is running then keep track of the average fish speed in one minute blocks
                if not self.is_off():
                    self.fishPosBuffer.append(pos)
                    self.fishPosBufferT.append(time)
                    if (self.fishPosBufferT[-1] - self.fishPosBufferT[0]) > 60: #averaging every 60 seconds, must match constant in drawDisplayOverlay               
                        self.fishPosBuffer = [np.sqrt((f1[0]-f0[0])**2 + (f1[1]-f0[1])**2) for f0, f1 in pairwise(self.fishPosBuffer)]
                        self.averageSpeed.append(np.mean(self.fishPosBuffer)/(self.fishPosBufferT[-1] - self.fishPosBufferT[0])) #pixels/sec
                        self.fishPosBuffer = []
                        self.fishPosBufferT = []

            if not self.is_off():
                #record location
                self.arenaData['video'].append((self.frameTime, None)) 
                d = [self.frameTime, pos[0], pos[1]]
                self.arenaData['tracking'].append(tuple(d))
                if self.paramSaveMovie.isChecked():                
                    img = np.array(self.currCvFrame[:,:]) #convert IplImage into numpy array
                    img = img[self.arenaBB[0][1]:self.arenaBB[1][1],self.arenaBB[0][0]:self.arenaBB[1][0]] 
                    self.movie_logger.write_frame(img)

            self.arenaView = view

            #HACK: couldn't quickly initilize cam corner before currCvFrame is set, so I stuck it here.
            #not a great place since conditoin will be checked frequently
            if self.resetCamCorners is not None:
                self.setArenaCamCorners(self.resetCamCorners)
                self.resetCamCorners = None

        except:
            print 'ClassicalConditioningController:onNewFrame failed'
            traceback.print_exc()
            QtCore.pyqtRemoveInputHook() 
            ipdb.set_trace()
        finally:
            self.mutex.release()

  
    def updateState(self):
        self.updateExperimentalState()
        
    def updateExperimentalState(self):
        if not self.is_off():
            self.mutex.acquire()
            try:
                self.t = time.time()

                #Check if it time to transition to next high level state (acclimate->baseline->trials->post->off)
                if self.t > self.nextStateTime:
                    if self.is_between_session():
                        if self.nSession < self.paramNumSession.value():
                            self.nOMR = 0 #TODO move to callback for start session
                            self.nSession+=1 #TODO move to callback for start session
                            self.start_session()
                        else:
                            self.quit()

                    elif self.is_omr_to_s1() or self.is_omr_to_s2():
                        self.stop_omr()
                            
                    elif self.is_no_omr_post_s1() or self.is_no_omr_post_s2():
                        if self.nOMR < self.paramOMRPerSession.value():
                            self.start_omr()
                        else:
                            self.stop_session()

                #handle omr projector updates
                if self.is_omr_to_s1() or self.is_omr_to_s2():
                    if self.t - self.omrLastUpdate > 0.010: #only updated display once sufficient time has passed.
                        if self.t - self.omrLastUpdate > 0.060:
                            print 'WARNING: projector slow to update', self.t - self.omrLastUpdate
                        self.updateProjectorDisplay()

            except:
                print 'ContextualHelplessnessController:updateState failed'
                traceback.print_exc()
                QtCore.pyqtRemoveInputHook()
                ipdb.set_trace()
            finally:
                self.mutex.release()

    def projectorPositionChanged(self):
        #the point defining the tank are such that:
        #points 1 and 2 define side1
        #points 2 and 3 define the 'near' connecting edge
        #points 3 and 4 define side2
        #points 4 and 1 define the 'far' connecting edge.
 
        #move this code elsewhere to avoid redundant computation...
        self.pts = np.array([[0                   , 0],
                             [0                   , self.projWid.value()],
                             [self.projLen.value(), self.projWid.value()],
                             [self.projLen.value(), 0]])
        theta = self.projRot.value()
        self.tankRotM = np.array([[cos(radians(theta)), -sin(radians(theta))],
                         [sin(radians(theta)),  cos(radians(theta))]]);
        self.pts = np.dot(self.tankRotM,self.pts.T).T
        self.pts += [self.projX.value(), self.projY.value()]
        self.updateProjectorDisplay()

    def drawProjectorDisplay(self, painter):
        if self.is_off() and self.projCalibButton.isChecked() and self.isCurrent():
            #DRAW A CALIBRATION IMAGE THAT INDICATES THE PROJECTOR LOCATION AND SIDE1
            side1Color = QtCore.Qt.red
            side2Color = QtCore.Qt.blue
            a = .5
            b = 1-a
            #Draw side one
            pen = QtGui.QPen(QtCore.Qt.NoPen)
            brush = QtGui.QBrush(side1Color)
            painter.setBrush(brush)
            painter.setPen(pen)
            poly = QtGui.QPolygonF()
            poly.append(QtCore.QPointF(self.pts[0,0], self.pts[0,1]))
            poly.append(QtCore.QPointF(self.pts[1,0], self.pts[1,1]))
            poly.append(QtCore.QPointF(a*self.pts[1,0] + b*self.pts[2,0], a*self.pts[1,1] + b*self.pts[2,1]))
            poly.append(QtCore.QPointF(a*self.pts[0,0] + b*self.pts[3,0], a*self.pts[0,1] + b*self.pts[3,1]))
            painter.drawPolygon(poly)
            painter.setPen(QtGui.QColor(168, 34, 3))
            painter.setFont(QtGui.QFont('Decorative',24))
            painter.drawText(self.pts[0,0],self.pts[0,1],'1')
            painter.drawText(self.pts[1,0],self.pts[1,1],'2')
            painter.drawText(self.pts[2,0],self.pts[2,1],'3')
            painter.drawText(self.pts[3,0],self.pts[3,1],'4')
            #Draw side two
            pen = QtGui.QPen(QtCore.Qt.NoPen)
            brush = QtGui.QBrush(side2Color)
            painter.setBrush(brush)
            painter.setPen(pen)
            poly = QtGui.QPolygonF()
            poly.append(QtCore.QPointF(a*self.pts[0,0] + b*self.pts[3,0], a*self.pts[0,1] + b*self.pts[3,1]))
            poly.append(QtCore.QPointF(a*self.pts[1,0] + b*self.pts[2,0], a*self.pts[1,1] + b*self.pts[2,1]))
            poly.append(QtCore.QPointF(self.pts[2,0], self.pts[2,1]))
            poly.append(QtCore.QPointF(self.pts[3,0], self.pts[3,1]))
            painter.drawPolygon(poly)
        else:
            #Draw Background
            backColor = self.paramColorNeutral.currentIndex()
            pen = QtGui.QPen(QtCore.Qt.NoPen)
            brush = self.getBrush(backColor)
            painter.setBrush(brush)
            painter.setPen(pen)
            poly = QtGui.QPolygonF()
            poly.append(QtCore.QPointF(self.pts[0,0], self.pts[0,1]))
            poly.append(QtCore.QPointF(self.pts[1,0], self.pts[1,1]))
            poly.append(QtCore.QPointF(self.pts[2,0], self.pts[2,1]))
            poly.append(QtCore.QPointF(self.pts[3,0], self.pts[3,1]))          
            painter.drawPolygon(poly)

            #Draw OMR grating
            if self.is_omr_to_s1() or self.is_omr_to_s2():
                #update the possition according to time since last update.
                if self.is_omr_to_s1():
                    self.omrPhase -= (float(self.paramOMRVelocity.value())/float(self.paramOMRPeriod.value())) * 360.0 * (self.t-self.omrLastUpdate)
                else:
                    self.omrPhase += (float(self.paramOMRVelocity.value())/float(self.paramOMRPeriod.value())) * 360.0 * (self.t-self.omrLastUpdate)
                self.omrPhase = self.omrPhase%360.0
                self.omrLastUpdate = self.t                    
                
                #This stuff doesn't need to computed on every frame
                mm2pix = math.hypot(self.pts[2,0]-self.pts[1,0], self.pts[2,1] - self.pts[1,1]) / self.tankLength.value()
                period = self.paramOMRPeriod.value() * mm2pix
                numBars = int(math.ceil(math.hypot(self.pts[2,0]-self.pts[1,0], self.pts[2,1] - self.pts[1,1]) / period))
                barWidth = period * (self.paramOMRDutyCycle.value()/100.0)
                    
                #Draw the bar
                phaseShift = self.omrPhase/360.0 * period  
                pen = QtGui.QPen(QtCore.Qt.NoPen)
                brush = self.getBrush(self.paramColorOMR.currentIndex())
                painter.setBrush(brush)
                painter.setPen(pen)
                for nBar in range(-1,numBars+1):
                    #define bar in unrotated untranslated space.
                    bl = min(max(phaseShift + nBar * period, 0), self.projLen.value()) 
                    br = min(max(phaseShift + nBar * period + barWidth, 0), self.projLen.value()) 
                    bar = np.array([[br,0],
                                    [bl,0],
                                    [bl,self.projWid.value()],
                                    [br,self.projWid.value()]])
                    #rotate and tranlate.
                    bar = np.dot(self.tankRotM, bar.T).T + [self.projX.value(), self.projY.value()]
                    #draw
                    poly = QtGui.QPolygonF()
                    poly.append(QtCore.QPointF(bar[0,0],bar[0,1]))
                    poly.append(QtCore.QPointF(bar[1,0],bar[1,1]))
                    poly.append(QtCore.QPointF(bar[2,0],bar[2,1]))
                    poly.append(QtCore.QPointF(bar[3,0],bar[3,1]))
                    painter.drawPolygon(poly)

    def drawDisplayOverlay(self, painter):
        #draw the fish position
        if self.fishPos and self.fishPos[0] > 0:
            brush = QtGui.QBrush(QtCore.Qt.red)
            painter.setBrush(brush)
            painter.setPen(QtCore.Qt.NoPen)
            painter.drawEllipse(QtCore.QPointF(self.fishPos[0],self.fishPos[1]), 3,3)
            
        #draw the arena overlay
        if self.arenaCamCorners:
            painter.setBrush(QtCore.Qt.NoBrush)
            if self.bIsCurrentArena:
                pen = QtGui.QPen(QtCore.Qt.green)
            else:
                pen = QtGui.QPen(QtCore.Qt.blue)
            pen.setWidth(3)
            painter.setPen(pen)
            poly = QtGui.QPolygonF()
            for p in self.arenaCamCorners:
                poly.append(QtCore.QPointF(p[0],p[1]))
            painter.drawPolygon(poly)
            if len(self.arenaCamCorners) >= 2:
                pen = QtGui.QPen(QtCore.Qt.red)
                pen.setWidth(2)
                painter.setPen(pen)
                painter.drawLine(self.arenaCamCorners[0][0],self.arenaCamCorners[0][1],
                                 self.arenaCamCorners[1][0],self.arenaCamCorners[1][1])
                painter.setPen(QtGui.QColor(168, 34, 3))
                painter.setFont(QtGui.QFont('Decorative',14))

                statustext = ""
                statustext = statustext + self.state
                if not self.is_off():
                    statustext = statustext + ": %d/%d %0.1f"%(self.nSession, self.paramNumSession.value(), self.nextStateTime-self.t)
                if self.is_omr_to_s1() or self.is_omr_to_s2() or self.is_no_omr_post_s1() or self.is_no_omr_post_s2():
                    statustext = statustext + ": %d/%d"%(self.nOMR, self.paramOMRPerSession.value())

                painter.drawText(self.arenaCamCorners[0][0],self.arenaCamCorners[0][1],statustext)

                #Plot average speed of fish overtime
                if not self.is_off() and len(self.averageSpeed)>2:
                    ac = np.array(self.arenaCamCorners)
                    plt_b = np.max(ac[:,1]) #Plot bottom
                    plt_l = np.min(ac[:,0]) #Plot left
                    plt_r = np.max(ac[:,0]) #Plot right

                    totalDuration = self.paramBetween.value() + self.paramNumSession.value()*(self.paramBetween.value()+
                                    self.paramOMRPerSession.value() * (self.paramOMRDuration.value() + self.paramOMRInterval.value()))
                    totalDuration = totalDuration*(60/60.0) #averaging every N seconds (see onnewframe)
                    ndx = np.arange(len(self.averageSpeed))
                    x = (np.arange(len(self.averageSpeed))/totalDuration) 
                    x = x * (plt_r - plt_l) 
                    x = x + plt_l
                    #scale = 75 / np.max(self.averageSpeed)
                    scale = 3000
                    y = plt_b - (scale * np.array(self.averageSpeed))
                    for ((x1,x2),(y1,y2)) in zip(pairwise(x),pairwise(y)):
                        painter.drawLine(x1,y1,x2,y2)


    def start(self):
        self.startstop()

    def stop(self):
        self.startstop()

    #---------------------------------------------------
    # STATE MACHINE CALLBACK METHODS
    #---------------------------------------------------

    def update_state_data(self):
        # runs on every state transition...
        self.updateProjectorDisplay()
        self.arenaData['stateinfo'].append((self.t, 0, 0, 0, self.state))
        #self.saveResults() #just save at end to avoid long interframe intervals

    def on_exit_off(self):
        #experiment has started so disable parts of UI
        self.paramGroup.setDisabled(True)
        self.infoGroup.setDisabled(True)
        self.startButton.setText('Stop')

        #experiment has started so prepare result files and data structures
        td = datetime.datetime.now()
        self.saveLocation = str(self.infoDir.text())
        [p, self.fnResults] = os.path.split(self.saveLocation)
        self.fnResults = self.fnResults + '_' + td.strftime('%Y-%m-%d-%H-%M-%S')
        self.jsonFileName = str(self.infoDir.text()) + os.sep + self.fnResults  + '.json'

        #prepare to write movie
        if self.paramSaveMovie.isChecked():
            self.movieFileName = str(self.infoDir.text()) + os.sep + self.fnResults  + '.mp4'
            from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter as VidWriter
            self.movie_logger = VidWriter(filename=self.movieFileName,
                                          size=(self.arenaBB[1][0] - self.arenaBB[0][0],
                                                self.arenaBB[1][1] - self.arenaBB[0][1]),
                                          fps=15,
                                          codec='mpeg4',
                                          preset='ultrafast')

        self.initArenaData()
        self.nSession = 0
        self.averageSpeed = []

    def on_enter_between_session(self):
        self.nextStateTime = self.t + self.paramBetween.value()

    def on_enter_omr_to_s1(self):
        self.nOMR += 1
        self.nextStateTime = self.t + self.paramOMRDuration.value()
        self.omrPhase = 0
        self.omrLastUpdate = self.t #updates everytime the projectory updates 
        self.updateProjectorDisplay()

    def on_exit_omr_to_s1(self):
        self.updateProjectorDisplay()

    def on_enter_omr_to_s2(self):
        self.nOMR += 1
        self.nextStateTime = self.t + self.paramOMRDuration.value()
        self.omrPhase = 0 
        self.omrLastUpdate = self.t #updates everytime the projectory updates 
        self.updateProjectorDisplay()

    def on_exit_omr_to_s2(self):
        self.updateProjectorDisplay()

    def on_enter_no_omr_post_s1(self):
        self.nextStateTime = self.t + self.paramOMRInterval.value()
        self.updateProjectorDisplay()

    def on_enter_no_omr_post_s2(self):
        self.nextStateTime = self.t + self.paramOMRInterval.value()
        self.updateProjectorDisplay()

    def on_enter_off(self):
        self.nextStateTime = None
        self.updateProjectorDisplay()

        self.saveResults()

        #note: the projector colors will be reset on call to after_state_change
        if self.paramSaveMovie.isChecked():
            self.movie_logger.close()

        #experiment has ended so enable parts of UI
        self.paramGroup.setDisabled(False)
        self.infoGroup.setDisabled(False)
        self.startButton.setText('Start')

    def isReadyToStart(self):
        if not os.path.exists(self.infoDir.text()):
            try:
                os.mkdir(self.infoDir.text())
            except:
                self.arenaMain.statusBar().showMessage('%s arena not ready to start.  Experiment directory does not exist and cannot be created.'%(self.getYokedStr()))
                return False

        if not self.arenaCamCorners:
            self.arenaMain.statusBar().showMessage('%s arena not ready to start.  The arena location has not been specified.'%(self.getYokedStr()))
            return False

        if not self.trackWidget.getBackgroundImage():
            self.arenaMain.statusBar().showMessage('%s arena not ready to start.  Background image has not been taken.'%(self.getYokedStr()))
            return False

        #if not self.fishImg:
        #    self.arenaMain.statusBar().showMessage('%s arena not ready to start.  Fish image has not been taken.'%(self.getYokedStr()))
        #    return False

        return True

    #---------------------------------------------------
    # SLOT CALLBACK METHODS
    #---------------------------------------------------

    def startstop(self):
        self.mutex.acquire()
        try:
            self.t = time.time()

            if self.is_off():
                self.begin()
            else:
                self.quit(); 
        except:
            print 'Start Failed:'
            traceback.print_exc()
            QtCore.pyqtRemoveInputHook()
            ipdb.set_trace()
        finally:
            self.mutex.release()

    def getArenaCameraPosition(self):
        self.arenaMain.statusBar().showMessage('Click on the corners of the arena on side 1.')
        self.currArenaClick = 0
        self.arenaCamCorners = []
        self.arenaMain.ftDisp.clicked.connect(self.handleArenaClicks) 

    @QtCore.pyqtSlot(int, int) #not critical but could practice to specify which functions are slots.
    def handleArenaClicks(self, x, y):
        self.currArenaClick+=1
        if self.currArenaClick<5:
            self.arenaCamCorners.append((x,y))
            if self.currArenaClick==1:
                self.arenaMain.statusBar().showMessage('Click on the other corner of the arena on side 1.')
            elif self.currArenaClick==2:
                self.arenaMain.statusBar().showMessage('Now, click on the corners of the arena on side 2.')
            elif self.currArenaClick==3:
                self.arenaMain.statusBar().showMessage('Click on the other corner of the arena on side 2.')	
            elif self.currArenaClick==4:
                self.arenaMain.ftDisp.clicked.disconnect(self.handleArenaClicks)
                self.arenaMain.statusBar().showMessage('')
                self.setArenaCamCorners(self.arenaCamCorners)
               
    def setArenaCamCorners(self, corners):
        #corners must be tuple of tuples (not list or np.array)
        self.currArenaclick=4
        self.arenaCamCorners = corners
        print corners
        [self.arenaMidLine, self.arenaSide1Sign] = self.processArenaCorners(self.arenaCamCorners, .5)
        #compute bounding box with vertical and horizontal sides.
        self.arenaBB = [[min([p[0] for p in corners]), min([p[1] for p in corners])],
                        [max([p[0] for p in corners]), max([p[1] for p in corners])]]
        #bounding box needs to have even heights and widths in order to save as mp4
        if (self.arenaBB[1][0] - self.arenaBB[0][0]) % 2:
            self.arenaBB[1][0] += 1
        if (self.arenaBB[1][1] - self.arenaBB[0][1]) % 2:
            self.arenaBB[1][1] += 1


        self.getArenaMask()
        self.trackWidget.setBackgroundImage()
                                             
    def getFishSize(self):
        self.arenaMain.statusBar().showMessage('Click on the tip of the fish tail.')
        self.currFishClick = 0
        self.fishSize = []
        self.arenaMain.ftDisp.clicked.connect(self.handleFishClicks) 		

    @QtCore.pyqtSlot(int, int) #not critical but could practice to specify which functions are slots.
    def handleFishClicks(self, x, y):
        self.currFishClick+=1
        if self.currFishClick == 1:
            self.fishSize.append((x,y))
            self.arenaMain.statusBar().showMessage('Click on the tip of the fish head.')
        elif self.currFishClick == 2:
            self.arenaMain.ftDisp.clicked.disconnect(self.handleFishClicks)
            self.fishSize.append((x,y))
            self.fishImg = cv.CloneImage(self.currCvFrame)
            self.arenaMain.statusBar().showMessage('')

    #---------------------------------------------------
    # HELPER METHODS
    #---------------------------------------------------

    def processArenaCorners(self, arenaCorners, linePosition):
        #return the line dividing the center of the arena, and a definition of side 1.
        a = 1-linePosition
        b = linePosition
        ac = np.array(arenaCorners)
        #arenaDivideLine = [tuple(np.mean(ac[(0,3),:],axis = 0)),tuple(np.mean(ac[(1,2),:],axis = 0))]
        arenaDivideLine = [(a*ac[0,0]+b*ac[3,0], a*ac[0,1]+b*ac[3,1]),(a*ac[1,0]+b*ac[2,0], a*ac[1,1]+b*ac[2,1])]
        side1Sign = 1
        if not self.isOnSide(arenaCorners[1], arenaDivideLine, side1Sign):
            side1Sign = -1	
        return (arenaDivideLine, side1Sign)

    def isOnSide(self, point, line, sideSign):
        #return if the fish is on side1 of the arena.   
        side = (line[1][0] - line[0][0]) * (point[1] - line[0][1]) - (line[1][1] - line[0][1]) * (point[0] - line[0][0])
        return cmp(side,0)==sideSign  

    #convert the arena corners into a color mask image (arena=255, not=0)    
    def getArenaMask(self): 
        if self.arenaView:
            cvImg = self.currCvFrame
            self.arenaCvMask = cv.CreateImage((cvImg.width,cvImg.height), cvImg.depth, cvImg.channels) 
            cv.SetZero(self.arenaCvMask)
            cv.FillConvexPoly(self.arenaCvMask, self.arenaCamCorners, (255,)*cvImg.channels)	
            self.trackWidget.setTrackMask(self.arenaCvMask)

    def initArenaData(self):
        #prepare output data structure
        self.arenaData = {}
        self.arenaData['fishbirthday'] = str(self.infoDOB.date().toPyDate())
        self.arenaData['fishage'] =  (datetime.date.today() - self.infoDOB.date().toPyDate()).days
        self.arenaData['fishstrain'] = str(self.infoType.text())
        self.arenaData['fishsize'] = self.fishSize

        self.arenaData['parameters'] = { 'Num Sessions':self.paramNumSession.value(),
                                         'Between (s)':self.paramBetween.value(),
                                         'OMRPerSession':self.paramOMRPerSession.value(),
                                         'OMR Duration (s)':self.paramOMRDuration.value(),
                                         'OMR Interval (s)':self.paramOMRInterval.value(),
                                         'OMR Period (mm)': self.paramOMRPeriod.value(),
                                         'OMR Duty Cycle': self.paramOMRDutyCycle.value(),
                                         'OMR Velocity':self.paramOMRVelocity.value(),
                                         'NeutralColor':str(self.paramColorNeutral.currentText()),
                                         'OMRColor':str(self.paramColorOMR.currentText()),
                                         'states':self.states.keys(),
                                         'CodeVersion':None }

        self.arenaData['trackingParameters'] = self.trackWidget.getParameterDictionary()
        self.arenaData['trackingParameters']['arenaPoly'] = self.arenaCamCorners 
        self.arenaData['trackingParameters']['arenaDivideLine'] = self.arenaMidLine
        self.arenaData['trackingParameters']['arenaSide1Sign'] = self.arenaSide1Sign
        self.arenaData['projectorParameters'] = {'position':[self.projX.value(), self.projY.value()],
                                                 'size':[self.projLen.value(), self.projWid.value()],
                                                 'rotation':self.projRot.value()}
        self.arenaData['tankSize_mm'] = [self.tankLength.value(), self.tankWidth.value()]
        self.arenaData['trials'] = list() #outcome on each trial
        self.arenaData['tracking'] = list() #list of tuples (frametime, posx, posy)
        self.arenaData['video'] = list() #list of tuples (frametime, filename)	
        self.arenaData['stateinfo'] = list() #list of times at switch stimulus flipped.
        self.arenaData['shockinfo'] = list()

        t = datetime.datetime.now()
      
        #save experiment images
        self.bcvImgFileName = str(self.infoDir.text()) + os.sep + self.fnResults  + '_BackImg_' + t.strftime('%Y-%m-%d-%H-%M-%S') + '.tiff'
        cv.SaveImage(self.bcvImgFileName, self.trackWidget.getBackgroundImage())	
        if self.fishImg:
            self.fishImgFileName = str(self.infoDir.text()) + os.sep +  self.fnResults + '_FishImg_' + t.strftime('%Y-%m-%d-%H-%M-%S') + '.tiff'
            cv.SaveImage(self.fishImgFileName, self.fishImg)

    def saveResults(self):
        f = open(name=self.jsonFileName, mode='w')
        json.dump(self.arenaData,f)
        f.close()

    def getBrush(self, colorNdx):
        if colorNdx == 0:
            return QtGui.QBrush(QtCore.Qt.white)
        elif colorNdx == 1:
            return QtGui.QBrush(QtCore.Qt.red)           
        elif colorNdx == 2:
            return QtGui.QBrush(QtCore.Qt.blue)           
        elif colorNdx == 3:
            return QtGui.QBrush(QtGui.QColor(128,128,128)) 
        elif colorNdx == 4:
            return QtGui.QBrush(QtCore.Qt.magenta)
        elif colorNdx == 5:
            return QtGui.QBrush(QtCore.Qt.black)
        else:
            return QtGui.QBrush(QtCore.Qt.black)

    def useDebugParams(self):
        pass
Beispiel #13
0
class CameraControl(QGroupBox):
    """ 
    Widget to control camera preview, camera object and evaluate button inputs
    """
    update_image_signal = Signal(np.ndarray, bool)

    def __init__(self, parent=None):
        super(CameraControl, self).__init__(parent)
        # loading settings
        settings = QSettings()

        # load UI components
        self.ui: Ui_main = self.window().ui
        self._first_show = True  # whether form is shown for the first time

        # video path and writer object to record videos
        self.video_dir = settings.value("camera_control/video_dir", ".", str)
        self.recorder: VidWriter = None

        # initialize camera object
        self.cam: AbstractCamera = None  # AbstractCamera as interface class for different cameras
        # if vimba software is installed
        if HAS_VIMBA and not USE_TEST_IMAGE:
            self.cam = VimbaCamera()
            logging.info("Using Vimba Camera")
        else:
            self.cam = TestCamera()
            if not USE_TEST_IMAGE:
                logging.error('No camera found! Fallback to test cam!')
            else:
                logging.info("Using Test Camera")
        self.update()
        self._oneshot_eval = False
        logging.debug("initialized camera control")

    def __del__(self):
        # self.cam.stop_streaming()
        del self.cam

    def showEvent(self, event):
        """
        custom show event

        connects the signals, requests an initial camear image and triggers camera view setup
        """
        if self._first_show:
            self.connect_signals()
            # on first show take snapshot of camera to display
            self.cam.snapshot()
            # prep preview window
            self.ui.camera_prev.prepare()

    def closeEvent(self, event: QtGui.QCloseEvent):
        """
        custom close event

        stops camera and video recorder
        """
        # close camera stream and recorder object
        if self.recorder: self.recorder.close()
        self.cam.stop_streaming()

    def connect_signals(self):
        """ connect all the signals """
        self.cam.new_image_available.connect(self.update_image)
        self.update_image_signal.connect(self.ui.camera_prev.update_image)
        # button signals
        self.ui.startCamBtn.clicked.connect(self.prev_start_pushed)
        self.ui.oneshotEvalBtn.clicked.connect(self.oneshot_eval)
        self.ui.setROIBtn.clicked.connect(self.apply_roi)
        self.ui.resetROIBtn.clicked.connect(self.cam.reset_roi)
        self.ui.syr_mask_chk.stateChanged.connect(self.needle_mask_changed)
        # action menu signals
        self.ui.actionVideo_Path.triggered.connect(self.set_video_path)
        self.ui.actionKalibrate_Size.triggered.connect(self.calib_size)
        self.ui.actionDelete_Size_Calibration.triggered.connect(
            self.remove_size_calib)
        self.ui.actionSave_Image.triggered.connect(self.save_image_dialog)

    def is_streaming(self) -> bool:
        """ 
        Return whether camera object is aquiring frames 
        
        :returns: True if camera is streaming, otherwise False
        """
        return self.cam.is_running

    @Slot(int)
    def needle_mask_changed(self, state):
        """called when needel mask checkbox is clicked  
        enables or disables masking of syringe

        :param state: check state of checkbox
        :type state: int
        """
        if (state == Qt.Unchecked):
            self.ui.camera_prev.hide_mask()
        else:
            self.ui.camera_prev.show_mask()

    @Slot()
    def prev_start_pushed(self, event):
        """
        Start Button pushed event

        If camera is not streaming:

        - Start video recorder if corresponding checkbox is ticked

        - Start camera streaming

        - Lock ROI buttons and checkboxes

        If camera is streaming:

        - Stop video recorder if it was active

        - Stop camera streaming

        - Unlock buttons
        """
        if self.ui.startCamBtn.text() != 'Stop':
            if self.ui.record_chk.isChecked():
                self.start_video_recorder()
            self.cam.start_streaming()
            logging.info("Started camera stream")
            self.ui.startCamBtn.setText('Stop')
            self.ui.record_chk.setEnabled(False)
            self.ui.frameInfoLbl.setText('Running')
            self.ui.setROIBtn.setEnabled(False)
            self.ui.resetROIBtn.setEnabled(False)
        else:
            self.cam.stop_streaming()
            if self.ui.record_chk.isChecked():
                self.stop_video_recorder()
            self.ui.record_chk.setEnabled(True)
            logging.info("Stop camera stream")
            self.ui.startCamBtn.setText('Start')
            self.ui.setROIBtn.setEnabled(True)
            self.ui.resetROIBtn.setEnabled(True)
            self.cam.snapshot()
            self.ui.frameInfoLbl.setText('Stopped')
            self.ui.drpltDataLbl.setText(str(self.ui.camera_prev._droplet))

    def oneshot_eval(self):
        self._oneshot_eval = True
        if not self.cam.is_running: self.cam.snapshot()

    def start_video_recorder(self):
        """
        start the `VidWriter` video recorder

        used to record the video from the camera
        """
        now = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.recorder = VidWriter(filename=self.video_dir + f"/{now}.mp4",
                                  size=self.cam.get_resolution(),
                                  fps=self.cam.get_framerate(),
                                  codec='mpeg4',
                                  preset='ultrafast',
                                  bitrate='5000k')
        #self.recorder.open(self.video_dir + f"/{now}.mp4", 0x21 ,self.cam.get_framerate(),self.cam.get_resolution())
        logging.info(
            f"Start video recording. File: {self.video_dir}/{now}.mp4 Resolution:{str(self.cam.get_resolution())}@{self.cam.get_framerate()}"
        )

    def stop_video_recorder(self):
        """
        stops the video recorder
        """
        self.recorder.close()
        #self.recorder.release()
        logging.info("Stoped video recording")

    @Slot()
    def apply_roi(self):
        """ Apply the ROI selected by the rubberband rectangle """
        x, y, w, h = self.ui.camera_prev.get_roi()
        logging.info(f"Applying ROI pos:({x}, {y}), size:({w}, {h})")
        self.cam.set_roi(x, y, w, h)

    @Slot(np.ndarray)
    def update_image(self, cv_img: np.ndarray):
        """ 
        Slot that gets called when a new image is available from the camera 

        handles the signal :attr:`camera.AbstractCamera.new_image_available`
        
        :param cv_img: the image array from the camera
        :type cv_img: np.ndarray, numpy 2D array

        .. seealso:: :py:meth:`camera_preview.CameraPreview.update_image`, :attr:`camera.AbstractCamera.new_image_available`
        """
        # block image signal to prevent overloading
        blocker = QSignalBlocker(self.cam)
        if self.cam.is_running:
            # evaluate droplet if checkbox checked
            eval = self.ui.evalChk.isChecked() or self._oneshot_eval
            self._oneshot_eval = False
            # display current fps
            self.ui.frameInfoLbl.setText('Running | FPS: ' +
                                         str(self.cam.get_framerate()))
            # save image frame if recording
            if self.ui.record_chk.isChecked():
                self.recorder.write_frame(
                    cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
        elif self._oneshot_eval:
            # enable evaluate for one frame (eg snapshots)
            eval = True
            self._oneshot_eval = False
        else:
            eval = False

        # if ROI size changed, cause update of internal variables for new image dimensions
        if self.cam._image_size_invalid:
            self.ui.camera_prev.invalidate_imagesize()
            self.cam._image_size_invalid = False

        # update preview image
        self.ui.camera_prev.update_image(cv_img, eval)

        # display droplet parameters
        self.ui.drpltDataLbl.setText(str(self.ui.camera_prev._droplet))

        # unblock signals from cam
        blocker.unblock()

    @Slot()
    def set_video_path(self):
        """ update the save path for videos """
        settings = QSettings()
        res = QFileDialog.getExistingDirectory(
            self, "Select default video directory", ".")
        if (res is not None and res != ""):
            self.video_dir = res
            settings.setValue("camera_control/video_dir", res)
            logging.info(f"set default videodirectory to {res}")

    @Slot()
    def calib_size(self):
        """ 
        map pixels to mm 
        
        - place object with known size at default distance from camera
        - call this fcn
        - enter size of object
        - fcn calculates scale factor and saves it to droplet object
        """
        # do oneshot eval and extract height from droplet, then calc scale and set in droplet
        res, ok = QInputDialog.getDouble(
            self, "Size of calib element",
            "Please enter the height of the test subject in mm:", 0, 0, 100)
        if not ok or res == 0.0:
            return
        self._oneshot_eval = True
        droplt = Droplet()  # singleton
        self.cam.snapshot()
        droplt.set_scale(res / droplt._height)
        logging.info(f"set image to real scae to {res / droplt._height}")

    @Slot()
    def remove_size_calib(self):
        drplt = Droplet()
        drplt.set_scale(None)

    @Slot()
    def save_image_dialog(self):
        raw_image = False

        checkBox = QCheckBox("Save raw image")
        checkBox.setChecked(True)
        msgBox = QMessageBox(self)
        msgBox.setWindowTitle("Save Image")
        msgBox.setText(
            "Save the image displayed in the preview or the raw image from the camera"
        )
        msgBox.setIcon(QMessageBox.Question)
        msgBox.addButton(QMessageBox.Ok)
        msgBox.addButton(QMessageBox.Cancel)
        msgBox.setDefaultButton(QMessageBox.Cancel)
        msgBox.setCheckBox(checkBox)

        val = msgBox.exec()
        if val == QMessageBox.Cancel:
            return

        raw_image = msgBox.checkBox().isChecked()
        now = datetime.now().strftime("%y-%m-%d_%H-%M")
        save_path, filter = QFileDialog.getSaveFileName(
            self, "Choose save file",
            f"{userpaths.get_my_pictures()}/screenshot_{now}.png",
            "Images (*.png *.jpg *.bmp)")
        if save_path is None:
            return

        qimg = self.ui.camera_prev.grab_image(raw=raw_image)
        if qimg is None:
            return

        qimg.save(save_path, quality=100)
class ExtensibleVideoMaker(Qt.QMainWindow):
    frame_index_changed = Qt.pyqtSignal(int)
    started_changed = Qt.pyqtSignal(bool)
    paused_changed = Qt.pyqtSignal(bool)
    write_output_changed = Qt.pyqtSignal(bool)
    operation_completed = Qt.pyqtSignal(bool) # Parameter is True if successful, False otherwise
    _advance_frame = Qt.pyqtSignal()

    def __init__(self,
                 video_out_fpath,
                 scale_factor=0.5,
                 video_fps=10,
                 parent=None,
                 ControlPaneClass=ControlPane,
                 GSClass=Qt.QGraphicsScene,
                 GVClass=GV):
        super().__init__(parent)
        self.ffmpeg_writer = None
        self.video_out_fpath = Path(video_out_fpath)
        self.scale_factor = scale_factor
        self.video_fps = video_fps
        self._started = False
        self._paused = False
        self._write_output = True
        self._displayed_frame_idx = -1
        self.setWindowTitle('Video Maker')
        self.gs = GSClass(self)
        self.gs.setBackgroundBrush(Qt.QBrush(Qt.QColor(Qt.Qt.black)))
        self.populate_scene()
        self.gv = GVClass(self.gs)
        self.setCentralWidget(self.gv)
        self.control_pane_dock_widget = Qt.QDockWidget('Control Pane', self)
        self.control_pane_widget = ControlPaneClass(self)
        self.control_pane_dock_widget.setWidget(self.control_pane_widget)
        self.control_pane_dock_widget.setAllowedAreas(Qt.Qt.LeftDockWidgetArea | Qt.Qt.RightDockWidgetArea)
        self.addDockWidget(Qt.Qt.RightDockWidgetArea, self.control_pane_dock_widget)
        # QueuedConnection so that a stop or pause click can sneak in between frames and so that our
        # call stack doesn't grow with the number of sequentially rendered frames
        self._advance_frame.connect(self.on_advance_frame, Qt.Qt.QueuedConnection)

    def __del__(self):
        if self.ffmpeg_writer is not None:
            self.ffmpeg_writer.close()

    @staticmethod
    def normalize_intensity(im):
        imn = im.astype(numpy.float32)
        imn *= 255/im.max()
        imn = imn.astype(numpy.uint8)
        return numpy.dstack((imn, imn, imn))

    @classmethod
    def read_image_file_into_qpixmap_item(cls, im_fpath, qpixmap_item, required_size=None):
        im = cls.normalize_intensity(freeimage.read(str(im_fpath)))
        qim = Qt.QImage(sip.voidptr(im.ctypes.data), im.shape[0], im.shape[1], Qt.QImage.Format_RGB888)
        if required_size is not None and qim.size() != required_size:
            raise ValueError('Expected {}x{} image, but "{}" is {}x{}.'.format(required_size.width(), required_size.height(), str(im_fpath), qim.size().width(), qim.size().height()))
        qpixmap_item.setPixmap(Qt.QPixmap(qim))

    def populate_scene(self):
        raise NotImplementedError()

    def on_start(self):
        if self._write_output:
            if self.video_out_fpath.exists():
                self.video_out_fpath.unlink()
            scene_rect = self.gs.sceneRect()
            desired_size = numpy.array((scene_rect.width(), scene_rect.height()), dtype=int)
            desired_size *= self.scale_factor
            # Non-divisble-by-two value width or height causes problems for some codecs
            desired_size += desired_size % 2
            # 24-bit RGB QImage rows are padded to 32-bit chunks, which we must match
            row_stride = desired_size[0] * 3
            row_stride_unpadding = row_stride % 4
            if row_stride_unpadding > 0:
                row_stride_padding = 4 - row_stride_unpadding
                row_stride += row_stride_padding
            # If NPY_RELAXED_STRIDES_CHECKING=1 was defined when your copy of numpy was built, the following commented code would work.  However, typically,
            # NPY_RELAXED_STRIDES_CHECKING=1 is not set, although the numpy manual states that it is eventually to become standard.
#           self._buffer = numpy.ndarray(
#               shape=(desired_size[1], desired_size[0], 3),
#               strides=(row_stride, 3, 1),
#               dtype=numpy.uint8)
#           self._qbuffer = Qt.QImage(sip.voidptr(self._buffer.ctypes.data), desired_size[0], desired_size[1], Qt.QImage.Format_RGB888)
            # Making the buffer wide enough to accommodate any padding and feeding a view of the buffer that excludes the padded regions works everywhere.
            self._buffer = numpy.empty(row_stride * desired_size[1], dtype=numpy.uint8)
            bdr = self._buffer.reshape((desired_size[1], row_stride))
            bdr = bdr[:, :desired_size[0]*3]
            self._buffer_data_region = bdr.reshape((desired_size[1], desired_size[0], 3))
            self._qbuffer = Qt.QImage(sip.voidptr(self._buffer.ctypes.data), desired_size[0], desired_size[1], Qt.QImage.Format_RGB888)

#           self.log_file = open(str(self._dpath / 'video.log'), 'w')
            self.ffmpeg_writer = FFMPEG_VideoWriter(str(self.video_out_fpath), desired_size, fps=self.video_fps, codec='mpeg4', preset='veryslow', bitrate='15000k')#, logfile=self.log_file)
        self._displayed_frame_idx = -1
        self.frame_index_changed.emit(self._displayed_frame_idx)

    def on_stop(self):
        if self._write_output:
            self.ffmpeg_writer.close()
            self.ffmpeg_writer = None

    def on_advance_frame(self):
        if not self._started or self._paused:
            return
        try:
            if not self.update_scene_for_current_frame():
                self.started = False
                self.operation_completed.emit(True)
                return
            self.gs.invalidate()
            if self._write_output:
                self._buffer[:] = 0
                qpainter = Qt.QPainter()
                qpainter.begin(self._qbuffer)
                self.gs.render(qpainter)
                qpainter.end()
                self.ffmpeg_writer.write_frame(self._buffer_data_region)
            self._displayed_frame_idx += 1
            self.frame_index_changed.emit(self._displayed_frame_idx)
            self._advance_frame.emit()
        except Exception as e:
            self.started = False
            self.operation_completed.emit(False)
            raise e

    def update_scene_for_current_frame(self):
        # Return True if scene was updated and is prepared to be rendered.
        # Return False if the last frame has already been rendered, the video file should be finalized and closed, and video creation is considered to have completed successfully.
        # Throw an exception if video creation should not or can not continue and should be considered failed.
        raise NotImplementedError()

    @property
    def started(self):
        return self._started

    @started.setter
    def started(self, started):
        if started != self._started:
            if started:
                self.on_start()
                self._started = True
                self.started_changed.emit(started)
                self._advance_frame.emit()
            else:
                self._started = False
                self.started_changed.emit(started)
                self.on_stop()

    @property
    def paused(self):
        return self._paused

    @paused.setter
    def paused(self, paused):
        if paused != self._paused:
            self._paused = paused
            self.paused_changed.emit(self._paused)
            if not paused and self._started:
                self._advance_frame.emit()

    @property
    def write_output(self):
        return self._write_output

    @write_output.setter
    def write_output(self, write_output):
        if write_output != self._write_output:
            if self._started:
                raise ValueError('write_output property can not be modified while started is True.')
            self._write_output = write_output
            self.write_output_changed.emit(write_output)

    @property
    def frame_index(self):
        """Index of last frame rendered since started."""
        return self._displayed_frame_idx
def optimize_input(obj,
                   model,
                   param_f,
                   transforms,
                   lr=0.05,
                   step_n=512,
                   num_output_channels=4,
                   do_render=False,
                   out_name="out"):

    sess = create_session()

    # Set up optimization problem
    size = 84
    t_size = tf.placeholder_with_default(size, [])
    T = render.make_vis_T(
        model,
        obj,
        param_f=param_f,
        transforms=transforms,
        optimizer=tf.train.AdamOptimizer(lr),
    )

    tf.global_variables_initializer().run()

    if do_render:
        video_fn = out_name + '.mp4'
        writer = FFMPEG_VideoWriter(video_fn, (size, size * 4), 60.0)

    # Optimization loop
    try:
        for i in range(step_n):
            _, loss, img = sess.run([T("vis_op"), T("loss"), T("input")])

            if do_render:
                #if outputting only one channel...
                if num_output_channels == 1:
                    img = img[..., -1:]  #print(img.shape)
                    img = np.tile(img, 3)
                else:
                    #img=img[...,-3:]
                    img = img.transpose([0, 3, 1, 2])
                    img = img.reshape([84 * 4, 84, 1])
                    img = np.tile(img, 3)
                writer.write_frame(_normalize_array(img))
                if i > 0 and i % 50 == 0:
                    clear_output()
                    print("%d / %d  score: %f" % (i, step_n, loss))
                    show(img)

    except KeyboardInterrupt:
        pass
    finally:
        if do_render:
            print("closing...")
            writer.close()

    # Save trained variables
    if do_render:
        train_vars = sess.graph.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES)
        params = np.array(sess.run(train_vars), object)
        save(params, out_name + '.npy')

        # Save final image
        final_img = T("input").eval({t_size: 600})[..., -1:]  #change size
        save(final_img, out_name + '.jpg', quality=90)

    out = T("input").eval({t_size: 84})
    sess.close()
    return out