Exemplo n.º 1
0
def session_factory(
    agent: Type[AgentType] = None,
    config=None,
    *,
    session: Union[Type[EnvironmentSessionType], EnvironmentSession],
    save: bool = True,
    has_x_server: bool = True,
    skip_confirmation: bool = True,
    **kwargs,
):
    r"""
Entry point start a starting a training session with the functionality of parsing cmdline arguments and
confirming configuration to use before training and overwriting of default training configurations
"""

    if config is None:
        config = {}

    if isinstance(config, dict):
        config = NOD(**config)
    else:
        config = NOD(config.__dict__)

    if has_x_server:
        display_env = getenv("DISPLAY", None)
        if display_env is None:
            config.RENDER_ENVIRONMENT = False
            has_x_server = False

    config_mapping = config_to_mapping(config)
    config_mapping.update(**kwargs)

    config_mapping.update(save=save, has_x_server=has_x_server)

    if not skip_confirmation:
        sprint(f"\nUsing config: {config}\n", highlight=True, color="yellow")
        for key, arg in config_mapping:
            print(f"{key} = {arg}")

        input("\nPress Enter to begin... ")

    if session is None:
        raise NoProcedure
    elif inspect.isclass(session):
        session = session(**config_mapping)  # Use passed config arguments
    elif isinstance(session, GDKC):
        session = session(
            **kwargs
        )  # Assume some kw parameters is set prior to passing session, only override with explicit overrides

    try:
        session(agent, **config_mapping)
    except KeyboardInterrupt:
        print("Stopping")

    torch.cuda.empty_cache()

    exit(0)
Exemplo n.º 2
0
    def build(
        self,
        observation_space: ObservationSpace,
        action_space: ActionSpace,
        signal_space: SignalSpace,
        *,
        metric_writer: Writer = MockWriter(),
        print_model_repr: bool = True,
        verbose: bool = False,
        **kwargs,
    ) -> None:
        """

@param observation_space:
@param action_space:
@param signal_space:
@param metric_writer:
@param print_model_repr:
@param kwargs:
@return:
        :param verbose:
"""
        super().build(
            observation_space,
            action_space,
            signal_space,
            print_model_repr=print_model_repr,
            metric_writer=metric_writer,
            **kwargs,
        )

        if print_model_repr:
            for k, w in self.models.items():
                sprint(f"{k}: {w}", highlight=True, color="cyan")

                if metric_writer:
                    try:
                        model = copy.deepcopy(w).to("cpu")
                        dummy_input = model.sample_input()
                        sprint(f'{k} input: {dummy_input.shape}')

                        import contextlib

                        with contextlib.redirect_stdout(
                            None
                        ):  # So much useless frame info printed... Suppress it
                            if isinstance(metric_writer, GraphWriterMixin):
                                metric_writer.graph(model, dummy_input, verbose=verbose) # No naming available at moment...
                    except RuntimeError as ex:
                        sprint(
                            f"Tensorboard(Pytorch) does not support you model! No graph added: {str(ex).splitlines()[0]}",
                            color="red",
                            highlight=True,
                        )
Exemplo n.º 3
0
    def __infer_io_shapes(
        self,
        observation_space: ObservationSpace,
        action_space: ActionSpace,
        signal_space: SignalSpace,
        print_inferred_io_shapes: bool = True,
    ) -> None:
        """
Tries to infer input and output size from env if either _input_shape or _output_shape, is None or -1 (int)

:rtype: object
"""

        if self._input_shape is None or self._input_shape == -1:
            self._input_shape = observation_space.shape

        if self._output_shape is None or self._output_shape == -1:
            self._output_shape = action_space.shape

        # region print

        if print_inferred_io_shapes:
            sprint(
                f"input shape: {self._input_shape}\n"
                f"observation space: {observation_space}\n",
                color="green",
                bold=True,
                highlight=True,
            )

            sprint(
                f"output shape: {self._output_shape}\n"
                f"action space: {action_space}\n",
                color="yellow",
                bold=True,
                highlight=True,
            )

            sprint(
                f"signal shape: {signal_space}\n",
                color="blue",
                bold=True,
                highlight=True,
            )
