Example #1
0
    def test_linear_decay(self):
        optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
        optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))

        lr_scheduler = scheduler.LRScheduler(lr_decay='linear',
                                             optD=optD,
                                             optG=optG,
                                             num_steps=self.num_steps,
                                             start_step=5)

        log_data = metric_log.MetricLog()
        for step in range(self.num_steps):
            lr_scheduler.step(log_data, step)

            if step < lr_scheduler.start_step:
                assert abs(2e-4 - self.get_lr(optD)) < 1e-5
                assert abs(2e-4 - self.get_lr(optG)) < 1e-5

            else:
                curr_lr = ((1 - (max(0, step - lr_scheduler.start_step) /
                                 (self.num_steps - lr_scheduler.start_step))) *
                           self.lr_D)

                assert abs(curr_lr - self.get_lr(optD)) < 1e-5
                assert abs(curr_lr - self.get_lr(optG)) < 1e-5
Example #2
0
 def test_arguments(self):
     with pytest.raises(NotImplementedError):
         optD = optim.Adam(self.netD.parameters(),
                           self.lr_D,
                           betas=(0.0, 0.9))
         optG = optim.Adam(self.netG.parameters(),
                           self.lr_G,
                           betas=(0.0, 0.9))
         scheduler.LRScheduler(lr_decay='does_not_exist',
                               optD=optD,
                               optG=optG,
                               num_steps=self.num_steps)
Example #3
0
    def test_no_decay(self):
        optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
        optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))

        lr_scheduler = scheduler.LRScheduler(lr_decay='None',
                                             optD=optD,
                                             optG=optG,
                                             num_steps=self.num_steps)

        log_data = metric_log.MetricLog()
        for step in range(1, self.num_steps + 1):
            lr_scheduler.step(log_data, step)

            assert (self.lr_D == self.get_lr(optD))
            assert (self.lr_G == self.get_lr(optG))
Example #4
0
    def test_linear_decay(self):
        optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
        optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))

        lr_scheduler = scheduler.LRScheduler(lr_decay='linear',
                                             optD=optD,
                                             optG=optG,
                                             num_steps=self.num_steps)

        log_data = metric_log.MetricLog()
        for step in range(1, self.num_steps + 1):
            lr_scheduler.step(log_data, step)

            curr_lr = ((1 - step / self.num_steps) * self.lr_D)

            assert (curr_lr - self.get_lr(optD) < 1e-5)
            assert (curr_lr - self.get_lr(optG) < 1e-5)
Example #5
0
    def __init__(self,
                 netD,
                 netG,
                 optD,
                 optG,
                 dataloader,
                 num_steps,
                 n_gpus,
                 log_dir='./log',
                 n_dis=1,
                 netG_ckpt_file=None,
                 netD_ckpt_file=None,
                 lr_decay=None,
                 device=None,
                 **kwargs):
        self.netD = netD
        self.netG = netG
        self.optD = optD
        self.optG = optG
        self.n_dis = n_dis
        self.lr_decay = lr_decay
        self.dataloader = dataloader
        self.num_steps = num_steps
        self.device = device

        # Log directory
        self.log_dir = log_dir
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # Obtain custom or latest checkpoint files
        if netG_ckpt_file:
            self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file)
            self.netG_ckpt_file = netG_ckpt_file
        else:
            self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
                                              'netG')
            self.netG_ckpt_file = self._get_latest_checkpoint(
                self.netG_ckpt_dir)  # can be None

        if netD_ckpt_file:
            self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file)
            self.netD_ckpt_file = netD_ckpt_file
        else:
            self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
                                              'netD')
            self.netD_ckpt_file = self._get_latest_checkpoint(
                self.netD_ckpt_dir)  # can be None

        # Default parameters, unless provided by kwargs
        default_params = {
            'print_steps': kwargs.get('print_steps', 1),
            'vis_steps': kwargs.get('vis_steps', 500),
            'flush_secs': kwargs.get('flush_secs', 30),
            'log_steps': kwargs.get('log_steps', 50),
            'save_steps': kwargs.get('save_steps', 5000),
            'save_when_end': kwargs.get('save_when_end', True),
        }
        for param in default_params:
            self.__dict__[param] = default_params[param]

        # Hyperparameters for logging experiments
        self.params = {
            'log_dir': self.log_dir,
            'num_steps': self.num_steps,
            'batch_size': self.dataloader.batch_size,
            'n_dis': self.n_dis,
            'lr_decay': self.lr_decay,
            'optD': optD.__repr__(),
            'optG': optG.__repr__(),
        }
        self.params.update(default_params)

        # Log training hyperparmaeters
        self._log_params(self.params)

        # Device for hosting model and data
        if not self.device:
            self.device = torch.device(
                'cuda:0' if torch.cuda.is_available() else "cpu")

        # Ensure model and data are in the same device
        for net in [self.netD, self.netG]:
            if net.device != self.device:
                net.to(self.device)

        # Training helper objects
        self.logger = logger.Logger(log_dir=self.log_dir,
                                    num_steps=self.num_steps,
                                    dataset_size=len(self.dataloader),
                                    flush_secs=self.flush_secs,
                                    device=self.device)

        self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay,
                                               optD=self.optD,
                                               optG=self.optG,
                                               num_steps=self.num_steps)
