示例#1
0
文件: video.py 项目: ebarns/cv-fun
 def __init__(self, num_of_cameras=1, record_video=False):
     self.is_capturing_video = True
     self.video_cameras = VideoCameras(num_of_cameras)
     self.record_video = record_video
     self.face_detection = FaceDetection()
     self.face_frame_morpher = FaceFrameMorpher()
     self.video_recorder = VideoRecorder()
示例#2
0
    def __init__(
        self,
        is_record_topic,
        is_memory_usage_exceeded_topic,
        image_topic,
        video_type,
        video_dimensions,
        frames_per_second,
        out_directory,
    ):

        rospy.init_node('video_recorder', anonymous=True)
        self._video_recorder = VideoRecorder(
            video_type=video_type,
            video_dimensions=video_dimensions,
            frames_per_second=frames_per_second,
            out_directory=out_directory,
        )

        self._is_record_subscriber = rospy.Subscriber(is_record_topic, Bool,
                                                      self._is_record_callback)
        self._image_subscriber = rospy.Subscriber(image_topic, Image,
                                                  self._image_callback)
        self._memory_watch_subscriber = rospy.Subscriber(
            is_memory_usage_exceeded_topic, Bool, self._memory_check_callback)
        self._bridge = CvBridge()

        # This flag is used to block recording if memory exceeds limits
        self._allow_recording = True
示例#3
0
文件: video.py 项目: ebarns/cv-fun
class Spoopy:
    def __init__(self, num_of_cameras=1, record_video=False):
        self.is_capturing_video = True
        self.video_cameras = VideoCameras(num_of_cameras)
        self.record_video = record_video
        self.face_detection = FaceDetection()
        self.face_frame_morpher = FaceFrameMorpher()
        self.video_recorder = VideoRecorder()

    def capture_video(self):
        while self.is_capturing_video:
            frames = self.video_cameras.get_frames()

            frames = self.face_frame_morpher.morph_frame_faces(frames)

            self.video_cameras.display_frames(frames)
            self.kill_capture()

    def write_video(self, frame):
        if self.record_video:
            self.video_recorder.write_video_frame(frame)

    def kill_capture(self):
        if cv2.waitKey(33) == ord('a'):
            print("Tearing down capture")
            self.is_capturing_video = False
            self.video_cameras.release()
            self.video_recorder.stop_video_record()
            cv2.destroyAllWindows()
示例#4
0
 def __init__(self, preview=False, max_video_length=MAX_VIDEO_LENGTH):
     log.info("booting up..")
     self.final_dir = self._setup_dirs()
     self.max_video_length = max_video_length
     self.video_recorder = VideoRecorder(preview=preview)
     self.audio_recorder = AudioRecorder()
     time.sleep(2)
     log.info("ready!")
示例#5
0
def test():

    # Set TF / Keras dtype
    tf.keras.backend.set_floatx(Params.DTYPE)

    # Load model
    model = tf.keras.models.load_model(ParamsTest.MODEL_PATH, compile=False)

    # Construct logdir
    log_dir = f"{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_test"

    # Init Logger
    if Params.DO_LOGGING:
        logger = Logger("logs/" + log_dir)
    else:
        logger = None

    # Init Video Recorder
    if Params.RECORD_VIDEO:
        video_recorder = VideoRecorder("recorded/" + log_dir)
    else:
        video_recorder = None

    # Init Env
    env = eval(Params.ENV_NAME)()

    run(logger, env, model, video_recorder)
def real_time_lrp(conf):
    """Method to display feature relevance scores in real time.

    Args:
        conf: Dictionary consisting of configuration parameters.
    """
    record_video = conf["playback"]["record_video"]

    webcam = Webcam()
    lrp = RelevancePropagation(conf)

    if record_video:
        recorder = VideoRecorder(conf)

    while True:
        t0 = time.time()

        frame = webcam.get_frame()
        heatmap = lrp.run(frame)
        heatmap = post_processing(frame, heatmap, conf)
        cv2.imshow("LRP", heatmap)

        if record_video:
            recorder.record(heatmap)

        t1 = time.time()
        fps = 1.0 / (t1 - t0)
        print("{:.1f} FPS".format(fps))

        if cv2.waitKey(1) % 256 == 27:
            print("Escape pressed.")
            break

    if record_video:
        recorder.release()

    webcam.turn_off()
    cv2.destroyAllWindows()