Exemplo n.º 4
0
def export_detection_model(
    cfg: NOD,
    model_ckpt: Path,
    model_export_path: Path = Path("torch_model"),
    verbose: bool = True,
    onnx_export: bool = False,
    strict_jit: bool = False,
    ) -> None:
  """

:param verbose:
:type verbose:
:param cfg:
:type cfg:
:param model_ckpt:
:type model_ckpt:
:param model_export_path:
:type model_export_path:
:return:
:rtype:
"""
  model = SingleShotDectectionNms(cfg)

  checkpointer = CheckPointer(
      model, save_dir=ensure_existence(PROJECT_APP_PATH.user_data / "results")
      )
  checkpointer.load(model_ckpt, use_latest=model_ckpt is None)
  print(
      f"Loaded weights from {model_ckpt if model_ckpt else checkpointer.get_checkpoint_file()}"
      )

  model.post_init()
  model.to(global_torch_device())

  transforms = SSDTransform(
      cfg.input.image_size, cfg.input.pixel_mean, split=Split.Testing
      )
  model.eval()

  pre_quantize_model = False
  if pre_quantize_model:  # Accuracy may drop!
    if True:
      model = quantization.quantize_dynamic(model, dtype=torch.qint8)
    else:
      pass
      # model = quantization.quantize(model)

  frame_g = frame_generator(cv2.VideoCapture(0))
  for image in tqdm(frame_g):
    example_input = (transforms(image)[0].unsqueeze(0).to(global_torch_device()),)
    try:
      traced_script_module = torch.jit.script(
          model,
          # example_input,
          )
      exp_path = model_export_path.with_suffix(".compiled")
      traced_script_module.save(str(exp_path))
      print(f"Traced Ops used {torch.jit.export_opnames(traced_script_module)}")
      sprint(
          f"Successfully exported JIT Traced model at {exp_path}", color="green"
          )
    except Exception as e_i:
      sprint(f"Torch JIT Trace export does not work!, {e_i}", color="red")

    break