Example #6
0
    def __init__(self,
                 netD,
                 netG,
                 optD,
                 optG,
                 dataloader,
                 num_steps,
                 log_dir='./log',
                 n_dis=1,
                 lr_decay=None,
                 device=None,
                 netG_ckpt_file=None,
                 netD_ckpt_file=None,
                 print_steps=1,
                 vis_steps=1000,
                 log_steps=100,
                 save_steps=10000,
                 flush_secs=60):
        # Input values checks
        ints_to_check = {
            'num_steps': num_steps,
            'n_dis': n_dis,
            'print_steps': print_steps,
            'vis_steps': vis_steps,
            'log_steps': log_steps,
            'save_steps': save_steps,
            'flush_secs': flush_secs
        }
        for name, var in ints_to_check.items():
            if var < 1:
                raise ValueError('{} must be at least 1 but got {}.'.format(
                    name, var))

        self.netD = netD
        self.netG = netG
        self.optD = optD
        self.optG = optG
        self.n_dis = n_dis
        self.lr_decay = lr_decay
        self.dataloader = dataloader
        self.num_steps = num_steps
        self.device = device
        self.log_dir = log_dir
        self.netG_ckpt_file = netG_ckpt_file
        self.netD_ckpt_file = netD_ckpt_file
        self.print_steps = print_steps
        self.vis_steps = vis_steps
        self.log_steps = log_steps
        self.save_steps = save_steps

        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # Training helper objects
        self.logger = logger.Logger(log_dir=self.log_dir,
                                    num_steps=self.num_steps,
                                    dataset_size=len(self.dataloader),
                                    flush_secs=flush_secs,
                                    device=self.device)

        self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay,
                                               optD=self.optD,
                                               optG=self.optG,
                                               num_steps=self.num_steps)

        # Obtain custom or latest checkpoint files
        if self.netG_ckpt_file:
            self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file)
            self.netG_ckpt_file = netG_ckpt_file
        else:
            self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
                                              'netG')
            self.netG_ckpt_file = self._get_latest_checkpoint(
                self.netG_ckpt_dir)  # can be None

        if self.netD_ckpt_file:
            self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file)
            self.netD_ckpt_file = netD_ckpt_file
        else:
            self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
                                              'netD')
            self.netD_ckpt_file = self._get_latest_checkpoint(
                self.netD_ckpt_dir)

        # Log hyperparameters for experiments
        self.params = {
            'log_dir': self.log_dir,
            'batch_size': self.dataloader.batch_size,
            'n_dis': self.n_dis,
            'lr_decay': self.lr_decay,
            'optD': optD.__repr__(),
            'optG': optG.__repr__(),
            'save_steps': self.save_steps,
        }
        self._log_params(self.params)

        # Device for hosting model and data
        if not self.device:
            self.device = torch.device(
                'cuda:0' if torch.cuda.is_available() else "cpu")

        # Ensure model and data are in the same device
        for net in [self.netD, self.netG]:
            if net.device != self.device:
                net.to(self.device)
Example #7
0
    def __init__(
        self,
        netD,
        netG,
        optD,
        optG,
        dataloader,
        num_steps,
        log_dir="./log",
        n_dis=1,
        lr_decay=None,
        device=None,
        netG_ckpt_file=None,
        netD_ckpt_file=None,
        print_steps=1,
        vis_steps=500,
        log_steps=50,
        save_steps=5000,
        flush_secs=30,
    ):
        # Input values checks
        ints_to_check = {
            "num_steps": num_steps,
            "n_dis": n_dis,
            "print_steps": print_steps,
            "vis_steps": vis_steps,
            "log_steps": log_steps,
            "save_steps": save_steps,
            "flush_secs": flush_secs,
        }
        for name, var in ints_to_check.items():
            if var < 1:
                raise ValueError("{} must be at least 1 but got {}.".format(
                    name, var))

        self.netD = netD
        self.netG = netG
        self.optD = optD
        self.optG = optG
        self.n_dis = n_dis
        self.lr_decay = lr_decay
        self.dataloader = dataloader
        self.num_steps = num_steps
        self.device = device
        self.log_dir = log_dir
        self.netG_ckpt_file = netG_ckpt_file
        self.netD_ckpt_file = netD_ckpt_file
        self.print_steps = print_steps
        self.vis_steps = vis_steps
        self.log_steps = log_steps
        self.save_steps = save_steps

        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # Training helper objects
        self.logger = logger.Logger(
            log_dir=self.log_dir,
            num_steps=self.num_steps,
            dataset_size=len(self.dataloader),
            flush_secs=flush_secs,
            device=self.device,
        )

        self.scheduler = scheduler.LRScheduler(
            lr_decay=self.lr_decay,
            optD=self.optD,
            optG=self.optG,
            num_steps=self.num_steps,
        )

        # Obtain custom or latest checkpoint files
        if self.netG_ckpt_file:
            self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file)
            self.netG_ckpt_file = netG_ckpt_file
        else:
            self.netG_ckpt_dir = os.path.join(self.log_dir, "checkpoints",
                                              "netG")
            self.netG_ckpt_file = self._get_latest_checkpoint(
                self.netG_ckpt_dir)  # can be None

        if self.netD_ckpt_file:
            self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file)
            self.netD_ckpt_file = netD_ckpt_file
        else:
            self.netD_ckpt_dir = os.path.join(self.log_dir, "checkpoints",
                                              "netD")
            self.netD_ckpt_file = self._get_latest_checkpoint(
                self.netD_ckpt_dir)

        # Log hyperparameters for experiments
        self.params = {
            "log_dir": self.log_dir,
            "num_steps": self.num_steps,
            "batch_size": self.dataloader.batch_size,
            "n_dis": self.n_dis,
            "lr_decay": self.lr_decay,
            "optD": optD.__repr__(),
            "optG": optG.__repr__(),
            "save_steps": self.save_steps,
        }
        self._log_params(self.params)

        # Device for hosting model and data
        if not self.device:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")

        # Ensure model and data are in the same device
        for net in [self.netD, self.netG]:
            if net.device != self.device:
                net.to(self.device)