Ejemplo n.º 1
0
 def checkpoint(self):
     """Performs a checkpoint for all encoders and decoders."""
     for name in self.network_names:
         netwrite(
             getattr(self, name),
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
Ejemplo n.º 2
0
 def checkpoint(self):
     the_net = self.net
     if isinstance(the_net, torch.nn.DataParallel):
         the_net = the_net.module
     netwrite(
         self.net,
         f"{self.checkpoint_path}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
     self.each_checkpoint()
Ejemplo n.º 3
0
 def run_checkpoint(self):
   for name, the_net in self.checkpoint_names.items():
     if isinstance(the_net, torch.nn.DataParallel):
       the_net = the_net.module
     netwrite(
       the_net,
       f"{self.full_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
   self.each_checkpoint()
Ejemplo n.º 4
0
 def checkpoint(self):
     netwrite(
         self.net,
         f"{self.network_name}-encoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
     # netwrite(
     #   self.decoder,
     #   f"{self.network_name}-decoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     # )
     self.each_checkpoint()
Ejemplo n.º 5
0
 def checkpoint(self):
     """Performs a checkpoint for all encoders and decoders."""
     for name in self.network_names:
         the_net = getattr(self, name)
         if isinstance(the_net, torch.nn.DataParallel):
             the_net = the_net.module
         netwrite(
             the_net,
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
Ejemplo n.º 6
0
 def checkpoint(self):
     """Performs a checkpoint of all generators and discriminators."""
     for name in self.generator_names:
         netwrite(
             getattr(self, name),
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     for name in self.discriminator_names:
         netwrite(
             getattr(self, name),
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
Ejemplo n.º 7
0
 def checkpoint(self):
     """Performs a checkpoint of all generators and discriminators."""
     for name in self.generator_names:
         the_net = getattr(self, name)
         if isinstance(the_net, torch.nn.DataParallel):
             the_net = the_net.module
         netwrite(
             the_net,
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     for name in self.discriminator_names:
         the_net = getattr(self, name)
         if isinstance(the_net, torch.nn.DataParallel):
             the_net = the_net.module
         netwrite(
             the_net,
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
Ejemplo n.º 8
0
    def checkpoint(self):
        the_net = self.net
        if isinstance(the_net, torch.nn.DataParallel):
            the_net = the_net.module
        netwrite(
            the_net,
            f"{self.network_name}-encoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
        )

        the_net = self.decoder
        if isinstance(the_net, torch.nn.DataParallel):
            the_net = the_net.module
        netwrite(
            the_net,
            f"{self.network_name}-decoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
        )

        for idx, classifier in enumerate(self.cluster_embeddings):
            the_net = classifier
            if isinstance(the_net, torch.nn.DataParallel):
                the_net = the_net.module
            netwrite(
                the_net,
                f"{self.network_name}-classifier-{idx}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
            )
        self.each_checkpoint()
Ejemplo n.º 9
0
 def checkpoint(self):
     netwrite(
         self.net,
         f"{self.network_name}-encoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
     netwrite(
         self.decoder,
         f"{self.network_name}-decoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
     for idx, classifier in enumerate(self.cluster_embeddings):
         netwrite(
             classifier,
             f"{self.network_name}-classifier-{idx}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
Ejemplo n.º 10
0
 def each_checkpoint(self):
     netwrite(self.agent, f"{self.path}-checkpoint-{self.step_id}.torch")
Ejemplo n.º 11
0
 def checkpoint(self):
     for name, the_net in self.checkpoint_names.items():
         if isinstance(the_net, torch.nn.DataParallel):
             the_net = the_net.module
         netwrite(the_net,
                  f"{self.ctx.path}-{name}-step-{self.ctx.step_id}.torch")