Esempio n. 1
0
    def setup(self):
        self.log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                    "test_log")

        self.logger = logger.Logger(log_dir=self.log_dir,
                                    num_steps=100,
                                    dataset_size=50000,
                                    flush_secs=30,
                                    device=torch.device('cpu'))

        self.scalars = [
            'errG',
            'errD',
            'D(x)',
            'D(G(z))',
            'img',
            'lr_D',
            'lr_G',
        ]
Esempio n. 2
0
    def __init__(self,
                 netD,
                 netG,
                 optD,
                 optG,
                 dataloader,
                 num_steps,
                 n_gpus,
                 log_dir='./log',
                 n_dis=1,
                 netG_ckpt_file=None,
                 netD_ckpt_file=None,
                 lr_decay=None,
                 device=None,
                 **kwargs):
        self.netD = netD
        self.netG = netG
        self.optD = optD
        self.optG = optG
        self.n_dis = n_dis
        self.lr_decay = lr_decay
        self.dataloader = dataloader
        self.num_steps = num_steps
        self.device = device

        # Log directory
        self.log_dir = log_dir
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # Obtain custom or latest checkpoint files
        if netG_ckpt_file:
            self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file)
            self.netG_ckpt_file = netG_ckpt_file
        else:
            self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
                                              'netG')
            self.netG_ckpt_file = self._get_latest_checkpoint(
                self.netG_ckpt_dir)  # can be None

        if netD_ckpt_file:
            self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file)
            self.netD_ckpt_file = netD_ckpt_file
        else:
            self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
                                              'netD')
            self.netD_ckpt_file = self._get_latest_checkpoint(
                self.netD_ckpt_dir)  # can be None

        # Default parameters, unless provided by kwargs
        default_params = {
            'print_steps': kwargs.get('print_steps', 1),
            'vis_steps': kwargs.get('vis_steps', 500),
            'flush_secs': kwargs.get('flush_secs', 30),
            'log_steps': kwargs.get('log_steps', 50),
            'save_steps': kwargs.get('save_steps', 5000),
            'save_when_end': kwargs.get('save_when_end', True),
        }
        for param in default_params:
            self.__dict__[param] = default_params[param]

        # Hyperparameters for logging experiments
        self.params = {
            'log_dir': self.log_dir,
            'num_steps': self.num_steps,
            'batch_size': self.dataloader.batch_size,
            'n_dis': self.n_dis,
            'lr_decay': self.lr_decay,
            'optD': optD.__repr__(),
            'optG': optG.__repr__(),
        }
        self.params.update(default_params)

        # Log training hyperparmaeters
        self._log_params(self.params)

        # Device for hosting model and data
        if not self.device:
            self.device = torch.device(
                'cuda:0' if torch.cuda.is_available() else "cpu")

        # Ensure model and data are in the same device
        for net in [self.netD, self.netG]:
            if net.device != self.device:
                net.to(self.device)

        # Training helper objects
        self.logger = logger.Logger(log_dir=self.log_dir,
                                    num_steps=self.num_steps,
                                    dataset_size=len(self.dataloader),
                                    flush_secs=self.flush_secs,
                                    device=self.device)

        self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay,
                                               optD=self.optD,
                                               optG=self.optG,
                                               num_steps=self.num_steps)
