示例#1
0
    def save(self, states: Tensor, masks: Tensor):
        if self.save_path is None:
            return

        save_dir = os.path.join(self.save_path, str(int(time())))
        os.makedirs(save_dir, exist_ok=False)

        model = jit.trace(self.network, (states, masks))
        jit.save(model, os.path.join(save_dir, "network.pt"))
示例#2
0
def save_traced_model(
    model: jit.ScriptModule,
    logdir: Union[str, Path] = None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    out_dir: Union[str, Path] = None,
    out_model: Union[str, Path] = None,
    checkpoint_name: str = None,
) -> None:
    """Saves traced model.

    Args:
        model (ScriptModule): Traced model
        logdir (Union[str, Path]): Path to experiment
        method_name (str): Name of the method was traced
        mode (str): Model's mode - `train` or `eval`
        requires_grad (bool): Whether model was traced with require_grad or not
        opt_level (str): Apex FP16 init level used during tracing
        out_dir (Union[str, Path]): Directory to save model to
            (overrides logdir)
        out_model (Union[str, Path]): Path to save model to
            (overrides logdir & out_dir)
        checkpoint_name (str): Checkpoint name used to restore the model

    Raises:
        ValueError: if nothing out of `logdir`, `out_dir` or `out_model`
          is specified.
    """
    if out_model is None:
        file_name = get_trace_name(
            method_name=method_name,
            mode=mode,
            requires_grad=requires_grad,
            opt_level=opt_level,
            additional_string=checkpoint_name,
        )

        output: Path = out_dir
        if output is None:
            if logdir is None:
                raise ValueError(
                    "One of `logdir`, `out_dir` or `out_model` "
                    "should be specified"
                )
            output: Path = Path(logdir) / "trace"

        output.mkdir(exist_ok=True, parents=True)

        out_model = str(output / file_name)
    else:
        out_model = str(out_model)

    jit.save(model, out_model)
示例#3
0
 def check_point(self, is_backup: bool) -> None:
     self.model = self.model.to(torch.device("cpu"))
     REPO = backup(STONK_REPO) if is_backup else STONK_REPO
     torch.save(self.model, REPO.format("reg") + self.name + ".pt")
     jit.save(
         cast(ScriptModule, jit.script(self.model)),
         REPO.format("jit") + self.name + ".pt",
     )
     with open(REPO.format("cp") + self.name + ".json", "w") as f:
         json.dump(self.cp, f, indent=4)
     self.model = self.model.to(torch.device("cuda:0"))
示例#4
0
def main(model_path='model.pt', name='densenet201', out_name='model.trcd'):
    model = torch.load(model_path)
    model, = list(model.children())
    state = model.state_dict()

    base = get_baseline(name=name)
    base.load_state_dict(state)
    base.eval()

    model = jit.trace(base, example_inputs=(torch.rand(4, 3, 256, 256), ))
    jit.save(model, out_name)
示例#5
0
    model = AlphaCompile(train[0].observation_space, len(FLAGS) + 1)
    logging.info(model)

    model.train()
    for e in range(1, EPOCHS + 1):
        shuffle(train)
        for program in train:
            try:
                model.play(program)
            except Exception as e:
                logging.exception("{}: {}".format(repr(program), e),
                                  exc_info=True)
                continue
            except KeyboardInterrupt:
                save(model.state_dict(), join('source', 'AlphaCompile.pth'))
                break
            else:
                logging.info("Saving the model")
                save(model.state_dict(), join('source', 'AlphaCompile.pth'))

    model.eval()
    for program in test:
        try:
            model.play(program)
        except Exception as e:
            logging.exception("{}: {}".format(repr(program), e), exc_info=True)
            continue
        else:
            jit.save(jit.trace(model, (randn(program.observation_space), )),
                     join('source', 'AlphaCompile.pt'))