Exemplo n.º 5
0
def export_detection_model(
    cfg: NOD,
    model_checkpoint: Path,
    model_export_path: Path = Path("torch_model"),
    verbose: bool = True,
    onnx_export: bool = False,
    strict_jit: bool = False,
) -> None:
    """

    :param verbose:
    :type verbose:
    :param cfg:
    :type cfg:
    :param model_checkpoint:
    :type model_checkpoint:
    :param model_export_path:
    :type model_export_path:
    :return:
    :rtype:"""
    model = SingleShotDetection(cfg)

    checkpointer = CheckPointer(model,
                                save_dir=ensure_existence(
                                    PROJECT_APP_PATH.user_data / "results"))
    checkpointer.load(model_checkpoint, use_latest=model_checkpoint is None)
    print(
        f"Loaded weights from {model_checkpoint if model_checkpoint else checkpointer.get_checkpoint_file()}"
    )

    model.post_init()
    model.to(global_torch_device())

    transforms = SSDTransform(cfg.input.image_size,
                              cfg.input.pixel_mean,
                              split=SplitEnum.testing)
    model.eval()  # Important!

    fuse_quantize_model = False
    if fuse_quantize_model:
        modules_to_fuse = [
            ["conv", "bn", "relu"]
        ]  # Names of modules to fuse, maybe supply directly for architecture class/declaration
        model = torch.quantization.fuse_modules(
            model, modules_to_fuse=modules_to_fuse, inplace=False)

    pre_quantize_model = False
    if pre_quantize_model:  # Accuracy may drop!
        if True:
            model = quantization.quantize_dynamic(model, dtype=torch.qint8)
        else:
            pass
            # model = quantization.quantize(model)

    frame_g = frame_generator(cv2.VideoCapture(0))
    for image in tqdm(frame_g):
        example_input = (transforms(image)[0].unsqueeze(0).to(
            global_torch_device()), )
        try:
            if onnx_export:
                exp_path = model_export_path.with_suffix(".onnx")
                output = onnx.export(
                    model,
                    example_input,
                    str(exp_path),
                    verbose=verbose,
                    # export_params=True,  # store the trained parameter weights inside the model file
                    # opset_version=10,  # the onnx version to export the model to
                    # do_constant_folding=True,  # wether to execute constant folding for optimization
                    # input_names=["input"],  # the model's input names
                    # output_names=["output"],  # the model's output names
                    # dynamic_axes={
                    #  "input": {0: "batch_size"},  # variable lenght axes
                    #  "output": {0: "batch_size"},
                    #  }
                )
                sprint(f"Successfully exported ONNX model at {exp_path}",
                       color="blue")
            else:
                raise Exception("Just trace instead, ignore exception")
        except Exception as e:
            sprint(f"Torch ONNX export does not work, {e}", color="red")
            try:
                traced_script_module = torch.jit.trace(
                    model,
                    example_input,
                    # strict=strict_jit,
                    check_inputs=(
                        transforms(next(frame_g))[0].unsqueeze(0).to(
                            global_torch_device()),
                        transforms(next(frame_g))[0].unsqueeze(0).to(
                            global_torch_device()),
                    ),
                )
                exp_path = model_export_path.with_suffix(".traced")
                traced_script_module.save(str(exp_path))
                print(
                    f"Traced Ops used {torch.jit.export_opnames(traced_script_module)}"
                )
                sprint(
                    f"Successfully exported JIT Traced model at {exp_path}",
                    color="green",
                )
            except Exception as e_i:
                sprint(f"Torch JIT Trace export does not work!, {e_i}",
                       color="red")

        break
    """
Exemplo n.º 6
0
    env.close()


if __name__ == "__main__":

    from neodroidagent.configs import (
        parse_arguments,
        get_upper_case_vars_or_protected_of,
    )

    config = parse_arguments("Regular small grid world experiment", C)

    for key, arg in config.__dict__.items():
        setattr(C, key, arg)

    draugr.sprint(f"\nUsing config: {C}\n", highlight=True, color="yellow")
    if not config.skip_confirmation:
        for key, arg in get_upper_case_vars_or_protected_of(C).items():
            print(f"{key} = {arg}")
        input("\nPress Enter to begin... ")

    _agent = C.AGENT_TYPE(C)

    try:
        train_agent(C, _agent)
    except KeyboardInterrupt:
        print("Stopping")

    torch.cuda.empty_cache()
Exemplo n.º 7
0
def export_detection_model(
    model_export_path: Path = ensure_existence(
        PROJECT_APP_PATH.user_data / "penn_fudan_segmentation"
    )
    / "seg_skip_fis",
    SEED: int = 87539842,
) -> None:
    """

    :param model_export_path:
    :type model_export_path:
    :return:
    :rtype:"""

    model = OutputActivationModule(
        SkipHourglassFission(input_channels=3, output_heads=(1,), encoding_depth=1)
    )

    with TorchDeviceSession(device=global_torch_device("cpu"), model=model):
        with TorchEvalSession(model):

            seed_stack(SEED)

            # standard PyTorch mean-std input image normalization
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ]
            )

            frame_g = frame_generator(cv2.VideoCapture(0))

            for image in tqdm(frame_g):
                example_input = (
                    transform(image).unsqueeze(0).to(global_torch_device()),
                )

                try:
                    traced_script_module = torch.jit.trace(
                        model,
                        example_input,
                        # strict=strict_jit,
                        check_inputs=(
                            transform(next(frame_g))
                            .unsqueeze(0)
                            .to(global_torch_device()),
                            transform(next(frame_g))
                            .unsqueeze(0)
                            .to(global_torch_device()),
                        ),
                    )
                    exp_path = model_export_path.with_suffix(".traced")
                    traced_script_module.save(str(exp_path))
                    print(
                        f"Traced Ops used {torch.jit.export_opnames(traced_script_module)}"
                    )
                    sprint(
                        f"Successfully exported JIT Traced model at {exp_path}",
                        color="green",
                    )
                except Exception as e_i:
                    sprint(f"Torch JIT Trace export does not work!, {e_i}", color="red")

                break
Exemplo n.º 8
0
    def __call__(
            self,
            agent: Type[Agent],
            *,
            load_time: Any = str(int(time.time())),
            seed: int = 0,
            save_ending_model: bool = False,
            save_training_resume: bool = False,
            continue_training: bool = True,
            train_agent: bool = True,
            debug: bool = False,
            num_envs: int = cpu_count(),
            **kwargs,
    ):
        """