Esempio n. 3
0
    def __init__(
        self,
        netD,
        netG,
        optD,
        optG,
        dataloader,
        num_steps,
        log_dir="./log",
        n_dis=1,
        lr_decay=None,
        device=None,
        netG_ckpt_file=None,
        netD_ckpt_file=None,
        print_steps=1,
        vis_steps=500,
        log_steps=50,
        save_steps=5000,
        flush_secs=30,
    ):
        # Input values checks
        ints_to_check = {
            "num_steps": num_steps,
            "n_dis": n_dis,
            "print_steps": print_steps,
            "vis_steps": vis_steps,
            "log_steps": log_steps,
            "save_steps": save_steps,
            "flush_secs": flush_secs,
        }
        for name, var in ints_to_check.items():
            if var < 1:
                raise ValueError("{} must be at least 1 but got {}.".format(
                    name, var))

        self.netD = netD
        self.netG = netG
        self.optD = optD
        self.optG = optG
        self.n_dis = n_dis
        self.lr_decay = lr_decay
        self.dataloader = dataloader
        self.num_steps = num_steps
        self.device = device
        self.log_dir = log_dir
        self.netG_ckpt_file = netG_ckpt_file
        self.netD_ckpt_file = netD_ckpt_file
        self.print_steps = print_steps
        self.vis_steps = vis_steps
        self.log_steps = log_steps
        self.save_steps = save_steps

        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # Training helper objects
        self.logger = logger.Logger(
            log_dir=self.log_dir,
            num_steps=self.num_steps,
            dataset_size=len(self.dataloader),
            flush_secs=flush_secs,
            device=self.device,
        )

        self.scheduler = scheduler.LRScheduler(
            lr_decay=self.lr_decay,
            optD=self.optD,
            optG=self.optG,
            num_steps=self.num_steps,
        )

        # Obtain custom or latest checkpoint files
        if self.netG_ckpt_file:
            self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file)
            self.netG_ckpt_file = netG_ckpt_file
        else:
            self.netG_ckpt_dir = os.path.join(self.log_dir, "checkpoints",
                                              "netG")
            self.netG_ckpt_file = self._get_latest_checkpoint(
                self.netG_ckpt_dir)  # can be None

        if self.netD_ckpt_file:
            self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file)
            self.netD_ckpt_file = netD_ckpt_file
        else:
            self.netD_ckpt_dir = os.path.join(self.log_dir, "checkpoints",
                                              "netD")
            self.netD_ckpt_file = self._get_latest_checkpoint(
                self.netD_ckpt_dir)

        # Log hyperparameters for experiments
        self.params = {
            "log_dir": self.log_dir,
            "num_steps": self.num_steps,
            "batch_size": self.dataloader.batch_size,
            "n_dis": self.n_dis,
            "lr_decay": self.lr_decay,
            "optD": optD.__repr__(),
            "optG": optG.__repr__(),
            "save_steps": self.save_steps,
        }
        self._log_params(self.params)

        # Device for hosting model and data
        if not self.device:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")

        # Ensure model and data are in the same device
        for net in [self.netD, self.netG]:
            if net.device != self.device:
                net.to(self.device)
Esempio n. 4
0
    def __init__(self,
                 netD,
                 netG,
                 optD,
                 optG,
                 dataloader,
                 num_steps,
                 log_dir='./log',
                 n_dis=1,
                 lr_decay=None,
                 device=None,
                 netG_ckpt_file=None,
                 netD_ckpt_file=None,
                 print_steps=1,
                 vis_steps=1000,
                 log_steps=100,
                 save_steps=10000,
                 flush_secs=60):
        # Input values checks
        ints_to_check = {
            'num_steps': num_steps,
            'n_dis': n_dis,
            'print_steps': print_steps,
            'vis_steps': vis_steps,
            'log_steps': log_steps,
            'save_steps': save_steps,
            'flush_secs': flush_secs
        }
        for name, var in ints_to_check.items():
            if var < 1:
                raise ValueError('{} must be at least 1 but got {}.'.format(
                    name, var))

        self.netD = netD
        self.netG = netG
        self.optD = optD
        self.optG = optG
        self.n_dis = n_dis
        self.lr_decay = lr_decay
        self.dataloader = dataloader
        self.num_steps = num_steps
        self.device = device
        self.log_dir = log_dir
        self.netG_ckpt_file = netG_ckpt_file
        self.netD_ckpt_file = netD_ckpt_file
        self.print_steps = print_steps
        self.vis_steps = vis_steps
        self.log_steps = log_steps
        self.save_steps = save_steps

        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # Training helper objects
        self.logger = logger.Logger(log_dir=self.log_dir,
                                    num_steps=self.num_steps,
                                    dataset_size=len(self.dataloader),
                                    flush_secs=flush_secs,
                                    device=self.device)

        self.scheduler = scheduler.LRScheduler(lr_decay=self.lr_decay,
                                               optD=self.optD,
                                               optG=self.optG,
                                               num_steps=self.num_steps)

        # Obtain custom or latest checkpoint files
        if self.netG_ckpt_file:
            self.netG_ckpt_dir = os.path.dirname(netG_ckpt_file)
            self.netG_ckpt_file = netG_ckpt_file
        else:
            self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
                                              'netG')
            self.netG_ckpt_file = self._get_latest_checkpoint(
                self.netG_ckpt_dir)  # can be None

        if self.netD_ckpt_file:
            self.netD_ckpt_dir = os.path.dirname(netD_ckpt_file)
            self.netD_ckpt_file = netD_ckpt_file
        else:
            self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints',
                                              'netD')
            self.netD_ckpt_file = self._get_latest_checkpoint(
                self.netD_ckpt_dir)

        # Log hyperparameters for experiments
        self.params = {
            'log_dir': self.log_dir,
            'batch_size': self.dataloader.batch_size,
            'n_dis': self.n_dis,
            'lr_decay': self.lr_decay,
            'optD': optD.__repr__(),
            'optG': optG.__repr__(),
            'save_steps': self.save_steps,
        }
        self._log_params(self.params)

        # Device for hosting model and data
        if not self.device:
            self.device = torch.device(
                'cuda:0' if torch.cuda.is_available() else "cpu")

        # Ensure model and data are in the same device
        for net in [self.netD, self.netG]:
            if net.device != self.device:
                net.to(self.device)
