Beispiel #1
0
def get_config(model, fake=False):
    nr_tower = max(get_num_gpu(), 1)
    assert args.batch % nr_tower == 0
    batch = args.batch // nr_tower

    logger.info("Running on {} towers. Batch size per tower: {}".format(
        nr_tower, batch))
    if batch < 32 or batch > 64:
        logger.warn(
            "Batch size per tower not in [32, 64]. This probably will lead to worse accuracy than reported."
        )
    if fake:
        data = QueueInput(
            FakeData([[batch, 224, 224, 3], [batch]],
                     1000,
                     random=False,
                     dtype='uint8'))
        callbacks = []
    else:
        data = QueueInput(get_data('train', batch))

        START_LR = 0.1
        BASE_LR = START_LR * (args.batch / 256.0)
        callbacks = [
            ModelSaver(),
            EstimatedTimeLeft(),
            ScheduledHyperParamSetter('learning_rate',
                                      [(0, min(START_LR, BASE_LR)),
                                       (30, BASE_LR * 1e-1),
                                       (60, BASE_LR * 1e-2),
                                       (90, BASE_LR * 1e-3),
                                       (100, BASE_LR * 1e-4)]),
        ]
        if BASE_LR > START_LR:
            callbacks.append(
                ScheduledHyperParamSetter('learning_rate', [(0, START_LR),
                                                            (5, BASE_LR)],
                                          interp='linear'))

        infs = [
            ClassificationError('wrong-top1', 'val-error-top1'),
            ClassificationError('wrong-top5', 'val-error-top5')
        ]
        dataset_val = get_data('val', batch)
        if nr_tower == 1:
            # single-GPU inference with queue prefetch
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            # multi-GPU inference (with mandatory queue prefetch)
            callbacks.append(
                DataParallelInferenceRunner(dataset_val, infs,
                                            list(range(nr_tower))))

    return TrainConfig(
        model=model,
        data=data,
        callbacks=callbacks,
        steps_per_epoch=100 if args.fake else 1281167 // args.batch,
        max_epoch=105,
    )
def load_from_cache(name, ctime=0):
    fn_cache = os.path.join(cfg.DATA.CACHEDIR, name + '.pkl')
    if ctime > os.path.getmtime(fn_cache):
        logger.warn('cache file is older than the dataset file')
        # raise IOError('cache file is older than the dataset file')
    try:
        with open(fn_cache, 'rb') as fh:
            dataset = pickle.load(fh)['dataset']
    except:
        raise IOError
    return dataset
