def setup(self): self.log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_log") self.logger = logger.Logger(log_dir=self.log_dir, num_steps=100, dataset_size=50000, flush_secs=30, device=torch.device('cpu')) self.scalars = [ 'errG', 'errD', 'D(x)', 'D(G(z))', 'img', 'lr_D', 'lr_G', ]
def __init__(self, netD, netG, optD, optG, dataloader, num_steps, n_gpus, log_dir='./log', n_dis=1, netG_ckpt_file=None, netD_ckpt_file=None, lr_decay=None, device=None, **kwargs): self.netD = netD self.netG = netG self.optD = optD self.optG = optG self.n_dis = n_dis self.lr_decay = lr_decay self.dataloader = dataloader self.num_steps = num_steps self.device = device # Log directory self.log_dir = log_dir if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) # Obtain custom or latest checkpoint files if netG_ckpt_file: self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file) self.netG_ckpt_file = netG_ckpt_file else: self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netG') self.netG_ckpt_file = self._get_latest_checkpoint( self.netG_ckpt_dir) # can be None if netD_ckpt_file: self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file) self.netD_ckpt_file = netD_ckpt_file else: self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netD') self.netD_ckpt_file = self._get_latest_checkpoint( self.netD_ckpt_dir) # can be None # Default parameters, unless provided by kwargs default_params = { 'print_steps': kwargs.get('print_steps', 1), 'vis_steps': kwargs.get('vis_steps', 500), 'flush_secs': kwargs.get('flush_secs', 30), 'log_steps': kwargs.get('log_steps', 50), 'save_steps': kwargs.get('save_steps', 5000), 'save_when_end': kwargs.get('save_when_end', True), } for param in default_params: self.__dict__[param] = default_params[param] # Hyperparameters for logging experiments self.params = { 'log_dir': self.log_dir, 'num_steps': self.num_steps, 'batch_size': self.dataloader.batch_size, 'n_dis': self.n_dis, 'lr_decay': self.lr_decay, 'optD': optD.__repr__(), 'optG': optG.__repr__(), } self.params.update(default_params) # Log training hyperparmaeters self._log_params(self.params) # Device for hosting model and data if not self.device: self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else "cpu") # Ensure model and data are in the same device for net in [self.netD, self.netG]: if net.device != self.device: net.to(self.device) # Training helper objects self.logger = logger.Logger(log_dir=self.log_dir, num_steps=self.num_steps, dataset_size=len(self.dataloader), flush_secs=self.flush_secs, device=self.device) self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay, optD=self.optD, optG=self.optG, num_steps=self.num_steps)
def __init__( self, netD, netG, optD, optG, dataloader, num_steps, log_dir="./log", n_dis=1, lr_decay=None, device=None, netG_ckpt_file=None, netD_ckpt_file=None, print_steps=1, vis_steps=500, log_steps=50, save_steps=5000, flush_secs=30, ): # Input values checks ints_to_check = { "num_steps": num_steps, "n_dis": n_dis, "print_steps": print_steps, "vis_steps": vis_steps, "log_steps": log_steps, "save_steps": save_steps, "flush_secs": flush_secs, } for name, var in ints_to_check.items(): if var < 1: raise ValueError("{} must be at least 1 but got {}.".format( name, var)) self.netD = netD self.netG = netG self.optD = optD self.optG = optG self.n_dis = n_dis self.lr_decay = lr_decay self.dataloader = dataloader self.num_steps = num_steps self.device = device self.log_dir = log_dir self.netG_ckpt_file = netG_ckpt_file self.netD_ckpt_file = netD_ckpt_file self.print_steps = print_steps self.vis_steps = vis_steps self.log_steps = log_steps self.save_steps = save_steps if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) # Training helper objects self.logger = logger.Logger( log_dir=self.log_dir, num_steps=self.num_steps, dataset_size=len(self.dataloader), flush_secs=flush_secs, device=self.device, ) self.scheduler = scheduler.LRScheduler( lr_decay=self.lr_decay, optD=self.optD, optG=self.optG, num_steps=self.num_steps, ) # Obtain custom or latest checkpoint files if self.netG_ckpt_file: self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file) self.netG_ckpt_file = netG_ckpt_file else: self.netG_ckpt_dir = os.path.join(self.log_dir, "checkpoints", "netG") self.netG_ckpt_file = self._get_latest_checkpoint( self.netG_ckpt_dir) # can be None if self.netD_ckpt_file: self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file) self.netD_ckpt_file = netD_ckpt_file else: self.netD_ckpt_dir = os.path.join(self.log_dir, "checkpoints", "netD") self.netD_ckpt_file = self._get_latest_checkpoint( self.netD_ckpt_dir) # Log hyperparameters for experiments self.params = { "log_dir": self.log_dir, "num_steps": self.num_steps, "batch_size": self.dataloader.batch_size, "n_dis": self.n_dis, "lr_decay": self.lr_decay, "optD": optD.__repr__(), "optG": optG.__repr__(), "save_steps": self.save_steps, } self._log_params(self.params) # Device for hosting model and data if not self.device: self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # Ensure model and data are in the same device for net in [self.netD, self.netG]: if net.device != self.device: net.to(self.device)
def __init__(self, netD, netG, optD, optG, dataloader, num_steps, log_dir='./log', n_dis=1, lr_decay=None, device=None, netG_ckpt_file=None, netD_ckpt_file=None, print_steps=1, vis_steps=1000, log_steps=100, save_steps=10000, flush_secs=60): # Input values checks ints_to_check = { 'num_steps': num_steps, 'n_dis': n_dis, 'print_steps': print_steps, 'vis_steps': vis_steps, 'log_steps': log_steps, 'save_steps': save_steps, 'flush_secs': flush_secs } for name, var in ints_to_check.items(): if var < 1: raise ValueError('{} must be at least 1 but got {}.'.format( name, var)) self.netD = netD self.netG = netG self.optD = optD self.optG = optG self.n_dis = n_dis self.lr_decay = lr_decay self.dataloader = dataloader self.num_steps = num_steps self.device = device self.log_dir = log_dir self.netG_ckpt_file = netG_ckpt_file self.netD_ckpt_file = netD_ckpt_file self.print_steps = print_steps self.vis_steps = vis_steps self.log_steps = log_steps self.save_steps = save_steps if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) # Training helper objects self.logger = logger.Logger(log_dir=self.log_dir, num_steps=self.num_steps, dataset_size=len(self.dataloader), flush_secs=flush_secs, device=self.device) self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay, optD=self.optD, optG=self.optG, num_steps=self.num_steps) # Obtain custom or latest checkpoint files if self.netG_ckpt_file: self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file) self.netG_ckpt_file = netG_ckpt_file else: self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netG') self.netG_ckpt_file = self._get_latest_checkpoint( self.netG_ckpt_dir) # can be None if self.netD_ckpt_file: self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file) self.netD_ckpt_file = netD_ckpt_file else: self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netD') self.netD_ckpt_file = self._get_latest_checkpoint( self.netD_ckpt_dir) # Log hyperparameters for experiments self.params = { 'log_dir': self.log_dir, 'batch_size': self.dataloader.batch_size, 'n_dis': self.n_dis, 'lr_decay': self.lr_decay, 'optD': optD.__repr__(), 'optG': optG.__repr__(), 'save_steps': self.save_steps, } self._log_params(self.params) # Device for hosting model and data if not self.device: self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else "cpu") # Ensure model and data are in the same device for net in [self.netD, self.netG]: if net.device != self.device: net.to(self.device)
def __init__( self, output_path, netD, netG, optD, optG, dataloader, num_steps, netD_drs=None, optD_drs=None, dataloader_drs=None, netD_drs_ckpt_file=None, log_dir='./log', n_dis=1, lr_decay=None, device=None, netG_ckpt_file=None, netD_ckpt_file=None, print_steps=1, vis_steps=500, log_steps=50, save_steps=5000, flush_secs=30, logit_save_steps=500, amp=False, save_logits=True, topk=False, gold=False, gold_step=None, save_logit_after=0, stop_save_logit_after=100000, save_eval_logits=True, ): self.output_path = output_path self.logit_save_steps = logit_save_steps self.netD = netD self.netG = netG self.optD = optD self.optG = optG self.n_dis = n_dis self.lr_decay = lr_decay self.dataloader = dataloader self.num_steps = num_steps self.device = device self.log_dir = log_dir self.netG_ckpt_file = netG_ckpt_file self.netD_ckpt_file = netD_ckpt_file self.print_steps = print_steps self.vis_steps = vis_steps self.log_steps = log_steps self.save_steps = save_steps self.amp = amp self.save_logits = save_logits self.save_logit_after = save_logit_after self.stop_save_logit_after = stop_save_logit_after self.save_eval_logits = save_eval_logits # for DRS self.netD_drs = netD_drs self.dataloader_drs = dataloader_drs self.optD_drs = optD_drs self.netD_drs_ckpt_file = netD_drs_ckpt_file self.topk = topk self.gold = gold self.gold_step = gold_step if self.gold: assert self.gold_step is not None if self.netD_drs is not None: assert self.dataloader_drs is not None and self.optD_drs is not None self.train_drs = True else: self.train_drs = False # Input values checks ints_to_check = { 'num_steps': num_steps, 'n_dis': n_dis, 'print_steps': print_steps, 'vis_steps': vis_steps, 'log_steps': log_steps, 'save_steps': save_steps, 'flush_secs': flush_secs } for name, var in ints_to_check.items(): if var < 1: raise ValueError('{} must be at least 1 but got {}.'.format( name, var)) if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) # Training helper objects self.logger = logger.Logger(log_dir=self.log_dir, num_steps=self.num_steps, dataset_size=len(self.dataloader), flush_secs=flush_secs, device=self.device) self.scheduler = DRS_LRScheduler( lr_decay=self.lr_decay, optimizers=[o for o in [self.optD, self.optD_drs, self.optG] if o is not None], num_steps=self.num_steps) self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netG') self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netD') self.netD_drs_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netD_drs') if self.train_drs else None # Device for hosting model and data if not self.device: self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else "cpu") # Ensure model and data are in the same device for net in [self.netD, self.netG, self.netD_drs]: if net is not None and net.device != self.device: net.to(self.device)