def session_factory( agent: Type[AgentType] = None, config=None, *, session: Union[Type[EnvironmentSessionType], EnvironmentSession], save: bool = True, has_x_server: bool = True, skip_confirmation: bool = True, **kwargs, ): r""" Entry point start a starting a training session with the functionality of parsing cmdline arguments and confirming configuration to use before training and overwriting of default training configurations """ if config is None: config = {} if isinstance(config, dict): config = NOD(**config) else: config = NOD(config.__dict__) if has_x_server: display_env = getenv("DISPLAY", None) if display_env is None: config.RENDER_ENVIRONMENT = False has_x_server = False config_mapping = config_to_mapping(config) config_mapping.update(**kwargs) config_mapping.update(save=save, has_x_server=has_x_server) if not skip_confirmation: sprint(f"\nUsing config: {config}\n", highlight=True, color="yellow") for key, arg in config_mapping: print(f"{key} = {arg}") input("\nPress Enter to begin... ") if session is None: raise NoProcedure elif inspect.isclass(session): session = session(**config_mapping) # Use passed config arguments elif isinstance(session, GDKC): session = session( **kwargs ) # Assume some kw parameters is set prior to passing session, only override with explicit overrides try: session(agent, **config_mapping) except KeyboardInterrupt: print("Stopping") torch.cuda.empty_cache() exit(0)
def build( self, observation_space: ObservationSpace, action_space: ActionSpace, signal_space: SignalSpace, *, metric_writer: Writer = MockWriter(), print_model_repr: bool = True, verbose: bool = False, **kwargs, ) -> None: """ @param observation_space: @param action_space: @param signal_space: @param metric_writer: @param print_model_repr: @param kwargs: @return: :param verbose: """ super().build( observation_space, action_space, signal_space, print_model_repr=print_model_repr, metric_writer=metric_writer, **kwargs, ) if print_model_repr: for k, w in self.models.items(): sprint(f"{k}: {w}", highlight=True, color="cyan") if metric_writer: try: model = copy.deepcopy(w).to("cpu") dummy_input = model.sample_input() sprint(f'{k} input: {dummy_input.shape}') import contextlib with contextlib.redirect_stdout( None ): # So much useless frame info printed... Suppress it if isinstance(metric_writer, GraphWriterMixin): metric_writer.graph(model, dummy_input, verbose=verbose) # No naming available at moment... except RuntimeError as ex: sprint( f"Tensorboard(Pytorch) does not support you model! No graph added: {str(ex).splitlines()[0]}", color="red", highlight=True, )
def __infer_io_shapes( self, observation_space: ObservationSpace, action_space: ActionSpace, signal_space: SignalSpace, print_inferred_io_shapes: bool = True, ) -> None: """ Tries to infer input and output size from env if either _input_shape or _output_shape, is None or -1 (int) :rtype: object """ if self._input_shape is None or self._input_shape == -1: self._input_shape = observation_space.shape if self._output_shape is None or self._output_shape == -1: self._output_shape = action_space.shape # region print if print_inferred_io_shapes: sprint( f"input shape: {self._input_shape}\n" f"observation space: {observation_space}\n", color="green", bold=True, highlight=True, ) sprint( f"output shape: {self._output_shape}\n" f"action space: {action_space}\n", color="yellow", bold=True, highlight=True, ) sprint( f"signal shape: {signal_space}\n", color="blue", bold=True, highlight=True, )
def export_detection_model( cfg: NOD, model_ckpt: Path, model_export_path: Path = Path("torch_model"), verbose: bool = True, onnx_export: bool = False, strict_jit: bool = False, ) -> None: """ :param verbose: :type verbose: :param cfg: :type cfg: :param model_ckpt: :type model_ckpt: :param model_export_path: :type model_export_path: :return: :rtype: """ model = SingleShotDectectionNms(cfg) checkpointer = CheckPointer( model, save_dir=ensure_existence(PROJECT_APP_PATH.user_data / "results") ) checkpointer.load(model_ckpt, use_latest=model_ckpt is None) print( f"Loaded weights from {model_ckpt if model_ckpt else checkpointer.get_checkpoint_file()}" ) model.post_init() model.to(global_torch_device()) transforms = SSDTransform( cfg.input.image_size, cfg.input.pixel_mean, split=Split.Testing ) model.eval() pre_quantize_model = False if pre_quantize_model: # Accuracy may drop! if True: model = quantization.quantize_dynamic(model, dtype=torch.qint8) else: pass # model = quantization.quantize(model) frame_g = frame_generator(cv2.VideoCapture(0)) for image in tqdm(frame_g): example_input = (transforms(image)[0].unsqueeze(0).to(global_torch_device()),) try: traced_script_module = torch.jit.script( model, # example_input, ) exp_path = model_export_path.with_suffix(".compiled") traced_script_module.save(str(exp_path)) print(f"Traced Ops used {torch.jit.export_opnames(traced_script_module)}") sprint( f"Successfully exported JIT Traced model at {exp_path}", color="green" ) except Exception as e_i: sprint(f"Torch JIT Trace export does not work!, {e_i}", color="red") break
def export_detection_model( cfg: NOD, model_checkpoint: Path, model_export_path: Path = Path("torch_model"), verbose: bool = True, onnx_export: bool = False, strict_jit: bool = False, ) -> None: """ :param verbose: :type verbose: :param cfg: :type cfg: :param model_checkpoint: :type model_checkpoint: :param model_export_path: :type model_export_path: :return: :rtype:""" model = SingleShotDetection(cfg) checkpointer = CheckPointer(model, save_dir=ensure_existence( PROJECT_APP_PATH.user_data / "results")) checkpointer.load(model_checkpoint, use_latest=model_checkpoint is None) print( f"Loaded weights from {model_checkpoint if model_checkpoint else checkpointer.get_checkpoint_file()}" ) model.post_init() model.to(global_torch_device()) transforms = SSDTransform(cfg.input.image_size, cfg.input.pixel_mean, split=SplitEnum.testing) model.eval() # Important! fuse_quantize_model = False if fuse_quantize_model: modules_to_fuse = [ ["conv", "bn", "relu"] ] # Names of modules to fuse, maybe supply directly for architecture class/declaration model = torch.quantization.fuse_modules( model, modules_to_fuse=modules_to_fuse, inplace=False) pre_quantize_model = False if pre_quantize_model: # Accuracy may drop! if True: model = quantization.quantize_dynamic(model, dtype=torch.qint8) else: pass # model = quantization.quantize(model) frame_g = frame_generator(cv2.VideoCapture(0)) for image in tqdm(frame_g): example_input = (transforms(image)[0].unsqueeze(0).to( global_torch_device()), ) try: if onnx_export: exp_path = model_export_path.with_suffix(".onnx") output = onnx.export( model, example_input, str(exp_path), verbose=verbose, # export_params=True, # store the trained parameter weights inside the model file # opset_version=10, # the onnx version to export the model to # do_constant_folding=True, # wether to execute constant folding for optimization # input_names=["input"], # the model's input names # output_names=["output"], # the model's output names # dynamic_axes={ # "input": {0: "batch_size"}, # variable lenght axes # "output": {0: "batch_size"}, # } ) sprint(f"Successfully exported ONNX model at {exp_path}", color="blue") else: raise Exception("Just trace instead, ignore exception") except Exception as e: sprint(f"Torch ONNX export does not work, {e}", color="red") try: traced_script_module = torch.jit.trace( model, example_input, # strict=strict_jit, check_inputs=( transforms(next(frame_g))[0].unsqueeze(0).to( global_torch_device()), transforms(next(frame_g))[0].unsqueeze(0).to( global_torch_device()), ), ) exp_path = model_export_path.with_suffix(".traced") traced_script_module.save(str(exp_path)) print( f"Traced Ops used {torch.jit.export_opnames(traced_script_module)}" ) sprint( f"Successfully exported JIT Traced model at {exp_path}", color="green", ) except Exception as e_i: sprint(f"Torch JIT Trace export does not work!, {e_i}", color="red") break """
env.close() if __name__ == "__main__": from neodroidagent.configs import ( parse_arguments, get_upper_case_vars_or_protected_of, ) config = parse_arguments("Regular small grid world experiment", C) for key, arg in config.__dict__.items(): setattr(C, key, arg) draugr.sprint(f"\nUsing config: {C}\n", highlight=True, color="yellow") if not config.skip_confirmation: for key, arg in get_upper_case_vars_or_protected_of(C).items(): print(f"{key} = {arg}") input("\nPress Enter to begin... ") _agent = C.AGENT_TYPE(C) try: train_agent(C, _agent) except KeyboardInterrupt: print("Stopping") torch.cuda.empty_cache()
def export_detection_model( model_export_path: Path = ensure_existence( PROJECT_APP_PATH.user_data / "penn_fudan_segmentation" ) / "seg_skip_fis", SEED: int = 87539842, ) -> None: """ :param model_export_path: :type model_export_path: :return: :rtype:""" model = OutputActivationModule( SkipHourglassFission(input_channels=3, output_heads=(1,), encoding_depth=1) ) with TorchDeviceSession(device=global_torch_device("cpu"), model=model): with TorchEvalSession(model): seed_stack(SEED) # standard PyTorch mean-std input image normalization transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) frame_g = frame_generator(cv2.VideoCapture(0)) for image in tqdm(frame_g): example_input = ( transform(image).unsqueeze(0).to(global_torch_device()), ) try: traced_script_module = torch.jit.trace( model, example_input, # strict=strict_jit, check_inputs=( transform(next(frame_g)) .unsqueeze(0) .to(global_torch_device()), transform(next(frame_g)) .unsqueeze(0) .to(global_torch_device()), ), ) exp_path = model_export_path.with_suffix(".traced") traced_script_module.save(str(exp_path)) print( f"Traced Ops used {torch.jit.export_opnames(traced_script_module)}" ) sprint( f"Successfully exported JIT Traced model at {exp_path}", color="green", ) except Exception as e_i: sprint(f"Torch JIT Trace export does not work!, {e_i}", color="red") break
def __call__( self, agent: Type[Agent], *, load_time: Any = str(int(time.time())), seed: int = 0, save_ending_model: bool = False, save_training_resume: bool = False, continue_training: bool = True, train_agent: bool = True, debug: bool = False, num_envs: int = cpu_count(), **kwargs, ): """ Start a session, builds Agent and starts/connect environment(s), and runs Procedure :param args: :param kwargs: :return: """ kwargs.update(num_envs=num_envs) kwargs.update(train_agent=train_agent) kwargs.update(debug=debug) kwargs.update(environment=self._environment) with ContextWrapper(torchsnooper.snoop, debug): with ContextWrapper(torch.autograd.detect_anomaly, debug): if agent is None: raise NoAgent if inspect.isclass(agent): sprint("Instantiating Agent", color="crimson", bold=True, italic=True) torch_seed(seed) self._environment.seed(seed) agent = agent(load_time=load_time, seed=seed, **kwargs) agent_class_name = agent.__class__.__name__ total_shape = "_".join([ str(i) for i in (self._environment.observation_space.shape + self._environment.action_space.shape + self._environment.signal_space.shape) ]) environment_name = f"{self._environment.environment_name}_{total_shape}" save_directory = (PROJECT_APP_PATH.user_data / environment_name / agent_class_name) log_directory = (PROJECT_APP_PATH.user_log / environment_name / agent_class_name / load_time) if self._environment.action_space.is_discrete: rollout_drawer = GDKC(DiscreteScrollPlot, num_actions=self._environment. action_space.discrete_steps, default_delta=None) else: rollout_drawer = GDKC(SeriesScrollPlot, window_length=100, default_delta=None) if train_agent: # TODO: allow metric writing while not training with flag metric_writer = GDKC(TensorBoardPytorchWriter, path=log_directory) else: metric_writer = GDKC(MockWriter) with ContextWrapper(metric_writer, train_agent) as metric_writer: with ContextWrapper(rollout_drawer, num_envs == 1) as rollout_drawer: agent.build( self._environment.observation_space, self._environment.action_space, self._environment.signal_space, metric_writer=metric_writer, ) kwargs.update( environment_name=( self._environment.environment_name, ), save_directory=save_directory, log_directory=log_directory, load_time=load_time, seed=seed, train_agent=train_agent, ) found = False if continue_training: sprint( "Searching for previously trained models for initialisation for this configuration " "(Architecture, Action Space, Observation Space, ...)", color="crimson", bold=True, italic=True, ) found = agent.load(save_directory=save_directory, evaluation=not train_agent) if not found: sprint( "Did not find any previously trained models for this configuration", color="crimson", bold=True, italic=True, ) if not train_agent: agent.eval() else: agent.train() if not found: sprint( "Training from new initialisation", color="crimson", bold=True, italic=True, ) session_proc = self._procedure(agent, **kwargs) with CaptureEarlyStop( callbacks=self._procedure.stop_procedure, **kwargs): with StopWatch() as timer: with suppress(KeyboardInterrupt): training_resume = session_proc( metric_writer=metric_writer, rollout_drawer=rollout_drawer, **kwargs) if training_resume and "stats" in training_resume and save_training_resume: training_resume.stats.save(**kwargs) end_message = f"Training ended, time elapsed: {timer // 60:.0f}m {timer % 60:.0f}s" line_width = 9 sprint( f'\n{"-" * line_width} {end_message} {"-" * line_width}\n', color="crimson", bold=True, italic=True, ) if save_ending_model: agent.save(**kwargs) try: self._environment.close() except BrokenPipeError: pass exit(0)