Beispiel #1
0
    def __init__(self, config: ArchitectureConfig, quiet: bool = False):
        raise NotImplementedError(
            'Currently this should not work as the actions recorded are 6d '
            'but this network only return 1d')
        super().__init__(config=config, quiet=True)
        self.input_size = (30, )
        self.output_size = (1, )
        self.action_min = -1
        self.action_max = +1
        self._actor = mlp_creator(
            sizes=[self.input_size[0], 64, 64, self.output_size[0]],
            activation=nn.Tanh(),
            output_activation=nn.Tanh())

        self._critic = mlp_creator(sizes=[self.input_size[0], 64, 64, 1],
                                   activation=nn.Tanh(),
                                   output_activation=None)
        log_std = self._config.log_std if self._config.log_std != 'default' else -0.5
        self.log_std = torch.nn.Parameter(
            torch.ones(self.output_size, dtype=torch.float32) * log_std,
            requires_grad=True)
        if not quiet:
            self._logger = get_logger(
                name=get_filename_without_extension(__file__),
                output_path=config.output_path,
                quiet=False)

            cprint(f'Started.', self._logger)
            self.initialize_architecture()
    def __init__(self, config: ArchitectureConfig, quiet: bool = False):
        super().__init__(config=config, quiet=True)

        self.residual_2.residual_net.conv_0.weight = self.residual_1.residual_net.conv_0.weight
        self.residual_2.residual_net.conv_0.bias = self.residual_1.residual_net.conv_0.bias
        self.residual_2.residual_net.conv_1.weight = self.residual_1.residual_net.conv_1.weight
        self.residual_2.residual_net.conv_1.bias = self.residual_1.residual_net.conv_1.bias

        self.residual_3.residual_net.conv_0.weight = self.residual_1.residual_net.conv_0.weight
        self.residual_3.residual_net.conv_0.bias = self.residual_1.residual_net.conv_0.bias
        self.residual_3.residual_net.conv_1.weight = self.residual_1.residual_net.conv_1.weight
        self.residual_3.residual_net.conv_1.bias = self.residual_1.residual_net.conv_1.bias

        self.residual_4.residual_net.conv_0.weight = self.residual_1.residual_net.conv_0.weight
        self.residual_4.residual_net.conv_0.bias = self.residual_1.residual_net.conv_0.bias
        self.residual_4.residual_net.conv_1.weight = self.residual_1.residual_net.conv_1.weight
        self.residual_4.residual_net.conv_1.bias = self.residual_1.residual_net.conv_1.bias

        if not quiet:
            self._logger = get_logger(
                name=get_filename_without_extension(__file__),
                output_path=config.output_path,
                quiet=False)

            self.initialize_architecture()
            cprint(f'Started.', self._logger)