示例#7
0
class App:
    def __init__(self, preview=False, max_video_length=MAX_VIDEO_LENGTH):
        log.info("booting up..")
        self.final_dir = self._setup_dirs()
        self.max_video_length = max_video_length
        self.video_recorder = VideoRecorder(preview=preview)
        self.audio_recorder = AudioRecorder()
        time.sleep(2)
        log.info("ready!")

    def _setup_dirs(self):
        final_dir = os.path.expanduser('~/media/')
        if (os.path.isdir(final_dir) == False):
            os.mkdir(final_dir)
        return final_dir

    def _make_filename(self):
        return datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")

    def has_space(self):
        statvfs = os.statvfs("/")
        megabytes_available = int(statvfs.f_frsize * statvfs.f_bavail / 1024 /
                                  1024)
        log.info(f"Still {megabytes_available}MB left on device")
        return megabytes_available > MIN_DISK_SPACE_MB

    def on_keyboard_release(self, key):
        if key == keyboard.Key.enter:
            if lock.locked():
                self.stop_recording()
            elif self.has_space():
                self.start_recording()
            else:
                return False
        if key == keyboard.Key.esc:
            if lock.locked():
                self.stop_recording()
            return False

    def timer(self, seconds, current_video):
        log.info(f"going to sleep for {seconds}s and then stop recording")
        for i in range(seconds):
            if not lock.locked():
                log.info("looks like recording has ended before timeout")
                return
            elif current_video != self.file_name:
                log.info("there is a different ongoing recording")
                return
            time.sleep(1)
        log.info("time's up!, stopping recording")
        self.stop_recording()

    def start_recording(self):
        lock.acquire()
        self.start_datetime = datetime.datetime.now()
        self.file_name = self._make_filename()
        timer_thread = threading.Thread(target=self.timer,
                                        args=(self.max_video_length,
                                              self.file_name))
        timer_thread.start()
        self.tmp_dir = tempfile.mkdtemp()

        log.info("starting threads...")
        self.video_recorder.start(self.file_name, self.tmp_dir)
        self.audio_recorder.start(self.file_name, self.tmp_dir)

    def stop_recording(self):
        log.info("stopping threads...")
        if not self.audio_recorder.stop():
            return
        if not self.video_recorder.stop():
            return
        now = datetime.datetime.now()
        video_length = (now - self.start_datetime).seconds
        if video_length > MIN_VIDEO_LENGTH:
            log.info("starting mux...")
            cmd = (
                f"ffmpeg -i {self.tmp_dir}/{self.file_name}.wav -i {self.tmp_dir}/{self.file_name}.h264 "
                f"-c:v copy -c:a aac -strict experimental {self.final_dir}/{self.file_name}.mp4"
            )
            subprocess.run(cmd, capture_output=True, shell=True)
            log.info(f"{self.file_name}.mp4 is ready!")
        else:
            log.info(f"Video was to short: {video_length}, removing it")
        shutil.rmtree(self.tmp_dir)
        log.info(f"{self.tmp_dir} removed")
        lock.release()

    def run(self):
        def on_release(button):
            if lock.locked():
                self.stop_recording()
            elif self.has_space():
                self.start_recording()
            else:
                return False

        button = Button(2)
        button.when_released = on_release
        listener = keyboard.Listener(on_release=self.on_keyboard_release)
        listener.start()
        pause()
示例#8
0
args = vars(ap.parse_args())

# Filter warnings.
warnings.filterwarnings("ignore")

# Load the configuration.
conf = json.load(open(args["conf"]))

# Initialize the camera and grab a reference to the raw camera capture.
camera = PiCamera()
camera.resolution = tuple(conf["resolution"])
camera.framerate = conf["fps"]
rawCapture = PiRGBArray(camera, size=tuple(conf["resolution"]))

# Pass the camera object to the Video Recorder.
VideoRecorder.set_camera(camera)

# Set the dir to write the video files to in the Video Recorder
VideoRecorder.set_videos_dir(conf["write_dir"])

# Allow the camera to warmup, then initialize the average frame, last
# uploaded timestamp, and frame motion counter.
print("[INFO] warming up...")
time.sleep(conf["camera_warmup_time"])
avg = None

# Initialize to a long time ago (in a galaxy far, far away...).
last_active_time = datetime.datetime(datetime.MINYEAR, 1, 1)