Esempio n. 5
0
    def __init__(
        self,
        output_path,
        netD,
        netG,
        optD,
        optG,
        dataloader,
        num_steps,
        netD_drs=None,
        optD_drs=None,
        dataloader_drs=None,
        netD_drs_ckpt_file=None,
        log_dir='./log',
        n_dis=1,
        lr_decay=None,
        device=None,
        netG_ckpt_file=None,
        netD_ckpt_file=None,
        print_steps=1,
        vis_steps=500,
        log_steps=50,
        save_steps=5000,
        flush_secs=30,
        logit_save_steps=500,
        amp=False,
        save_logits=True,
        topk=False,
        gold=False,
        gold_step=None,
        save_logit_after=0,
        stop_save_logit_after=100000,
        save_eval_logits=True,
    ):
        self.output_path = output_path
        self.logit_save_steps = logit_save_steps

        self.netD = netD
        self.netG = netG
        self.optD = optD
        self.optG = optG
        self.n_dis = n_dis
        self.lr_decay = lr_decay
        self.dataloader = dataloader
        self.num_steps = num_steps
        self.device = device
        self.log_dir = log_dir
        self.netG_ckpt_file = netG_ckpt_file
        self.netD_ckpt_file = netD_ckpt_file
        self.print_steps = print_steps
        self.vis_steps = vis_steps
        self.log_steps = log_steps
        self.save_steps = save_steps
        self.amp = amp
        self.save_logits = save_logits
        self.save_logit_after = save_logit_after
        self.stop_save_logit_after = stop_save_logit_after
        self.save_eval_logits = save_eval_logits

        # for DRS
        self.netD_drs = netD_drs
        self.dataloader_drs = dataloader_drs
        self.optD_drs = optD_drs
        self.netD_drs_ckpt_file = netD_drs_ckpt_file

        self.topk = topk
        self.gold = gold
        self.gold_step = gold_step

        if self.gold:
            assert self.gold_step is not None

        if self.netD_drs is not None:
            assert self.dataloader_drs is not None and self.optD_drs is not None
            self.train_drs = True
        else:
            self.train_drs = False
        
        # Input values checks
        ints_to_check = {
            'num_steps': num_steps,
            'n_dis': n_dis,
            'print_steps': print_steps,
            'vis_steps': vis_steps,
            'log_steps': log_steps,
            'save_steps': save_steps,
            'flush_secs': flush_secs
        }
        for name, var in ints_to_check.items():
            if var < 1:
                raise ValueError('{} must be at least 1 but got {}.'.format(
                    name, var))

        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # Training helper objects
        self.logger = logger.Logger(log_dir=self.log_dir,
                                    num_steps=self.num_steps,
                                    dataset_size=len(self.dataloader),
                                    flush_secs=flush_secs,
                                    device=self.device)

        self.scheduler = DRS_LRScheduler(
            lr_decay=self.lr_decay,
            optimizers=[o for o in [self.optD, self.optD_drs, self.optG] if o is not None],
            num_steps=self.num_steps)

        self.netG_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netG')
        self.netD_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netD')
        self.netD_drs_ckpt_dir = os.path.join(self.log_dir, 'checkpoints', 'netD_drs') if self.train_drs else None

        # Device for hosting model and data
        if not self.device:
            self.device = torch.device(
                'cuda:0' if torch.cuda.is_available() else "cpu")

        # Ensure model and data are in the same device
        for net in [self.netD, self.netG, self.netD_drs]:
            if net is not None and net.device != self.device:
                net.to(self.device)