Beispiel #3
0
    def update_fields(self):
        key = self.get_key()
        if self.methodBindings is not None and key in self.methodBindings.keys(
        ):
            eval(f'self.{self.methodBindings[key]}()')
        if self.topicBindings is not None and key in self.topicBindings.keys():
            self.publishers[key].publish(Empty())
            cprint(f'publish {self.publishers[key]}', self._logger)
        if self.serviceBindings is not None and key in self.serviceBindings.keys(
        ):
            # rospy.wait_for_service(self.serviceBindings[key]['name'])
            self.serviceBindings[key]['proxy'](
                self.serviceBindings[key]['message'])
            # self.serviceBindings[key]['proxy'](True)
            cprint(
                f'{self.serviceBindings[key]["proxy"]}({self.serviceBindings[key]["message"]})',
                self._logger)
        if self.moveBindings is not None and key in self.moveBindings.keys():
            self.x = self.moveBindings[key][0]
            self.y = self.moveBindings[key][1]
            self.z = self.moveBindings[key][2]
            self.roll = self.moveBindings[key][3]
            self.pitch = self.moveBindings[key][4]
            self.yaw = self.moveBindings[key][5]
        else:
            self.reset_control_fields()

        return key
    def __init__(self, config: ArchitectureConfig, quiet: bool = False):
        super().__init__(config=config, quiet=True)
        self.input_scope = 'default'
        self.discrete = False
        self.dropout = nn.Dropout(
            p=config.dropout) if isinstance(config.dropout, float) else None
        self._config.batch_normalisation = config.batch_normalisation if isinstance(config.batch_normalisation, bool) \
            else False
        self.h = self._config.latent_dim if isinstance(config.latent_dim,
                                                       int) else 32
        self.vae = self._config.vae if isinstance(config.vae, bool) else False

        if not quiet:
            self.encoder = nn.Sequential(
                nn.Conv2d(1, self.h * (2 if self.vae else 1), 1, stride=1),
                *[
                    nn.BatchNorm2d(self.h * (2 if self.vae else 1)),
                    nn.LeakyReLU()
                ] if self._config.batch_normalisation else [nn.LeakyReLU()],
            )
            self.decoder = nn.Sequential(nn.Conv2d(self.h, 1, 1, stride=1))
            self.initialize_architecture()
            self._logger = get_logger(
                name=get_filename_without_extension(__file__),
                output_path=config.output_path,
                quiet=False)

            cprint(f'Started.', self._logger)
 def _store_frame(self, data: Union[np.ndarray, float], dst: str,
                  time_stamp: int) -> None:
     if not isinstance(data, np.ndarray):
         data = np.asarray(data)
     try:
         if len(data.shape) in [2, 3]:
             if not os.path.isdir(
                     os.path.join(self._config.saving_directory, dst)):
                 os.makedirs(os.path.join(self._config.saving_directory,
                                          dst),
                             exist_ok=True)
             store_image(
                 data=data,
                 file_name=os.path.join(self._config.saving_directory, dst,
                                        timestamp_to_filename(time_stamp)) +
                 '.jpg')
         elif len(data.shape) in [0, 1]:
             store_array_to_file(data=data,
                                 file_name=os.path.join(
                                     self._config.saving_directory,
                                     dst + '.data'),
                                 time_stamp=time_stamp)
     except Exception as e:
         cprint(f'Failed to store frame: {e}',
                self._logger,
                msg_type=MessageType.error)
 def _update_fsm_state(self, msg: String):
     if self._fsm_state == FsmState.Running and FsmState[
             msg.data] == FsmState.Terminated:
         self._write_image()
     if self._fsm_state != FsmState[msg.data]:
         cprint(f'update fsm state {msg.data}', self._logger)
     self._fsm_state = FsmState[msg.data]
    def _setup(self):
        self._fsm_state = FsmState.Unknown
        rospy.Subscriber(name='/fsm/state',
                         data_class=String,
                         callback=self._set_fsm_state)

        # field turn True when motors are enabled, don't publish control when motors are disabled
        self._motors_enabled = False

        # keep track of last commands to detect stable point
        self._control_norm_window_length = 10
        self._control_norm_window = []

        if 'turtle' in self._robot or 'default' in self._robot or 'real' in self._robot:
            cprint(
                f'altitude control not required for {self._robot}, shutting down...',
                self._logger)
            sys.exit(0)
        elif self._robot == 'quadrotor':
            # in case of single quadrotor
            self._publishers['default'] = rospy.Publisher('cmd_vel',
                                                          Twist,
                                                          queue_size=10)
            sensor = SensorType.position
            sensor_topic = rospy.get_param(
                f'/robot/{sensor.name}_sensor/topic')
            sensor_type = rospy.get_param(f'/robot/{sensor.name}_sensor/type')
            rospy.Subscriber(
                name=sensor_topic,
                data_class=eval(sensor_type),
                callback=eval(
                    f'self._process_{camelcase_to_snake_format(sensor_type)}'),
                callback_args='default')
            rospy.wait_for_service('/enable_motors')
            self._enable_motors_services['default'] = rospy.ServiceProxy(
                '/enable_motors', EnableMotors)
        elif isinstance(self._robot, list):
            # in case of tracking fleeing quadrotor
            self._publishers['tracking'] = rospy.Publisher('cmd_vel',
                                                           Twist,
                                                           queue_size=10)
            self._publishers['fleeing'] = rospy.Publisher('cmd_vel_1',
                                                          Twist,
                                                          queue_size=10)
            for agent in ['tracking', 'fleeing']:
                sensor = SensorType.position
                sensor_topic = rospy.get_param(
                    f'/robot/{agent}_{sensor.name}_sensor/topic')
                sensor_type = rospy.get_param(
                    f'/robot/{agent}_{sensor.name}_sensor/type')
                rospy.Subscriber(
                    name=sensor_topic,
                    data_class=eval(sensor_type),
                    callback=eval(
                        f'self._process_{camelcase_to_snake_format(sensor_type)}'
                    ),
                    callback_args=agent)
                rospy.wait_for_service(f'/{agent}/enable_motors')
                self._enable_motors_services[agent] = rospy.ServiceProxy(
                    f'/{agent}/enable_motors', EnableMotors)
    def __init__(self):
        self.count = 0
        rospy.init_node('ros_expert')
        stime = time.time()
        max_duration = 60
        while not rospy.has_param('/actor/ros_expert/specs'
                                  ) and time.time() < stime + max_duration:
            time.sleep(0.01)
        self._specs = rospy.get_param('/actor/ros_expert/specs')
        super().__init__(
            config=ActorConfig(name='ros_expert', specs=self._specs))
        self._output_path = get_output_path()
        self._logger = get_logger(get_filename_without_extension(__file__),
                                  self._output_path)
        cprint(f'ros specifications: {self._specs}', self._logger)
        with open(os.path.join(self._output_path, 'ros_expert_specs.yml'),
                  'w') as f:
            yaml.dump(self._specs, f)
        self._reference_height = rospy.get_param('/world/starting_height', 1)
        self._adjust_height = 0
        self._adjust_yaw_collision_avoidance = 0
        self._adjust_yaw_waypoint_following = 0
        self._rate_fps = self._specs[
            'rate_fps'] if 'rate_fps' in self._specs.keys() else 10
        self._next_waypoint = []
        noise_config = self._specs['noise'] if 'noise' in self._specs.keys(
        ) else {}
        self._noise = eval(f"{noise_config['name']}(**noise_config['args'])"
                           ) if noise_config else None

        self._publisher = rospy.Publisher('cmd_vel', Twist, queue_size=10)
        self._subscribe()
    def create_train_validation_hdf5_files(
            self,
            runs: List[str] = None,
            input_size: List[int] = None) -> None:
        all_runs = runs if runs is not None else self._get_runs()

        number_of_training_runs = int(self._config.training_validation_split *
                                      len(all_runs))
        train_runs = all_runs[0:number_of_training_runs]
        validation_runs = all_runs[number_of_training_runs:]

        for file_name, runs in zip(['train', 'validation'],
                                   [train_runs, validation_runs]):
            config = DataLoaderConfig().create(
                config_dict={
                    'data_directories': runs,
                    'output_path': self._config.output_path,
                    'subsample': self._config.subsample_hdf5,
                    'input_size': input_size
                })
            data_loader = DataLoader(config=config)
            data_loader.load_dataset()
            create_hdf5_file_from_dataset(filename=os.path.join(
                self._config.output_path, file_name + '.hdf5'),
                                          dataset=data_loader.get_dataset())
            cprint(f'created {file_name}.hdf5', self._logger)
