コード例 #1
0
ファイル: vae.py プロジェクト: manik-hossain/torchsupport
 def checkpoint(self):
     """Performs a checkpoint for all encoders and decoders."""
     for name in self.network_names:
         netwrite(
             getattr(self, name),
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
コード例 #2
0
ファイル: training.py プロジェクト: vrydeep/torchsupport
 def checkpoint(self):
     the_net = self.net
     if isinstance(the_net, torch.nn.DataParallel):
         the_net = the_net.module
     netwrite(
         self.net,
         f"{self.checkpoint_path}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
     self.each_checkpoint()
コード例 #3
0
ファイル: training.py プロジェクト: mjendrusch/torchsupport
 def run_checkpoint(self):
   for name, the_net in self.checkpoint_names.items():
     if isinstance(the_net, torch.nn.DataParallel):
       the_net = the_net.module
     netwrite(
       the_net,
       f"{self.full_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
   self.each_checkpoint()
コード例 #4
0
 def checkpoint(self):
     netwrite(
         self.net,
         f"{self.network_name}-encoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
     # netwrite(
     #   self.decoder,
     #   f"{self.network_name}-decoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     # )
     self.each_checkpoint()
コード例 #5
0
ファイル: vae.py プロジェクト: vrydeep/torchsupport
 def checkpoint(self):
     """Performs a checkpoint for all encoders and decoders."""
     for name in self.network_names:
         the_net = getattr(self, name)
         if isinstance(the_net, torch.nn.DataParallel):
             the_net = the_net.module
         netwrite(
             the_net,
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
コード例 #6
0
ファイル: gan.py プロジェクト: manik-hossain/torchsupport
 def checkpoint(self):
     """Performs a checkpoint of all generators and discriminators."""
     for name in self.generator_names:
         netwrite(
             getattr(self, name),
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     for name in self.discriminator_names:
         netwrite(
             getattr(self, name),
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
コード例 #7
0
 def checkpoint(self):
     """Performs a checkpoint of all generators and discriminators."""
     for name in self.generator_names:
         the_net = getattr(self, name)
         if isinstance(the_net, torch.nn.DataParallel):
             the_net = the_net.module
         netwrite(
             the_net,
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     for name in self.discriminator_names:
         the_net = getattr(self, name)
         if isinstance(the_net, torch.nn.DataParallel):
             the_net = the_net.module
         netwrite(
             the_net,
             f"{self.checkpoint_path}-{name}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
コード例 #8
0
ファイル: clustering.py プロジェクト: vrydeep/torchsupport
    def checkpoint(self):
        the_net = self.net
        if isinstance(the_net, torch.nn.DataParallel):
            the_net = the_net.module
        netwrite(
            the_net,
            f"{self.network_name}-encoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
        )

        the_net = self.decoder
        if isinstance(the_net, torch.nn.DataParallel):
            the_net = the_net.module
        netwrite(
            the_net,
            f"{self.network_name}-decoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
        )

        for idx, classifier in enumerate(self.cluster_embeddings):
            the_net = classifier
            if isinstance(the_net, torch.nn.DataParallel):
                the_net = the_net.module
            netwrite(
                the_net,
                f"{self.network_name}-classifier-{idx}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
            )
        self.each_checkpoint()
コード例 #9
0
 def checkpoint(self):
     netwrite(
         self.net,
         f"{self.network_name}-encoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
     netwrite(
         self.decoder,
         f"{self.network_name}-decoder-epoch-{self.epoch_id}-step-{self.step_id}.torch"
     )
     for idx, classifier in enumerate(self.cluster_embeddings):
         netwrite(
             classifier,
             f"{self.network_name}-classifier-{idx}-epoch-{self.epoch_id}-step-{self.step_id}.torch"
         )
     self.each_checkpoint()
コード例 #10
0
ファイル: off_policy.py プロジェクト: twang15/torchsupport
 def each_checkpoint(self):
     netwrite(self.agent, f"{self.path}-checkpoint-{self.step_id}.torch")
コード例 #11
0
ファイル: checkpoint.py プロジェクト: mjendrusch/torchsupport
 def checkpoint(self):
     for name, the_net in self.checkpoint_names.items():
         if isinstance(the_net, torch.nn.DataParallel):
             the_net = the_net.module
         netwrite(the_net,
                  f"{self.ctx.path}-{name}-step-{self.ctx.step_id}.torch")