def get_config(model, fake=False):
    nr_tower = max(get_num_gpu(), 1)
    assert args.batch % nr_tower == 0
    batch = args.batch // nr_tower

    logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
    if batch < 32 or batch > 64:
        logger.warn("Batch size per tower not in [32, 64]. This probably will lead to worse accuracy than reported.")
    if fake:
        data = QueueInput(FakeData(
            [[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8'))
        callbacks = []
    else:
        data = QueueInput(get_data('train', batch))

        START_LR = 0.1
        BASE_LR = START_LR * (args.batch / 256.0)
        callbacks = [
            ModelSaver(),
            EstimatedTimeLeft(),
            ScheduledHyperParamSetter(
                'learning_rate', [
                    (0, min(START_LR, BASE_LR)), (30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
                    (90, BASE_LR * 1e-3), (100, BASE_LR * 1e-4)]),
        ]
        if BASE_LR > START_LR:
            callbacks.append(
                ScheduledHyperParamSetter(
                    'learning_rate', [(0, START_LR), (5, BASE_LR)], interp='linear'))

        infs = [ClassificationError('wrong-top1', 'val-error-top1'),
                ClassificationError('wrong-top5', 'val-error-top5')]
        dataset_val = get_data('val', batch)
        if nr_tower == 1:
            # single-GPU inference with queue prefetch
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            # multi-GPU inference (with mandatory queue prefetch)
            callbacks.append(DataParallelInferenceRunner(
                dataset_val, infs, list(range(nr_tower))))

    return TrainConfig(
        model=model,
        data=data,
        callbacks=callbacks,
        steps_per_epoch=100 if args.fake else 1281167 // args.batch,
        max_epoch=105,
    )
    def display(self, return_rgb_array=False):
        # pass
        # --------------------------------------------------------------------
        ## planes seen by the agent
        # # get image and convert it to pyglet
        # plane = self._plane.grid[:,:,round(self.depth/2)] # z-plane
        # # concatenate groundtruth image
        # gt_plane = self._groundTruth_plane.grid[:,:,round(self.depth/2)]
        # --------------------------------------------------------------------
        ## whole plan
        # image_size = (int(min(self._image_dims)),)*3
        image_size = self._image_dims
        current_plane = Plane(*getPlane(self.sitk_image,
                                        self._origin3d_point,
                                        self._plane.params,
                                        image_size,
                                        spacing=[1, 1, 1]))

        # get image and convert it to pyglet
        plane = current_plane.grid[:, :, int(image_size[2] / 2)]  # z-plane
        # concatenate groundtruth image
        gt_plane = self.groundTruth_plane_iso.grid[:, :,
                                                   int(image_size[2] / 2)]
        # --------------------------------------------------------------------
        # concatenate two planes side by side
        plane = np.concatenate((plane, gt_plane), axis=1)
        #
        img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 5
        scale_y = 5
        # img = cv2.resize(img,
        #                  (int(scale_x*img.shape[1]),int(scale_y*img.shape[0])),
        #                  interpolation=cv2.INTER_LINEAR)
        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_x=1,
                                            scale_y=1,
                                            filepath=self.filename)
            self.gif_buffer = []

        # display image
        self.viewer.draw_image(img)
        self.viewer.display_text('Current Plane',
                                 color=(0, 0, 204, 255),
                                 x=int(0.7 * img.shape[1] / 7),
                                 y=img.shape[0] - 3)
        self.viewer.display_text('Ground Truth',
                                 color=(0, 0, 204, 255),
                                 x=int(4.3 * img.shape[1] / 7),
                                 y=img.shape[0] - 3)

        # display info
        dist_color_flag = False
        if len(self._dist_history) > 1:
            dist_color_flag = self.cur_dist < self._dist_history[-2]

        color_dist = (0, 204, 0, 255) if dist_color_flag else (204, 0, 0, 255)
        text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm'
        self.viewer.display_text(text,
                                 color=color_dist,
                                 x=int(3 * img.shape[1] / 8),
                                 y=5 * scale_y)

        dist_color_flag = False
        if len(self._dist_history_params) > 1:
            dist_color_flag = self.cur_dist_params < self._dist_history_params[
                -2]

        # color_dist = (0,255,0,255) if dist_color_flag else (255,0,0,255)
        # text = 'Params Error ' + str(round(self.cur_dist_params,3))
        # self.viewer.display_text(text, color=color_dist,
        # x=int(6*img.shape[1]/8), y=5*scale_y)
        text = 'Spacing ' + str(round(self.spacing[0], 3)) + 'mm'
        self.viewer.display_text(text,
                                 color=(204, 204, 0, 255),
                                 x=int(6 * img.shape[1] / 8),
                                 y=5 * scale_y)

        color_reward = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0,
                                                                 255)
        text = 'Reward ' + "%+d" % round(self.reward, 3)
        self.viewer.display_text(text,
                                 color=color_reward,
                                 x=2 * scale_x,
                                 y=5 * scale_y)

        # render and wait (viz) time between frames
        self.viewer.render()

        # save gif
        if self.saveGif:
            image_data = pyglet.image.get_buffer_manager().get_color_buffer(
            ).get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)
            # set_trace()
            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(
                np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            im = Image.fromarray(arr).convert('P')
            self.gif_buffer.append(im)

            if not self.terminal:
                gifname = self.filename.split('.')[0] + '.gif'
                self.viewer.savegif(gifname,
                                    arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            if (self.cnt <= 1):
                if os.path.isdir(dirname):
                    logger.warn(
                        """Log directory {} exists! Use 'd' to delete it. """.
                        format(dirname))
                    act = input("select action: d (delete) / q (quit): "
                                ).lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if self.terminal:
                save_cmd = [
                    'ffmpeg', '-f', 'image2', '-framerate', '30',
                    '-pattern_type', 'sequence', '-start_number', '0', '-r',
                    '3', '-i', dirname + '/%04d.png', '-s', '1280x720',
                    '-vcodec', 'libx264', '-b:v', '2567k',
                    self.filename + '.mp4'
                ]
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Beispiel #5
0
from tensorpack import TowerContext, logger, PlaceholderInput
from tensorpack.tfutils import varmanip, get_model_loader

parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file')
parser.add_argument('--meta', help='metagraph file')
parser.add_argument(dest='model')
parser.add_argument(dest='output')
args = parser.parse_args()

assert args.config or args.meta, "Either config or metagraph must be present!"

with tf.Graph().as_default() as G:
    if args.config:
        logger.warn(
            "Using a config script is not reliable. Please use metagraph.")
        MODEL = imp.load_source('config_script', args.config).Model
        M = MODEL()
        with TowerContext('', is_training=False):
            input = PlaceholderInput()
            input.setup(M.get_inputs_desc())
            M.build_graph(input)
    else:
        tf.train.import_meta_graph(args.meta)

    # loading...
    init = get_model_loader(args.model)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    init.init(sess)
Beispiel #6
0
    def display(self, return_rgb_array=False):
        # pass
        # get dimensions
        current_point = self._location
        target_point = self._target_loc

        # get image and convert it to pyglet
        plane = self.get_plane(current_point[2])  # z-plane

        # plane = np.squeeze(self._current_state()[:,:,13])
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 1
        scale_y = 1
        img = cv2.resize(plane,
                         (int(scale_x*plane.shape[1]), int(scale_y*plane.shape[0])),
                         interpolation=cv2.INTER_LINEAR)

        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_x=1,
                                            scale_y=1,
                                            filepath=self.filename)
            self.gif_buffer = []

        # display image
        self.viewer.draw_image(img)

        # draw current point
        self.viewer.draw_circle(radius=scale_x * 1,
                                pos_x=scale_x * current_point[0],
                                pos_y=scale_y * current_point[1],
                                color=(0.0, 0.0, 1.0, 1.0))

        # draw a box around the agent - what the network sees ROI
        self.viewer.draw_rect(scale_x*self.rectangle.xmin, scale_y*self.rectangle.ymin,
                              scale_x*self.rectangle.xmax, scale_y*self.rectangle.ymax)

        self.viewer.display_text('Agent ', color=(204, 204, 0, 255),
                                 x=self.rectangle.xmin - 15,
                                 y=self.rectangle.ymin)
        # display info
        text = 'Spacing ' + str(self.xscale)
        self.viewer.display_text(text, color=(204, 204, 0, 255),
                                 x=10, y=self._image_dims[1]-80)

        # ---------------------------------------------------------------------
        if (self.task != 'play'):
            # draw a transparent circle around target point with variable radius
            # based on the difference z-direction
            diff_z = scale_x * abs(current_point[2]-target_point[2])

            self.viewer.draw_circle(radius=diff_z,
                                    pos_x=scale_x*target_point[0],
                                    pos_y=scale_y*target_point[1],
                                    color=(1.0, 0.0, 0.0, 0.2))
            # draw target point
            self.viewer.draw_circle(radius=scale_x * 1,
                                    pos_x=scale_x*target_point[0],
                                    pos_y=scale_y*target_point[1],
                                    color=(1.0, 0.0, 0.0, 1.0))
            # display info
            color = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0, 255)
            text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm'
            self.viewer.display_text(text, color=color, x=10, y=20)

        # ---------------------------------------------------------------------

        # render and wait (viz) time between frames
        self.viewer.render()

        # time.sleep(self.viz)
        # save gif
        if self.saveGif:
            image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)

            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(np.reshape(arr, (image_data.height, image_data.width, -1)), 0)

            im = Image.fromarray(arr)
            self.gif_buffer.append(im)

            if not self.terminal:
                gifname = self.filename.split('.')[0] + '.gif'
                self.viewer.saveGif(gifname, arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            if self.cnt <= 1:
                if os.path.isdir(dirname):
                    logger.warn("""Log directory {} exists! Use 'd' to delete it. """.format(dirname))
                    act = input("select action: d (delete) / q (quit): ").lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if self.terminal:
                resolution = str(3 * self.viewer.img_width) + 'x' + str(3 * self.viewer.img_height)
                save_cmd = ['ffmpeg', '-f', 'image2', '-framerate', '30',
                            '-pattern_type', 'sequence', '-start_number', '0', '-r',
                            '6', '-i', dirname + '/%04d.png', '-s', resolution,
                            '-vcodec', 'libx264', '-b:v', '2567k', self.filename + '.mp4']
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Beispiel #7
0
    def display(self, return_rgb_array=False):
        # Initializations
        planes = np.flipud(
            np.transpose(self.get_plane(self._location[0][2], agent=0)))
        shape = np.shape(planes)

        target_points = []
        current_points = []

        for i in range(self.agents):
            # get landmarks
            current_points.append(self._location[i])
            if self.task != 'play':
                target_points.append(self._target_loc[i])
            else:
                target_points.append(None)
            # get current plane
            current_plane = np.flipud(
                np.transpose(self.get_plane(current_points[i][2], agent=i)))

            if i > 0:
                # get image in z-axis
                planes = np.hstack((planes, current_plane))

        shifts_x = [np.shape(current_plane)[1] * i for i in range(self.agents)]
        shifts_y = [0] * self.agents

        # get image and convert it to pyglet + convert to rgb
        # # horizontal concat
        # planes = np.array(planes)#.ravel(order='C') # C for cardiac
        # np.transpose(planes, (2,1,0))
        # img = cv2.cvtColor(np.flipud(planes.reshape((shape[1],
        #                                   shape[0]*shape[2]),
        #                                   order='C')), # F for cardiac
        #                    cv2.COLOR_GRAY2RGB)
        # # vertical concat
        # planes = np.array(planes)
        # img = cv2.cvtColor(planes.reshape(shape[0]*shape[1], shape[2]),
        #                    cv2.COLOR_GRAY2RGB)

        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_y = 1
        scale_x = 1
        img = cv2.resize(
            planes,
            (int(scale_y * planes.shape[1]), int(scale_x * planes.shape[0])),
            interpolation=cv2.INTER_LINEAR)

        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_y=1,
                                            scale_x=1,
                                            filepath=self.filename[i] + str(i))
            self.gif_buffer = []
        # display image
        self.viewer.draw_image(img)

        # plot landmarks
        for i in range(self.agents):
            # get landmarks - correct location if image is flipped and tranposed
            current_point = (shape[0] - current_points[i][1] + shifts_y[i],
                             current_points[i][0] + shifts_x[i],
                             current_points[i][2])
            if self.task != 'play':
                target_point = (shape[0] - target_points[i][1] + shifts_y[i],
                                target_points[i][0] + shifts_x[i],
                                target_points[i][2])
            # draw current point
            self.viewer.draw_circle(radius=scale_x * 1,
                                    pos_y=scale_y * current_point[1],
                                    pos_x=scale_x * current_point[0],
                                    color=(0.0, 0.0, 1.0, 1.0))
            # draw a box around the agent - what the network sees ROI
            # - correct location if image is flipped
            self.viewer.draw_rect(
                scale_y * (shape[0] - self.rectangle[i].ymin + shifts_y[i]),
                scale_x * (self.rectangle[i].xmin + shifts_x[i]),
                scale_y * (shape[0] - self.rectangle[i].ymax + shifts_y[i]),
                scale_x * (self.rectangle[i].xmax + shifts_x[i])),
            self.viewer.display_text(
                'Agent ' + str(i),
                color=(204, 204, 0, 255),
                x=scale_y * (shape[0] - self.rectangle[i].ymin + shifts_y[i]),
                y=scale_x * (self.rectangle[i].xmin + shifts_x[i]))
            # display info
            text = 'Spacing ' + str(self.xscale)
            self.viewer.display_text(text, color=(204, 204, 0, 255), x=8, y=8)
            #self._image_dims[1]-(int)(0.2*self._image_dims[1])-5)

            # -----------------------------------------------------------------
            if self.task != 'play':
                # draw a transparent circle around target point with variable radius
                # based on the difference z-direction
                diff_z = scale_x * abs(current_point[2] - target_point[2])
                self.viewer.draw_circle(radius=diff_z,
                                        pos_x=scale_x * target_point[0],
                                        pos_y=scale_y * target_point[1],
                                        color=(1.0, 0.0, 0.0, 0.2))
                # draw target point
                self.viewer.draw_circle(radius=scale_x * 1,
                                        pos_x=scale_x * target_point[0],
                                        pos_y=scale_y * target_point[1],
                                        color=(1.0, 0.0, 0.0, 1.0))
                # display info
                color = (0, 204, 0, 255) if self.reward[i] > 0 else (204, 0, 0,
                                                                     255)
                text = 'Error - ' + 'Agent ' + str(i) + ' - ' + str(
                    round(self.cur_dist[i], 3)) + 'mm'
                self.viewer.display_text(
                    text,
                    color=color,
                    x=scale_y * (int(1.0 * shape[0]) - 15 + shifts_y[i]),
                    y=scale_x * (8 + shifts_x[i]))

        # -----------------------------------------------------------------

        # render and wait (viz) time between frames
        self.viewer.render()
        # time.sleep(self.viz)
        # save gif
        if self.saveGif:

            image_data = pyglet.image.get_buffer_manager().get_color_buffer(
            ).get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)
            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(
                np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            im = Image.fromarray(arr)
            self.gif_buffer.append(im)

            if all(self.terminal):
                gifname = self.filename[0].split('.')[0] + '_{}.gif'.format(i)
                self.viewer.saveGif(gifname,
                                    arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video_cardiac'
            if self.cnt <= 1:
                if os.path.isdir(dirname):
                    logger.warn(
                        """Log directory {} exists! Use 'd' to delete it. """.
                        format(dirname))
                    act = input("select action: d (delete) / q (quit): "
                                ).lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if all(self.terminal):
                resolution = str(3 * self.viewer.img_width) + 'x' + str(
                    3 * self.viewer.img_height)
                save_cmd = [
                    'ffmpeg', '-f', 'image2', '-framerate', '30',
                    '-pattern_type', 'sequence', '-start_number', '0', '-r',
                    '6', '-i', dirname + '/%04d.png', '-s', resolution,
                    '-vcodec', 'libx264', '-b:v', '2567k',
                    self.filename[0] + '_{}_agents.mp4'.format(i + 1)
                ]
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Beispiel #8
0
    def display(self, return_rgb_array=False):
        # get dimensions
        current_point = self._location
        target_point = self._target_loc
        # get image and convert it to pyglet

        plane = self.get_plane_z(current_point[2])
        plane_x = self.get_plane_x(current_point[0])
        plane_y = self.get_plane_y(current_point[1])

        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 2
        scale_y = 2
        scale_z = 2
        current_point = (current_point[0] * scale_x,
                         current_point[1] * scale_y,
                         current_point[2] * scale_z)
        if target_point is not None:
            target_point = (target_point[0] * scale_x,
                            target_point[1] * scale_y,
                            target_point[2] * scale_z)
        self.rectangle = (self.rectangle[0] * scale_x, self.rectangle[1] *
                          scale_x, self.rectangle[2] * scale_y,
                          self.rectangle[3] * scale_y, self.rectangle[4] *
                          scale_z, self.rectangle[5] * scale_z)
        img = cv2.resize(
            plane,
            (int(scale_x * plane.shape[1]), int(scale_y * plane.shape[0])),
            interpolation=cv2.INTER_LINEAR)
        img_x = cv2.resize(
            plane_x,
            (int(scale_x * plane_x.shape[1]), int(scale_y * plane_x.shape[0])),
            interpolation=cv2.INTER_LINEAR)
        img_y = cv2.resize(
            plane_y,
            (int(scale_y * plane_y.shape[1]), int(scale_y * plane_y.shape[0])),
            interpolation=cv2.INTER_LINEAR)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        img_x = cv2.cvtColor(img_x, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        img_y = cv2.cvtColor(img_y, cv2.COLOR_GRAY2RGB)  # congvert to rgb

        ########################################################################
        # PyQt GUI Code Section

        # Section of code to get initial value to be stored in a pickle object
        # (Uncomment if you wish to modify default_data.pickle)
        # viewer_param = {
        #     "arrs": (img, img_x, img_y),
        #     "filepath": self.filename
        # }
        # with open("default_data.pickle", "wb") as f:
        #     viewer_param = pickle.dump(viewer_param, f)
        #     exit()

        # Sleep until resume (for browse mode)
        if self.task != 'browse':
            while self.viewer.right_widget.automatic_mode.thread.pause:
                time.sleep(0.5)

                # Check whether thread should be killed (pause)
                if self.viewer.right_widget.automatic_mode.thread.terminate:
                    exit()

            # Check whether thread should be killed (general)
            if self.viewer.right_widget.automatic_mode.thread.terminate:
                exit()

        # Need to emit signal here (to draw images)
        self.viewer.widget.agent_signal.emit({
            "arrs": (img, img_x, img_y),
            "agent_loc": current_point,
            "target": target_point,
            "error": self.cur_dist,
            "scale": self.xscale,
            "rect": self.rectangle,
            "task": self.task,
            "is_terminal": self.terminal,
            "cnt": self.cnt
        })

        if self.task != 'browse':
            # Control agent speed
            if self.viewer.right_widget.automatic_mode.thread.speed == WorkerThread.FAST:
                time.sleep(0)
            elif self.viewer.right_widget.automatic_mode.thread.speed == WorkerThread.MEDIUM:
                time.sleep(0.5)
            else:
                time.sleep(1.5)

            ########################################################################

            # save gif
            if self.saveGif:
                image_data = pyglet.image.get_buffer_manager(
                ).get_color_buffer().get_image_data()
                data = image_data.get_data('RGB', image_data.width * 3)
                arr = np.array(bytearray(data)).astype('uint8')
                arr = np.flip(
                    np.reshape(arr, (image_data.height, image_data.width, -1)),
                    0)
                im = Image.fromarray(arr)
                self.gif_buffer.append(im)

                if not self.terminal:
                    gifname = self.filename.split('.')[0] + '.gif'
                    self.viewer.saveGif(gifname,
                                        arr=self.gif_buffer,
                                        duration=self.viz)
            if self.saveVideo:
                dirname = 'tmp_video'
                if self.cnt <= 1:
                    if os.path.isdir(dirname):
                        logger.warn(
                            """Log directory {} exists! Use 'd' to delete it. """
                            .format(dirname))
                        act = input("select action: d (delete) / q (quit): "
                                    ).lower().strip()
                        if act == 'd':
                            shutil.rmtree(dirname, ignore_errors=True)
                        else:
                            raise OSError(
                                "Directory {} exits!".format(dirname))
                    os.mkdir(dirname)

                frame = dirname + '/' + '%04d' % self.cnt + '.png'
                pyglet.image.get_buffer_manager().get_color_buffer().save(
                    frame)
                if self.terminal:
                    resolution = str(3 * self.viewer.img_width) + 'x' + str(
                        3 * self.viewer.img_height)
                    save_cmd = [
                        'ffmpeg', '-f', 'image2', '-framerate', '30',
                        '-pattern_type', 'sequence', '-start_number', '0',
                        '-r', '6', '-i', dirname + '/%04d.png', '-s',
                        resolution, '-vcodec', 'libx264', '-b:v', '2567k',
                        self.filename + '.mp4'
                    ]
                    subprocess.check_output(save_cmd)
                    shutil.rmtree(dirname, ignore_errors=True)
Beispiel #9
0
                _import_external_ops(e.message)
            else:
                break

    # loading...
    if input.endswith('.npz'):
        dic = np.load(input)
    else:
        dic = varmanip.load_chkpt_vars(input)
    dic = {get_op_tensor_name(k)[1]: v for k, v in six.iteritems(dic)}

    if args.meta is not None:
        # save variables that are GLOBAL, and either TRAINABLE or MODEL
        var_to_dump = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        var_to_dump.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
        if len(set(var_to_dump)) != len(var_to_dump):
            logger.warn("TRAINABLE and MODEL variables have duplication!")
        var_to_dump = list(set(var_to_dump))
        globvarname = set([k.name for k in tf.global_variables()])
        var_to_dump = set(
            [k.name for k in var_to_dump if k.name in globvarname])

        for name in var_to_dump:
            assert name in dic, "Variable {} not found in the model!".format(
                name)
    else:
        var_to_dump = set(dic.keys())

    dic_to_dump = {k: v for k, v in six.iteritems(dic) if k in var_to_dump}
    varmanip.save_chkpt_vars(dic_to_dump, args.output)
Beispiel #10
0
    def display(self, return_rgb_array=False):
        # pass
        for i in range(0, self.agents):
            # get dimensions
            current_point = self._location[i]
            target_point = None
            if self.task != "play":
                target_point = self._target_loc[i]
            # print("_location", self._location)
            # print("_target_loc", self._target_loc)
            # print("current_point", current_point)
            # print("target_point", target_point)
            # get image and convert it to pyglet
            plane = self.get_plane(current_point[2], agent=i)  # z-plane
            # plane = np.squeeze(self._current_state()[:,:,13])
            img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
            # rescale image
            # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
            scale_x = 2
            scale_y = 2
            #
            img = cv2.resize(
                img,
                (int(scale_x * img.shape[1]), int(scale_y * img.shape[0])),
                interpolation=cv2.INTER_LINEAR,
            )
            # skip if there is a viewer open
            if (not self.viewer) and self.viz:
                from viewer import SimpleImageViewer

                self.viewer = SimpleImageViewer(arr=img,
                                                scale_x=1,
                                                scale_y=1,
                                                filepath=self.filepath[i] +
                                                str(i))
                self.gif_buffer = []
            # display image
            self.viewer.draw_image(img)
            # draw current point
            self.viewer.draw_circle(
                radius=scale_x * 1,
                pos_x=scale_x * current_point[0],
                pos_y=scale_y * current_point[1],
                color=(0.0, 0.0, 1.0, 1.0),
            )
            # draw a box around the agent - what the network sees ROI
            self.viewer.draw_rect(
                scale_x * self.rectangle[i].xmin,
                scale_y * self.rectangle[i].ymin,
                scale_x * self.rectangle[i].xmax,
                scale_y * self.rectangle[i].ymax,
            )
            self.viewer.display_text(
                "Agent " + str(i),
                color=(204, 204, 0, 255),
                x=scale_x * self.rectangle[i].xmin - 15,
                y=scale_y * self.rectangle[i].ymin,
            )
            # display info
            text = "Spacing " + str(self.xscale)
            self.viewer.display_text(text,
                                     color=(204, 204, 0, 255),
                                     x=10,
                                     y=self._image_dims[1] - 80)

            # ---------------------------------------------------------------------

            if self.task != "play":
                # draw a transparent circle around target point with variable radius
                # based on the difference z-direction
                diff_z = scale_x * abs(current_point[2] - target_point[2])
                self.viewer.draw_circle(
                    radius=diff_z,
                    pos_x=scale_x * target_point[0],
                    pos_y=scale_y * target_point[1],
                    color=(1.0, 0.0, 0.0, 0.2),
                )
                # draw target point
                self.viewer.draw_circle(
                    radius=scale_x * 1,
                    pos_x=scale_x * target_point[0],
                    pos_y=scale_y * target_point[1],
                    color=(1.0, 0.0, 0.0, 1.0),
                )
                # display info
                color = (0, 204, 0, 255) if self.reward[i] > 0 else (204, 0, 0,
                                                                     255)
                text = "Error " + str(round(self.cur_dist[i], 3)) + "mm"
                self.viewer.display_text(text, color=color, x=10, y=20)

            # ---------------------------------------------------------------------

            # render and wait (viz) time between frames
            self.viewer.render()
            # time.sleep(self.viz)
            # save gif
            if self.saveGif:
                image_data = (pyglet.image.get_buffer_manager().
                              get_color_buffer().get_image_data())
                data = image_data.get_data("RGB", image_data.width * 3)
                arr = np.array(bytearray(data)).astype("uint8")
                arr = np.flip(
                    np.reshape(arr, (image_data.height, image_data.width, -1)),
                    0)
                im = Image.fromarray(arr)
                self.gif_buffer.append(im)

                if not self.terminal[i]:
                    gifname = self.filepath[0] + ".gif"
                    self.viewer.saveGif(gifname,
                                        arr=self.gif_buffer,
                                        duration=self.viz)
            if self.saveVideo:
                dirname = "tmp_video"
                if self.cnt <= 1:
                    if os.path.isdir(dirname):
                        logger.warn(
                            """Log directory {} exists! Use 'd' to delete it. """
                            .format(dirname))
                        act = (input("select action: d (delete) / q (quit): ").
                               lower().strip())
                        if act == "d":
                            shutil.rmtree(dirname, ignore_errors=True)
                        else:
                            raise OSError(
                                "Directory {} exits!".format(dirname))
                    os.mkdir(dirname)

                frame = dirname + "/" + "%04d" % self.cnt + ".png"
                pyglet.image.get_buffer_manager().get_color_buffer().save(
                    frame)
                if self.terminal[i]:
                    resolution = (str(3 * self.viewer.img_width) + "x" +
                                  str(3 * self.viewer.img_height))
                    save_cmd = [
                        "ffmpeg",
                        "-f",
                        "image2",
                        "-framerate",
                        "30",
                        "-pattern_type",
                        "sequence",
                        "-start_number",
                        "0",
                        "-r",
                        "6",
                        "-i",
                        dirname + "/%04d.png",
                        "-s",
                        resolution,
                        "-vcodec",
                        "libx264",
                        "-b:v",
                        "2567k",
                        self.filepath[i] + ".mp4",
                    ]
                    subprocess.check_output(save_cmd)
                    shutil.rmtree(dirname, ignore_errors=True)