Beispiel #10
0
    def __init__(self,
                 config: EvaluatorConfig,
                 network: BaseNet,
                 quiet: bool = False):
        self._config = config
        self._net = network
        self.data_loader = DataLoader(config=self._config.data_loader_config)

        if not quiet:
            self._logger = get_logger(
                name=get_filename_without_extension(__file__),
                output_path=config.output_path,
                quiet=False) if type(self) == Evaluator else None
            cprint(f'Started.', self._logger)

        self._device = torch.device(
            "cuda" if self._config.device in ['gpu', 'cuda']
            and torch.cuda.is_available() else "cpu")
        self._criterion = eval(
            f'{self._config.criterion}(reduction=\'none\', {self._config.criterion_args_str})'
        )
        self._criterion.to(self._device)
        self._lowest_validation_loss = None
        self.data_loader.load_dataset()

        self._minimum_error = float(10**6)
        self._original_model_device = self._net.get_device(
        ) if self._net is not None else None
    def _subscribe(self):
        self._fsm_state = FsmState.Unknown
        rospy.Subscriber(name='/fsm/state',
                         data_class=String,
                         callback=self._set_fsm_state)

        # listen to desired next reference point
        rospy.Subscriber('/reference_pose', PointStamped,
                         self._reference_update)
        rospy.Subscriber(name='/waypoint_indicator/current_waypoint',
                         data_class=Float32MultiArray,
                         callback=self._reference_update)

        # Robot sensors:
        for sensor in [SensorType.position]:
            if rospy.has_param(f'/robot/{sensor.name}_sensor/topic'):
                sensor_topic = rospy.get_param(
                    f'/robot/{sensor.name}_sensor/topic')
                sensor_type = rospy.get_param(
                    f'/robot/{sensor.name}_sensor/type')
                sensor_callback = f'_process_{camelcase_to_snake_format(sensor_type)}'
                if sensor_callback not in self.__dir__():
                    cprint(f'Could not find sensor_callback {sensor_callback}',
                           self._logger)
                sensor_stats = rospy.get_param(f'/robot/{sensor.name}_sensor/stats') \
                    if rospy.has_param(f'/robot/{sensor.name}_sensor/stats') else {}
                rospy.Subscriber(name=sensor_topic,
                                 data_class=eval(sensor_type),
                                 callback=eval(f'self.{sensor_callback}'),
                                 callback_args=(sensor_topic, sensor_stats))
    def __init__(self, config: ArchitectureConfig, quiet: bool = False):
        super().__init__(config=config, quiet=True)
        self.input_size = (4,)
        self.output_size = (5,)
        self.discrete = False

        log_std = self._config.log_std if self._config.log_std != 'default' else -0.5
        self.log_std = torch.nn.Parameter(torch.ones(self.output_size, dtype=torch.float32) * log_std,
                                          requires_grad=True)

        self._actor = mlp_creator(sizes=[self.input_size[0], 10, self.output_size[0]],
                                  activation=nn.Tanh(),
                                  output_activation=None)

        self._critic = mlp_creator(sizes=[self.input_size[0], 10, 1],
                                   activation=nn.Tanh(),
                                   output_activation=None)

        self.initialize_architecture()

        self.discrete_action_mapper = DiscreteActionMapper([
            torch.as_tensor([0.0, 0.0, 0.0, 0.0]),
            torch.as_tensor([-1.0, 0.0, 0.0, 0.0]),
            torch.as_tensor([1.0, 0.0, 0.0, 0.0]),
            torch.as_tensor([0.0, -1.0, 0.0, 0.0]),
            torch.as_tensor([0.0, 1.0, 0.0, 0.0]),
        ])
        if not quiet:
            self._logger = get_logger(name=get_filename_without_extension(__file__),
                                      output_path=config.output_path,
                                      quiet=False)

            cprint(f'Started.', self._logger)
    def __init__(self, config: ArchitectureConfig, quiet: bool = False):
        super().__init__(config=config)
        self.input_size = (3, )
        self.output_size = (1, 1)
        self.action_min = -2
        self.action_max = 2

        self._actor = mlp_creator(
            sizes=[self.input_size[0], 64, 64, self.output_size[0]],
            activation=nn.Tanh(),
            output_activation=None)

        self._critic = mlp_creator(sizes=[self.input_size[0], 64, 64, 1],
                                   activation=nn.Tanh(),
                                   output_activation=None)
        log_std = self._config.log_std if self._config.log_std != 'default' else -0.5
        self.log_std = torch.nn.Parameter(
            torch.ones(self.output_size, dtype=torch.float32) * log_std,
            requires_grad=True)
        if not quiet:
            self._logger = get_logger(
                name=get_filename_without_extension(__file__),
                output_path=config.output_path,
                quiet=False)

            cprint(f'Started.', self._logger)
            self.initialize_architecture()
    def __init__(self):
        stime = time.time()
        max_duration = 60
        while not rospy.has_param('/modified_state_publisher/mode'
                                  ) and time.time() < stime + max_duration:
            time.sleep(0.01)

        self._output_path = get_output_path()
        self._logger = get_logger(get_filename_without_extension(__file__),
                                  self._output_path)

        rospy.Subscriber(
            rospy.get_param('/robot/modified_state_sensor/topic',
                            '/modified_state'),
            eval(
                rospy.get_param('/robot/modified_state_sensor/type',
                                'CombinedGlobalPoses')),
            self._process_state_and_publish_frame)
        self._publisher = rospy.Publisher('/modified_state_frame',
                                          Image,
                                          queue_size=10)
        cprint(
            f"subscribe to {rospy.get_param('/robot/modified_state_sensor/topic', '/modified_state')}",
            self._logger)
        rospy.init_node('modified_state_frame_visualizer')
    def __init__(self, config: ArchitectureConfig, quiet: bool = False):
        super().__init__(config=config, quiet=True)
        self.input_size = (30, )
        self.output_size = (3, )
        self._actor = mlp_creator(
            sizes=[self.input_size[0], 64, 64, self.output_size[0]],
            activation=nn.Tanh(),
            output_activation=None)

        self._critic = mlp_creator(sizes=[self.input_size[0], 64, 64, 1],
                                   activation=nn.Tanh(),
                                   output_activation=None)
        self.initialize_architecture()
        self.discrete_action_mapper = DiscreteActionMapper([
            torch.as_tensor([0.2, 0.0, 0.0, 0.0, 0.0, -0.2]),
            torch.as_tensor([0.2, 0.0, 0.0, 0.0, 0.0, 0.0]),
            torch.as_tensor([0.2, 0.0, 0.0, 0.0, 0.0, 0.2]),
        ])
        if not quiet:
            self._logger = get_logger(
                name=get_filename_without_extension(__file__),
                output_path=config.output_path,
                quiet=False)

            cprint(f'Started.', self._logger)