# Get the min and max x and y values for the timestamp exclusion
# code below.  If the resolution is changed in the conf.json file,
示例#9
0
    def __init__(self):
        super(MainWindow, self).__init__()
        self.ui = Ui_dialog()
        self.ui.setupUi(self)
        self.setFixedSize(self.width(), self.height())
        self.onBindingUI()

        ##
        self.record_header_labels = [
            'who', 'what', 'when', 'where', 'speaking'
        ]

        ## Create detail speaking information window
        self.speaking_ui_cam1 = SpeakingMainWindow()
        self.speaking_ui_cam2 = SpeakingMainWindow()

        ## For audio and video analysis from camera
        global video_thread_cam1
        global video_thread_cam2
        global audio_thread_cam1
        global audio_thread_cam2

        ## For showing Quadro W record
        global nobody_record
        global chung_chih_record
        global yu_sung_record
        global chang_yu_record
        global i_hsin_record
        global tzu_yuan_record
        global chih_yu_record
        global other_record

        ## For Q Learing
        global nobody_q_table
        global chung_chih_q_table
        global yu_sung_q_table
        global chang_yu_q_table
        global i_hsin_q_table
        global tzu_yuan_q_table
        global other_q_table

        ## Create video and audio proces thread and process them to get Quadro W info
        audio_thread_cam1 = AudioRecorder(
            cam_ip='rtsp://*****:*****@192.168.65.66:554/stream1',
            audio_spath='temporarily_rtsp_data/for_speaking/cam1',
            audio_stime=5,
            pos_path='nlp_recognize/pos.json',
            dict_path='nlp_recognize/extra_dict.txt')
        video_thread_cam1 = VideoRecorder(
            cam_ip='rtsp://*****:*****@192.168.65.66:554/stream1',
            record_path='record/',
            qtable_path='q_table/',
            compress_result=True,
            fps=15,
            analysis_sec=1,
            prefix_name="cam1",
            audio_thread=audio_thread_cam1,
            human_model_path=
            "trans_model/model_simulated_RGB_mgpu_scaling_append.0024.pb",
            face_model_path='trans_model/face_new3.pb',
            place_model_path='trans_model/place_new3.pb',
            object_model_path=
            'object_recognize/code/workspace/training_demo/model/pb/frozen_inference_graph.pb',
            face_category_path='face_recognize/categories_human_uscc.txt',
            place_category_path='places_recognize/categories_places_uscc.txt',
            object_category_path=
            'object_recognize/code/workspace/training_demo/annotations/label_map.pbtxt'
        )
        audio_thread_cam2 = AudioRecorder(
            cam_ip='rtsp://*****:*****@192.168.65.41:554/stream1',
            audio_spath='temporarily_rtsp_data/for_speaking/cam2',
            audio_stime=5,
            pos_path='nlp_recognize/pos.json',
            dict_path='nlp_recognize/extra_dict.txt')
        video_thread_cam2 = VideoRecorder(
            cam_ip='rtsp://*****:*****@192.168.65.41:554/stream1',
            record_path='record/',
            qtable_path='q_table/',
            compress_result=True,
            fps=15,
            analysis_sec=1,
            prefix_name="cam2",
            audio_thread=audio_thread_cam2,
            human_model_path=
            "trans_model/model_simulated_RGB_mgpu_scaling_append.0024.pb",
            face_model_path='trans_model/face_new3.pb',
            place_model_path='trans_model/place_new3.pb',
            object_model_path=
            'object_recognize/code/workspace/training_demo/model/pb/frozen_inference_graph.pb',
            face_category_path='face_recognize/categories_human_uscc.txt',
            place_category_path='places_recognize/categories_places_uscc.txt',
            object_category_path=
            'object_recognize/code/workspace/training_demo/annotations/label_map.pbtxt'
        )

        ## Create thread to show record of Quadro W info
        nobody_record = RecordUpdata(csv_path="record/Nobody_record.csv")
        chung_chih_record = RecordUpdata(
            csv_path="record/chung-chih_record.csv")
        yu_sung_record = RecordUpdata(csv_path="record/yu-sung_record.csv")
        chang_yu_record = RecordUpdata(csv_path="record/chang-yu_record.csv")
        i_hsin_record = RecordUpdata(csv_path="record/i-hsin_record.csv")
        tzu_yuan_record = RecordUpdata(csv_path="record/tzu-yuan_record.csv")
        chih_yu_record = RecordUpdata(csv_path="record/chih-yu_record.csv")

        ## Create thread to calculate Q value and show
        nobody_q_table = QLearningUpdata(
            in_record_path="record/Nobody_record.csv",
            out_table_path="q_table/Nobody_qtable.csv",
            where_pool=[[2], [1], []],
            where_category_path="places_recognize/categories_places_uscc.txt",
            care_number=100,
            decay_reward=0.98,
            base_reward=100,
            lower_limit=1,
            decay_qvalue=0.9,
            learning_rate=0.1)
        chung_chih_q_table = QLearningUpdata(
            in_record_path="record/chung-chih_record.csv",
            out_table_path="q_table/chung-chih_qtable.csv",
            where_pool=[[2], [1], []],
            where_category_path="places_recognize/categories_places_uscc.txt",
            care_number=100,
            decay_reward=0.98,
            base_reward=100,
            lower_limit=1,
            decay_qvalue=0.9,
            learning_rate=0.1)
        yu_sung_q_table = QLearningUpdata(
            in_record_path="record/yu-sung_record.csv",
            out_table_path="q_table/yu-sung_qtable.csv",
            where_pool=[[2], [1], []],
            where_category_path="places_recognize/categories_places_uscc.txt",
            care_number=100,
            decay_reward=0.98,
            base_reward=100,
            lower_limit=1,
            decay_qvalue=0.9,
            learning_rate=0.1)
        chang_yu_q_table = QLearningUpdata(
            in_record_path="record/chang-yu_record.csv",
            out_table_path="q_table/chang-yu_qtable.csv",
            where_pool=[[2], [1], []],
            where_category_path="places_recognize/categories_places_uscc.txt",
            care_number=100,
            decay_reward=0.98,
            base_reward=100,
            lower_limit=1,
            decay_qvalue=0.9,
            learning_rate=0.1)
        i_hsin_q_table = QLearningUpdata(
            in_record_path="record/i-hsin_record.csv",
            out_table_path="q_table/i-hsin_qtable.csv",
            where_pool=[[2], [1], []],
            where_category_path="places_recognize/categories_places_uscc.txt",
            care_number=100,
            decay_reward=0.98,
            base_reward=100,
            lower_limit=1,
            decay_qvalue=0.9,
            learning_rate=0.1)
        tzu_yuan_q_table = QLearningUpdata(
            in_record_path="record/tzu-yuan_record.csv",
            out_table_path="q_table/tzu-yuan_qtable.csv",
            where_pool=[[2], [1], []],
            where_category_path="places_recognize/categories_places_uscc.txt",
            care_number=100,
            decay_reward=0.98,
            base_reward=100,
            lower_limit=1,
            decay_qvalue=0.9,
            learning_rate=0.1)
        chih_yu_q_table = QLearningUpdata(
            in_record_path="record/chih-yu_record.csv",
            out_table_path="q_table/chih-yu_qtable.csv",
            where_pool=[[2], [1], []],
            where_category_path="places_recognize/categories_places_uscc.txt",
            care_number=100,
            decay_reward=0.98,
            base_reward=100,
            lower_limit=1,
            decay_qvalue=0.9,
            learning_rate=0.1)

        ## Start all thread
        video_thread_cam1.start()
        video_thread_cam2.start()
        audio_thread_cam1.start()
        audio_thread_cam2.start()

        nobody_record.start()
        chung_chih_record.start()
        yu_sung_record.start()
        chang_yu_record.start()
        i_hsin_record.start()
        tzu_yuan_record.start()
        chih_yu_record.start()

        nobody_q_table.start()
        chih_yu_q_table.start()
        yu_sung_q_table.start()
        chang_yu_q_table.start()
        i_hsin_q_table.start()
        tzu_yuan_q_table.start()
        chang_yu_q_table.start()

        self.timer = QTimer(self)
        self.timer.timeout.connect(self.start_webcam)
        self.timer.start(0)