Start a session, builds Agent and starts/connect environment(s), and runs Procedure


:param args:
:param kwargs:
:return:
"""
        kwargs.update(num_envs=num_envs)
        kwargs.update(train_agent=train_agent)
        kwargs.update(debug=debug)
        kwargs.update(environment=self._environment)

        with ContextWrapper(torchsnooper.snoop, debug):
            with ContextWrapper(torch.autograd.detect_anomaly, debug):

                if agent is None:
                    raise NoAgent

                if inspect.isclass(agent):
                    sprint("Instantiating Agent",
                           color="crimson",
                           bold=True,
                           italic=True)
                    torch_seed(seed)
                    self._environment.seed(seed)

                    agent = agent(load_time=load_time, seed=seed, **kwargs)

                agent_class_name = agent.__class__.__name__

                total_shape = "_".join([
                    str(i)
                    for i in (self._environment.observation_space.shape +
                              self._environment.action_space.shape +
                              self._environment.signal_space.shape)
                ])

                environment_name = f"{self._environment.environment_name}_{total_shape}"

                save_directory = (PROJECT_APP_PATH.user_data /
                                  environment_name / agent_class_name)
                log_directory = (PROJECT_APP_PATH.user_log / environment_name /
                                 agent_class_name / load_time)

                if self._environment.action_space.is_discrete:
                    rollout_drawer = GDKC(DiscreteScrollPlot,
                                          num_actions=self._environment.
                                          action_space.discrete_steps,
                                          default_delta=None)
                else:
                    rollout_drawer = GDKC(SeriesScrollPlot,
                                          window_length=100,
                                          default_delta=None)

                if train_agent:  # TODO: allow metric writing while not training with flag
                    metric_writer = GDKC(TensorBoardPytorchWriter,
                                         path=log_directory)
                else:
                    metric_writer = GDKC(MockWriter)

                with ContextWrapper(metric_writer,
                                    train_agent) as metric_writer:
                    with ContextWrapper(rollout_drawer,
                                        num_envs == 1) as rollout_drawer:

                        agent.build(
                            self._environment.observation_space,
                            self._environment.action_space,
                            self._environment.signal_space,
                            metric_writer=metric_writer,
                        )

                        kwargs.update(
                            environment_name=(
                                self._environment.environment_name, ),
                            save_directory=save_directory,
                            log_directory=log_directory,
                            load_time=load_time,
                            seed=seed,
                            train_agent=train_agent,
                        )

                        found = False
                        if continue_training:
                            sprint(
                                "Searching for previously trained models for initialisation for this configuration "
                                "(Architecture, Action Space, Observation Space, ...)",
                                color="crimson",
                                bold=True,
                                italic=True,
                            )
                            found = agent.load(save_directory=save_directory,
                                               evaluation=not train_agent)
                            if not found:
                                sprint(
                                    "Did not find any previously trained models for this configuration",
                                    color="crimson",
                                    bold=True,
                                    italic=True,
                                )

                        if not train_agent:
                            agent.eval()
                        else:
                            agent.train()

                        if not found:
                            sprint(
                                "Training from new initialisation",
                                color="crimson",
                                bold=True,
                                italic=True,
                            )

                        session_proc = self._procedure(agent, **kwargs)

                        with CaptureEarlyStop(
                                callbacks=self._procedure.stop_procedure,
                                **kwargs):
                            with StopWatch() as timer:
                                with suppress(KeyboardInterrupt):
                                    training_resume = session_proc(
                                        metric_writer=metric_writer,
                                        rollout_drawer=rollout_drawer,
                                        **kwargs)
                                    if training_resume and "stats" in training_resume and save_training_resume:
                                        training_resume.stats.save(**kwargs)

                        end_message = f"Training ended, time elapsed: {timer // 60:.0f}m {timer % 60:.0f}s"
                        line_width = 9
                        sprint(
                            f'\n{"-" * line_width} {end_message} {"-" * line_width}\n',
                            color="crimson",
                            bold=True,
                            italic=True,
                        )

                        if save_ending_model:
                            agent.save(**kwargs)

                        try:
                            self._environment.close()
                        except BrokenPipeError:
                            pass

                        exit(0)