Beispiel #16
0
 def __init__(self, config: ArchitectureConfig, quiet: bool = False):
     super().__init__(config=config, quiet=True)
     self._logger = get_logger(
         name=get_filename_without_extension(__file__),
         output_path=config.output_path,
         quiet=False)
     if not quiet:
         cprint(f'Started.', self._logger)
     self.input_size = (3, 128, 128)
     self.output_size = (6, )
     self.discrete = False
     self.dropout = nn.Dropout(
         p=config.dropout) if config.dropout != 'default' else None
     self.encoder = nn.Sequential(
         nn.Conv2d(3, 32, 4, stride=2),
         nn.ReLU(),
         nn.Conv2d(32, 64, 4, stride=2),
         nn.ReLU(),
         nn.Conv2d(64, 128, 4, stride=2),
         nn.ReLU(),
         nn.Conv2d(128, 256, 4, stride=2),
         nn.ReLU(),
     )
     self.decoder = mlp_creator(
         sizes=[256 * 6 * 6, 128, 128, self.output_size[0]],
         activation=nn.ReLU(),
         output_activation=nn.Tanh(),
         bias_in_last_layer=False)
     self.initialize_architecture()
 def reset(self) -> Tuple[Experience, np.ndarray]:
     """
     reset gazebo, reset fsm, wait till fsm in 'running' state
     return experience without reward or action
     """
     cprint(f'resetting', self._logger)
     self._reset_filters()
     self._step = 0
     self._return = 0
     if self._config.ros_config.ros_launch_config.gazebo:
         self._reset_gazebo()
     self._reset_publisher.publish(Empty())
     self._clear_experience_values()
     while self.fsm_state != FsmState.Running \
             or self.observation is None \
             or self.terminal_state is None \
             or self.terminal_state is TerminationType.Unknown:
         self._run_shortly()
     self.observation = self._filter_observation(self.observation)
     self._current_experience = Experience(
         done=deepcopy(self.terminal_state),
         observation=deepcopy(self.observation),
         time_stamp=int(rospy.get_time() * 10**3),
         info={})
     self._previous_observation = deepcopy(self.observation)
     return self._current_experience, deepcopy(self.observation)
 def _signal_handler(self, signal_number: int, _) -> None:
     return_value = self.remove()
     cprint(f'received signal {signal_number}.',
            self._logger,
            msg_type=MessageType.info if return_value
            == ProcessState.Terminated else MessageType.error)
     sys.exit(0)
