def test_linear_decay(self): optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) lr_scheduler = scheduler.LRScheduler(lr_decay='linear', optD=optD, optG=optG, num_steps=self.num_steps, start_step=5) log_data = metric_log.MetricLog() for step in range(self.num_steps): lr_scheduler.step(log_data, step) if step < lr_scheduler.start_step: assert abs(2e-4 - self.get_lr(optD)) < 1e-5 assert abs(2e-4 - self.get_lr(optG)) < 1e-5 else: curr_lr = ((1 - (max(0, step - lr_scheduler.start_step) / (self.num_steps - lr_scheduler.start_step))) * self.lr_D) assert abs(curr_lr - self.get_lr(optD)) < 1e-5 assert abs(curr_lr - self.get_lr(optG)) < 1e-5
def test_arguments(self): with pytest.raises(NotImplementedError): optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) scheduler.LRScheduler(lr_decay='does_not_exist', optD=optD, optG=optG, num_steps=self.num_steps)
def test_no_decay(self): optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) lr_scheduler = scheduler.LRScheduler(lr_decay='None', optD=optD, optG=optG, num_steps=self.num_steps) log_data = metric_log.MetricLog() for step in range(1, self.num_steps + 1): lr_scheduler.step(log_data, step) assert (self.lr_D == self.get_lr(optD)) assert (self.lr_G == self.get_lr(optG))
def test_linear_decay(self): optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) lr_scheduler = scheduler.LRScheduler(lr_decay='linear', optD=optD, optG=optG, num_steps=self.num_steps) log_data = metric_log.MetricLog() for step in range(1, self.num_steps + 1): lr_scheduler.step(log_data, step) curr_lr = ((1 - step / self.num_steps) * self.lr_D) assert (curr_lr - self.get_lr(optD) < 1e-5) assert (curr_lr - self.get_lr(optG) < 1e-5)
def __init__(self, netD, netG, optD, optG, dataloader, num_steps, n_gpus, log_dir='./log', n_dis=1, netG_ckpt_file=None, netD_ckpt_file=None, lr_decay=None, device=None, **kwargs): self.netD = netD self.netG = netG self.optD = optD self.optG = optG self.n_dis = n_dis self.lr_decay = lr_decay self.dataloader = dataloader self.num_steps = num_steps self.device = device # Log directory self.log_dir = log_dir if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) # Obtain custom or latest checkpoint files if netG_ckpt_file: self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file) self.netG_ckpt_file = netG_ckpt_file else: self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netG') self.netG_ckpt_file = self._get_latest_checkpoint( self.netG_ckpt_dir) # can be None if netD_ckpt_file: self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file) self.netD_ckpt_file = netD_ckpt_file else: self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netD') self.netD_ckpt_file = self._get_latest_checkpoint( self.netD_ckpt_dir) # can be None # Default parameters, unless provided by kwargs default_params = { 'print_steps': kwargs.get('print_steps', 1), 'vis_steps': kwargs.get('vis_steps', 500), 'flush_secs': kwargs.get('flush_secs', 30), 'log_steps': kwargs.get('log_steps', 50), 'save_steps': kwargs.get('save_steps', 5000), 'save_when_end': kwargs.get('save_when_end', True), } for param in default_params: self.__dict__[param] = default_params[param] # Hyperparameters for logging experiments self.params = { 'log_dir': self.log_dir, 'num_steps': self.num_steps, 'batch_size': self.dataloader.batch_size, 'n_dis': self.n_dis, 'lr_decay': self.lr_decay, 'optD': optD.__repr__(), 'optG': optG.__repr__(), } self.params.update(default_params) # Log training hyperparmaeters self._log_params(self.params) # Device for hosting model and data if not self.device: self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else "cpu") # Ensure model and data are in the same device for net in [self.netD, self.netG]: if net.device != self.device: net.to(self.device) # Training helper objects self.logger = logger.Logger(log_dir=self.log_dir, num_steps=self.num_steps, dataset_size=len(self.dataloader), flush_secs=self.flush_secs, device=self.device) self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay, optD=self.optD, optG=self.optG, num_steps=self.num_steps)
def __init__(self, netD, netG, optD, optG, dataloader, num_steps, log_dir='./log', n_dis=1, lr_decay=None, device=None, netG_ckpt_file=None, netD_ckpt_file=None, print_steps=1, vis_steps=1000, log_steps=100, save_steps=10000, flush_secs=60): # Input values checks ints_to_check = { 'num_steps': num_steps, 'n_dis': n_dis, 'print_steps': print_steps, 'vis_steps': vis_steps, 'log_steps': log_steps, 'save_steps': save_steps, 'flush_secs': flush_secs } for name, var in ints_to_check.items(): if var < 1: raise ValueError('{} must be at least 1 but got {}.'.format( name, var)) self.netD = netD self.netG = netG self.optD = optD self.optG = optG self.n_dis = n_dis self.lr_decay = lr_decay self.dataloader = dataloader self.num_steps = num_steps self.device = device self.log_dir = log_dir self.netG_ckpt_file = netG_ckpt_file self.netD_ckpt_file = netD_ckpt_file self.print_steps = print_steps self.vis_steps = vis_steps self.log_steps = log_steps self.save_steps = save_steps if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) # Training helper objects self.logger = logger.Logger(log_dir=self.log_dir, num_steps=self.num_steps, dataset_size=len(self.dataloader), flush_secs=flush_secs, device=self.device) self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay, optD=self.optD, optG=self.optG, num_steps=self.num_steps) # Obtain custom or latest checkpoint files if self.netG_ckpt_file: self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file) self.netG_ckpt_file = netG_ckpt_file else: self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netG') self.netG_ckpt_file = self._get_latest_checkpoint( self.netG_ckpt_dir) # can be None if self.netD_ckpt_file: self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file) self.netD_ckpt_file = netD_ckpt_file else: self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netD') self.netD_ckpt_file = self._get_latest_checkpoint( self.netD_ckpt_dir) # Log hyperparameters for experiments self.params = { 'log_dir': self.log_dir, 'batch_size': self.dataloader.batch_size, 'n_dis': self.n_dis, 'lr_decay': self.lr_decay, 'optD': optD.__repr__(), 'optG': optG.__repr__(), 'save_steps': self.save_steps, } self._log_params(self.params) # Device for hosting model and data if not self.device: self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else "cpu") # Ensure model and data are in the same device for net in [self.netD, self.netG]: if net.device != self.device: net.to(self.device)
def __init__( self, netD, netG, optD, optG, dataloader, num_steps, log_dir="./log", n_dis=1, lr_decay=None, device=None, netG_ckpt_file=None, netD_ckpt_file=None, print_steps=1, vis_steps=500, log_steps=50, save_steps=5000, flush_secs=30, ): # Input values checks ints_to_check = { "num_steps": num_steps, "n_dis": n_dis, "print_steps": print_steps, "vis_steps": vis_steps, "log_steps": log_steps, "save_steps": save_steps, "flush_secs": flush_secs, } for name, var in ints_to_check.items(): if var < 1: raise ValueError("{} must be at least 1 but got {}.".format( name, var)) self.netD = netD self.netG = netG self.optD = optD self.optG = optG self.n_dis = n_dis self.lr_decay = lr_decay self.dataloader = dataloader self.num_steps = num_steps self.device = device self.log_dir = log_dir self.netG_ckpt_file = netG_ckpt_file self.netD_ckpt_file = netD_ckpt_file self.print_steps = print_steps self.vis_steps = vis_steps self.log_steps = log_steps self.save_steps = save_steps if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) # Training helper objects self.logger = logger.Logger( log_dir=self.log_dir, num_steps=self.num_steps, dataset_size=len(self.dataloader), flush_secs=flush_secs, device=self.device, ) self.scheduler = scheduler.LRScheduler( lr_decay=self.lr_decay, optD=self.optD, optG=self.optG, num_steps=self.num_steps, ) # Obtain custom or latest checkpoint files if self.netG_ckpt_file: self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file) self.netG_ckpt_file = netG_ckpt_file else: self.netG_ckpt_dir = os.path.join(self.log_dir, "checkpoints", "netG") self.netG_ckpt_file = self._get_latest_checkpoint( self.netG_ckpt_dir) # can be None if self.netD_ckpt_file: self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file) self.netD_ckpt_file = netD_ckpt_file else: self.netD_ckpt_dir = os.path.join(self.log_dir, "checkpoints", "netD") self.netD_ckpt_file = self._get_latest_checkpoint( self.netD_ckpt_dir) # Log hyperparameters for experiments self.params = { "log_dir": self.log_dir, "num_steps": self.num_steps, "batch_size": self.dataloader.batch_size, "n_dis": self.n_dis, "lr_decay": self.lr_decay, "optD": optD.__repr__(), "optG": optG.__repr__(), "save_steps": self.save_steps, } self._log_params(self.params) # Device for hosting model and data if not self.device: self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # Ensure model and data are in the same device for net in [self.netD, self.netG]: if net.device != self.device: net.to(self.device)