示例#10
0
    def __init__(self, cfg):
        self.cfg = cfg
        if not os.path.exists(cfg.workdir): os.makedirs(cfg.workdir)
        Tracer.FILE = cfg.workdir + "log_" + datetime.now().strftime(
            "%Y%m%d%H%M%S") + ".txt"
        self.progress = OO(epoch=1,
                           exploration_noise=1,
                           num_train_iteration=0,
                           num_critic_update_iteration=0,
                           num_actor_update_iteration=0)
        self.progress.update(episode=1,
                             episode_step=0,
                             global_step=0,
                             env_top_reward=0,
                             evaluate_reward=0,
                             eval_top_reward=-9e99,
                             train_running_reward=0,
                             exploration_rate_epsilon=1)

        self.workdir = cfg.workdir
        self.meter = Meter(cfg, cfg.workdir)
        self.env = gym.make(cfg.env.name)
        env = self.env

        # 命令行参数、函数调用参数和reload配置文件中定义的参数,都会被这里定义的值重新覆盖
        cfg.override(
            device=torch.device(cfg.device),
            agent=OO(device="${device}", ).__dict__,
        )
        cfg.override(env=OO(
            spec=env.spec,
            reward_range=env.reward_range,
            state_space_dim=env.observation_space.shape,
            state_line_dim=np.prod(env.observation_space.shape),
            state_digest_dim=128,
            # 离散 1 或连续 0
            action_discrete=1 if isinstance(env.action_space, gym.spaces.
                                            Discrete) else 0,
            # 动作维度
            action_dim=env.action_space.
            n if isinstance(env.action_space, gym.spaces.Discrete
                            ) else env.action_space.shape[0],
            # 动作取值下界
            action_min=0 if isinstance(env.action_space, gym.spaces.Discrete
                                       ) else float(env.action_space.low[0]),
            # 动作取值上界
            action_max=env.action_space.
            n if isinstance(env.action_space, gym.spaces.Discrete
                            ) else float(env.action_space.high[0]),
        ))
        cfg.override(agent=OO(
            state_space_dim=cfg.env.state_space_dim,
            state_line_dim=cfg.env.state_line_dim,
            state_digest_dim=cfg.env.state_digest_dim,
            action_discrete=cfg.env.action_discrete,
            action_dim=cfg.env.action_dim,
            action_min=cfg.env.action_min,
            action_max=cfg.env.action_max,
            actor=OO(
                lr=1e-2,
                betas=[0.9, 0.999],
            ),
            critic=OO(
                lr=1e-2,
                betas=[0.9, 0.999],
            ),
        ))

        assert self.cfg.exploration_rate_init >= 0 and self.cfg.exploration_rate_init < 1
        Tracer.trace(self.cfg)

        self.video_recorder = VideoRecorder()
        self.replay_buffer = ReplayBuffer(int(cfg.replay_buffer_capacity))
        self.agent = cfg.agent_class(cfg, self.progress, self.meter,
                                     self.replay_buffer)

        self.temp_buffer = []

        self.load()
        self.last_save_time = time.time()
        self.last_rest_time = time.time()
        self.last_log_time = time.time()
        self.last_meter_time = time.time()

        print("-" * 150)