Beispiel #19
0
    def __init__(self,
                 config: TrainerConfig,
                 network: BaseNet,
                 quiet: bool = False):
        super().__init__(config, network, quiet=True)

        if not quiet:
            self._optimizer = eval(f'torch.optim.{self._config.optimizer}')(
                params=self._net.deeply_supervised_parameters(),
                lr=self._config.learning_rate,
                weight_decay=self._config.weight_decay)
            lambda_function = lambda f: 1 - f / self._config.scheduler_config.number_of_epochs

            self._scheduler = torch.optim.lr_scheduler.LambdaLR(self._optimizer, lr_lambda=lambda_function) \
                if self._config.scheduler_config is not None else None

            self._discriminator_optimizer = eval(
                f'torch.optim.{self._config.optimizer}')(
                    params=self._net.discriminator_parameters(),
                    lr=self._config.critic_learning_rate
                    if self._config.critic_learning_rate != -1 else
                    self._config.learning_rate,
                    weight_decay=self._config.weight_decay)

            self._logger = get_logger(
                name=get_filename_without_extension(__file__),
                output_path=config.output_path,
                quiet=False)
            cprint(f'Started.', self._logger)
 def _process_position(self, pose: Pose):
     if float(f"{rospy.get_time()}".split('.')[-1]
              [0]) % 5 == 0:  # print every 500ms
         cprint(f'received pose {pose}',
                self._logger,
                msg_type=MessageType.debug)
     robot_global_translation = np.asarray(
         [pose.position.x, pose.position.y, pose.position.z])
     robot_global_orientation = rotation_from_quaternion(
         (pose.orientation.x, pose.orientation.y, pose.orientation.z,
          pose.orientation.w))
     points_global_frame = transform(points=self._local_frame,
                                     orientation=robot_global_orientation,
                                     translation=robot_global_translation)
     points_camera_frame = transform(
         points=points_global_frame,
         orientation=self._camera_global_orientation,
         translation=self._camera_global_translation,
         invert=True
     )  # camera_global transformation is given, but should be inverted
     points_image_frame = project(points=points_camera_frame,
                                  fx=self._fx,
                                  fy=self._fy,
                                  cx=self._cx,
                                  cy=self._cy)
     if self._previous_position is None:
         self._previous_position = points_image_frame[
             0]  # store origin of position
         self._frame_points.append(points_image_frame)
     elif get_distance(self._previous_position,
                       points_image_frame[0]) > self._minimum_distance_px:
         self._previous_position = points_image_frame[
             0]  # store origin of position
         self._frame_points.append(points_image_frame)
 def _internal_update_terminal_state(self):
     if self.fsm_state == FsmState.Running and \
             self._config.max_number_of_steps != -1 and \
             self._config.max_number_of_steps <= self._step:
         self.terminal_state = TerminationType.Done
         cprint(
             f'reach max number of steps {self._config.max_number_of_steps} < {self._step}',
             self._logger)
 def run(self):
     while not rospy.is_shutdown():
         if self._fsm_state == FsmState.Running:
             self.count += 1
             if self.count % 10 * self._rate_fps == 0:
                 msg = f'<<reference: {self.pose_ref}, \n<<pose: {self.pose_est} \n control: {self.last_cmd}'
                 cprint(msg, self._logger)
         rospy.sleep(duration=1 / self._rate_fps)
Beispiel #23
0
 def _check_wrench_stamped(self, msg: WrenchStamped) -> None:
     if self._delay_evaluation():
         return
     if msg.wrench.force.z < 0:
         cprint(
             f"found drag force: {msg.wrench.force.z}, so robot must be upside-down.",
             self._logger)
         self._occasion = 'on_collision'
         self._shutdown_run()
 def get_checkpoint(self) -> dict:
     """
     :return: a dictionary with global_step and model_state of neural network.
     """
     cprint(f'checksum: {self.get_checksum()}', self._logger)
     return {
         'global_step': self.global_step,
         'model_state': self.state_dict()
     }
 def test_multiple_loggers(self):
     logger_a = get_logger(name='module_a',
                           output_path=self.TEST_DIR)
     logger_b = get_logger(name='module_b',
                           output_path=self.TEST_DIR)
     cprint('started', logger_a)
     cprint('started', logger_b)
     log_files = glob(os.path.join(self.TEST_DIR, 'log_files', '*'))
     self.assertEqual(len(log_files), 2)
Beispiel #26
0
 def load_checkpoint(self, checkpoint_file: str):
     # Load params for each experiment element
     checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
     self._epoch = checkpoint['epoch'] if 'epoch' in checkpoint.keys() else 0
     for element, key in zip([self._net, self._trainer, self._environment],
                             ['net_ckpt', 'trainer_ckpt', 'environment_ckpt']):
         if element is not None and key in checkpoint.keys():
             element.load_checkpoint(checkpoint[key])
     cprint(f'loaded network from {checkpoint_file}', self._logger)
    def __init__(self, config: ArchitectureConfig, quiet: bool = False):
        super().__init__(config=config, quiet=True)
        self.discrete = True
        if not quiet:
            self._logger = get_logger(name=get_filename_without_extension(__file__),
                                      output_path=config.output_path,
                                      quiet=False)

            cprint(f'Started.', self._logger)
            self.initialize_architecture()
    def test_cprint(self):
        current_logger = get_logger(name=get_filename_without_extension(__file__),
                                    output_path=self.TEST_DIR,
                                    quiet=True)
        cprint('HELP', current_logger)

        log_file = glob(os.path.join(self.TEST_DIR, 'log_files', '*'))[0]
        with open(log_file, 'r') as f:
            log_line = f.readlines()[0].strip()
            self.assertTrue('HELP' in log_line)
 def load_checkpoint(self, checkpoint) -> None:
     """
     Try to load checkpoint in global step and model state. Raise error.
     :param checkpoint: dictionary containing 'global step' and 'model state'
     :return: None
     """
     self.global_step = checkpoint['global_step']
     self.load_state_dict(checkpoint['model_state'])
     self.set_device(self._device)
     cprint(f'checksum: {self.get_checksum()}', self._logger)
    def _subscribe(self):
        # Robot sensors:
        sensor = '/robot/forward_camera'
        if rospy.has_param(f'{sensor}_topic'):
            sensor_topic = rospy.get_param(f'{sensor}_topic')
            sensor_type = rospy.get_param(f'{sensor}_type')
            sensor_callback = f'_process_{camelcase_to_snake_format(sensor_type)}'
            if sensor_callback not in self.__dir__():
                cprint(f'Could not find sensor_callback {sensor_callback}',
                       self._logger)
            sensor_stats = rospy.get_param(
                f'{sensor}_stats') if rospy.has_param(
                    f'{sensor}_stats') else {}
            rospy.Subscriber(name=sensor_topic,
                             data_class=eval(sensor_type),
                             callback=eval(f'self.{sensor_callback}'),
                             callback_args=(sensor_topic, sensor_stats))
        rospy.Subscriber('/fsm/reset', Empty, self._reset)

        # Applied action
        self._action = None
        if rospy.has_param('/robot/command_topic'):
            rospy.Subscriber(name=rospy.get_param('/robot/command_topic'),
                             data_class=Twist,
                             callback=self._set_field,
                             callback_args=('action', {}))
        # fsm state
        self._fsm_state = None
        rospy.Subscriber(name=rospy.get_param('/fsm/state_topic'),
                         data_class=String,
                         callback=self._set_field,
                         callback_args=('fsm_state', {}))

        # Reward topic
        self._reward = None
        self._terminal_state = TerminationType.Unknown
        rospy.Subscriber(name=rospy.get_param('/fsm/reward_topic', ''),
                         data_class=RosReward,
                         callback=self._set_field,
                         callback_args=('reward', {}))

        # waypoint
        self._waypoint = None
        rospy.Subscriber(name='/waypoint_indicator/current_waypoint',
                         data_class=Float32MultiArray,
                         callback=self._set_field,
                         callback_args=('waypoint', {}))

        # battery state
        self._battery = None
        rospy.Subscriber(
            name='/bebop/states/common/CommonState/BatteryStateChanged',
            data_class=CommonCommonStateBatteryStateChanged,
            callback=self._set_field,
            callback_args=('battery', {}))