示例#11
0
class Workspace():
    def __init__(self, cfg):
        self.cfg = cfg
        if not os.path.exists(cfg.workdir): os.makedirs(cfg.workdir)
        Tracer.FILE = cfg.workdir + "log_" + datetime.now().strftime(
            "%Y%m%d%H%M%S") + ".txt"
        self.progress = OO(epoch=1,
                           exploration_noise=1,
                           num_train_iteration=0,
                           num_critic_update_iteration=0,
                           num_actor_update_iteration=0)
        self.progress.update(episode=1,
                             episode_step=0,
                             global_step=0,
                             env_top_reward=0,
                             evaluate_reward=0,
                             eval_top_reward=-9e99,
                             train_running_reward=0,
                             exploration_rate_epsilon=1)

        self.workdir = cfg.workdir
        self.meter = Meter(cfg, cfg.workdir)
        self.env = gym.make(cfg.env.name)
        env = self.env

        # 命令行参数、函数调用参数和reload配置文件中定义的参数,都会被这里定义的值重新覆盖
        cfg.override(
            device=torch.device(cfg.device),
            agent=OO(device="${device}", ).__dict__,
        )
        cfg.override(env=OO(
            spec=env.spec,
            reward_range=env.reward_range,
            state_space_dim=env.observation_space.shape,
            state_line_dim=np.prod(env.observation_space.shape),
            state_digest_dim=128,
            # 离散 1 或连续 0
            action_discrete=1 if isinstance(env.action_space, gym.spaces.
                                            Discrete) else 0,
            # 动作维度
            action_dim=env.action_space.
            n if isinstance(env.action_space, gym.spaces.Discrete
                            ) else env.action_space.shape[0],
            # 动作取值下界
            action_min=0 if isinstance(env.action_space, gym.spaces.Discrete
                                       ) else float(env.action_space.low[0]),
            # 动作取值上界
            action_max=env.action_space.
            n if isinstance(env.action_space, gym.spaces.Discrete
                            ) else float(env.action_space.high[0]),
        ))
        cfg.override(agent=OO(
            state_space_dim=cfg.env.state_space_dim,
            state_line_dim=cfg.env.state_line_dim,
            state_digest_dim=cfg.env.state_digest_dim,
            action_discrete=cfg.env.action_discrete,
            action_dim=cfg.env.action_dim,
            action_min=cfg.env.action_min,
            action_max=cfg.env.action_max,
            actor=OO(
                lr=1e-2,
                betas=[0.9, 0.999],
            ),
            critic=OO(
                lr=1e-2,
                betas=[0.9, 0.999],
            ),
        ))

        assert self.cfg.exploration_rate_init >= 0 and self.cfg.exploration_rate_init < 1
        Tracer.trace(self.cfg)

        self.video_recorder = VideoRecorder()
        self.replay_buffer = ReplayBuffer(int(cfg.replay_buffer_capacity))
        self.agent = cfg.agent_class(cfg, self.progress, self.meter,
                                     self.replay_buffer)

        self.temp_buffer = []

        self.load()
        self.last_save_time = time.time()
        self.last_rest_time = time.time()
        self.last_log_time = time.time()
        self.last_meter_time = time.time()

        print("-" * 150)

    def save(self, best_model=False):
        def safe_torch_save(obj, filename):
            torch.save(obj, filename + '.tmp')
            if os.access(filename, os.F_OK): os.remove(filename)
            os.rename(filename + '.tmp', filename)

        agent_model = {}
        for (name, mod) in self.agent.modules.items():
            agent_model[name] = mod.state_dict()
        safe_torch_save(agent_model, self.cfg.workdir + '/training_model.drl')
        if best_model:
            safe_torch_save(agent_model, self.cfg.workdir + '/agent_model.drl')
        progress = {
            'progress': self.progress.__dict__,
            'meters.train': self.meter.train_mg.data,
            'meters.eval': self.meter.eval_mg.data,
        }
        for (name, optim) in self.agent.optimizers.items():
            progress[name] = optim.state_dict()
        safe_torch_save(progress, self.cfg.workdir + '/progress.drl')
        safe_torch_save({
            'memory': self.replay_buffer.__dict__,
        }, self.cfg.workdir + '/replay_buffer.drl')
        if 'save_prompt_message' in self.cfg and self.cfg.save_prompt_message:
            Tracer.trace("Model has been saved...")

    def load(self):
        def safe_torch_load(filename):
            if os.path.exists(filename):
                return torch.load(filename)
            elif os.path.exists(filename + '.tmp'):
                return torch.load(filename + '.tmp')
            else:
                return None

        o = safe_torch_load(self.cfg.workdir + '/agent_model.drl')
        if o is not None:
            for (name, mod) in self.agent.modules.items():
                if name in o: mod.load_state_dict(o[name])
        o = safe_torch_load(self.cfg.workdir + '/progress.drl')
        if o is not None:
            for (name, optim) in self.agent.optimizers.items():
                if name in o: optim.load_state_dict(o[name])
        if o is not None and 'progress' in o:
            self.progress.__dict__.update(o['progress'])
        if o is not None and 'meters.train' in o:
            self.meter.train_mg.data = o['meters.train']
        if o is not None and 'meters.eval' in o:
            self.meter.eval_mg.data = o['meters.eval']
        o = safe_torch_load(self.cfg.workdir + '/replay_buffer.drl')
        if o is not None and 'memory' in o:
            self.replay_buffer.__dict__.update(o['memory'])
        Tracer.trace("model has been loaded...")

    def evaluate(self):
        self.video_recorder.init()
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_epochs):
            state = self.env.reset()
            pre_state = state
            self.agent.reset()
            done = False
            episode_reward = 0
            step = 0
            while not done and step < self.cfg.eval_max_frame:
                with eval_mode(self.agent):
                    action, action_prob = self.agent.action(
                        pre_state, state, 0)
                prestate = state
                state, reward, done, _ = self.env.step(action)
                self.video_recorder.record(self.env)
                episode_reward += reward
                step += 1
            work_time = time.time() - self.last_rest_time
            time.sleep((work_time *
                        self.cfg.running_idle_rate) if work_time > 0 else 0.01)
            self.last_rest_time = time.time()
            average_episode_reward += episode_reward
        average_episode_reward /= self.cfg.num_eval_episodes
        self.meter.log('eval/episode_reward', average_episode_reward,
                       self.progress.global_step)
        self.meter.dump(self.progress.global_step)
        return average_episode_reward

    def push_replay_buffer(self, info):
        if self.cfg.reward_forward == 0:
            self.replay_buffer.push(info)
        else:
            self.temp_buffer += [info]
            if info[-1][0] != 0:
                for i in reversed(range(len(self.temp_buffer) - 1)):
                    if self.temp_buffer[i][-1][0] == 0:
                        self.temp_buffer[i][-2][
                            0] += self.cfg.reward_forward * self.temp_buffer[
                                i + 1][-2][0]
                    else:
                        self.temp_buffer[i][-2][0] += 0
                for i in range(len(self.temp_buffer)):
                    self.replay_buffer.push(self.temp_buffer[i])
                self.temp_buffer.clear()

    def train_episode(self, train_max_frame, exploration_rate_epsilon):
        progress = self.progress
        state = self.env.reset()
        pre_state = state
        self.agent.reset()
        done_or_stop = False
        episode_reward = 0
        episode_step = 0
        self.meter.log('train/episode', progress.episode, progress.global_step)
        while not done_or_stop:
            if self.cfg.env.render:
                self.env.render()
            # sample action
            action, action_prob, state_digest = self.agent.action(
                pre_state, state, exploration_rate_epsilon)
            # one step
            next_state, reward, done, info = self.env.step(action)
            #有时候env输出的observation_space可能与定义的不一样
            #因为action可以是多个连续动作
            assert next_state.shape == self.env.observation_space.shape
            episode_reward += reward
            # if cfg.env_name == "Breakout-ram-v4":
            #     s = next_state * (state != next_state)
            #     s[90] = 0
            #     if np.sum(s) == 0:
            #         done = True
            done_or_stop = done or train_max_frame > 0 and episode_step == train_max_frame
            self.push_replay_buffer(
                ([str(progress.global_step)], pre_state, state, state_digest,
                 next_state, action, action_prob, [reward
                                                   ], [float(done_or_stop)]))
            # run training update
            if self.replay_buffer.data_count >= self.cfg.batch_size:
                for i in range(self.cfg.batch_train_episodes):
                    self.agent.update(self.cfg.batch_size)
                    self.progress.num_train_iteration += 1
                    work_time = time.time() - self.last_rest_time
                    time.sleep((
                        work_time *
                        self.cfg.running_idle_rate) if work_time > 0 else 0.01)
                    self.last_rest_time = time.time()
                    if time.time() < self.last_save_time or time.time(
                    ) - self.last_save_time >= self.cfg.save_exceed_seconds:
                        self.save()
                        self.last_save_time = time.time()

            pre_state = state
            state = next_state
            episode_step += 1
            progress.global_step += 1
        return episode_reward

    def train(self):
        def set_random_seed(seed):
            if seed is not None:
                random.seed(seed)
                self.env.seed(seed)
                torch.manual_seed(seed)
                if torch.cuda.is_available():
                    torch.cuda.manual_seed_all(seed)
                np.random.seed(seed)

        # trainning
        set_random_seed(self.cfg.seed)
        self.agent.train()
        progress = self.progress
        while (self.cfg.num_train_epochs < 0 or progress.episode <= self.cfg.num_train_epochs) \
                and (self.cfg.train_stop_reward < 0 or progress.train_running_reward < self.cfg.train_stop_reward) \
                and progress.train_running_reward < progress.env_top_reward:
            self.cfg.reload()
            train_max_frame = self.cfg.train_max_frame
            if isinstance(
                    self.env.spec.max_episode_steps,
                    int) and self.env.spec.max_episode_steps < train_max_frame:
                train_max_frame = self.env.spec.max_episode_steps

            if self.progress.global_step < self.cfg.num_seed_steps:
                exploration_rate_epsilon = 1
            elif self.progress.exploration_rate_epsilon == 1:
                exploration_rate_epsilon = self.progress.exploration_rate_epsilon = self.cfg.exploration_rate_init
            else:
                exploration_rate_epsilon = self.progress.exploration_rate_epsilon

            start_time = time.time()
            episode_reward = self.train_episode(train_max_frame,
                                                exploration_rate_epsilon)
            if progress.episode > 1:
                progress.train_running_reward = 0.05 * episode_reward + (
                    1 - 0.05) * progress.train_running_reward
            else:
                progress.train_running_reward = episode_reward

            self.meter.log('train/epsilon', exploration_rate_epsilon,
                           progress.global_step)
            self.meter.log('train/episode_reward', episode_reward,
                           progress.global_step)
            self.meter.log('train/running_reward',
                           progress.train_running_reward, progress.global_step)
            self.meter.log('train/duration',
                           time.time() - start_time, progress.global_step)
            self.meter.dump(progress.global_step)

            if self.progress.exploration_rate_epsilon < 1:
                if self.progress.exploration_rate_epsilon > self.cfg.exploration_rate_min:
                    self.progress.exploration_rate_epsilon *= self.cfg.exploration_rate_decay
                if self.progress.exploration_rate_epsilon < self.cfg.exploration_rate_min:
                    self.progress.exploration_rate_epsilon = self.cfg.exploration_rate_min

            if progress.train_running_reward > progress.env_top_reward * self.cfg.eval_expect_top_reward_percent:
                self.meter.log('eval/episode', progress.episode,
                               progress.global_step)
                self.progress.evaluate_reward = self.evaluate()
                if self.progress.evaluate_reward > self.progress.eval_top_reward:
                    self.save(best_model=True)
                    if self.progress.evaluate_reward > self.cfg.save_video_exceed_reward:
                        vrfile = f'{self.workdir}/eval_top_reward.mp4'
                        self.video_recorder.save(vrfile)
                        self.progress.eval_top_reward = self.progress.evaluate_reward

            work_time = time.time() - self.last_rest_time
            time.sleep((work_time *
                        self.cfg.running_idle_rate) if work_time > 0 else 0.01)
            self.last_rest_time = time.time()

            progress.episode += 1
            # Finished One Episode
        return progress.train_running_reward

    def run(self):
        while True:
            if isinstance(
                    self.env.spec.reward_threshold, float
            ) and self.env.spec.reward_threshold > self.progress.env_top_reward:
                self.progress.env_top_reward = self.env.spec.reward_threshold
            elif self.cfg.train_stop_reward > self.progress.env_top_reward:
                self.progress.env_top_reward = self.cfg.train_stop_reward
            elif self.progress.evaluate_reward > self.progress.env_top_reward:
                self.progress.env_top_reward = self.progress.evaluate_reward
            if self.progress.train_running_reward < self.progress.env_top_reward:
                self.progress.train_running_reward = self.train()
            else:
                self.progress.evaluate_reward = self.evaluate()
示例#12
0
class VideoCapture:
    def __init__(
        self,
        is_record_topic,
        is_memory_usage_exceeded_topic,
        image_topic,
        video_type,
        video_dimensions,
        frames_per_second,
        out_directory,
    ):

        rospy.init_node('video_recorder', anonymous=True)
        self._video_recorder = VideoRecorder(
            video_type=video_type,
            video_dimensions=video_dimensions,
            frames_per_second=frames_per_second,
            out_directory=out_directory,
        )

        self._is_record_subscriber = rospy.Subscriber(is_record_topic, Bool,
                                                      self._is_record_callback)
        self._image_subscriber = rospy.Subscriber(image_topic, Image,
                                                  self._image_callback)
        self._memory_watch_subscriber = rospy.Subscriber(
            is_memory_usage_exceeded_topic, Bool, self._memory_check_callback)
        self._bridge = CvBridge()

        # This flag is used to block recording if memory exceeds limits
        self._allow_recording = True

    def _image_callback(self, data):

        try:
            cv_image = self._bridge.imgmsg_to_cv2(data, "bgr8")
        except CvBridgeError as e:
            raise e

        self._video_recorder.add_image(cv_image,
                                       is_throw_error_if_not_recording=False)

    def _is_record_callback(self, data):

        is_record = data.data
        try:
            if is_record:
                if self._allow_recording:
                    rospy.loginfo("Starting to record video")
                    self._video_recorder.start_recording()
                else:
                    rospy.logerr(
                        "Recording will not happen due to memory limits exceeded"
                    )

            else:
                if self._video_recorder._is_recording:
                    rospy.loginfo("Stopped recording video")
                    self._video_recorder.stop_recording()

        except RuntimeError as e:
            rospy.logerr(e)

    def _memory_check_callback(self, data):
        is_memory_usage_exceeded = data.data

        if is_memory_usage_exceeded:
            self._allow_recording = False
            if self._video_recorder._is_recording:
                self._video_recorder.stop_recording()
                rospy.logerr(
                    "Stopped Video recording due to memory utilization exceeded"
                )
            else:
                rospy.loginfo(
                    "Memory utilization exceeded the set limits. Recording will not happen"
                )

        else:
            self._allow_recording = True