Exemplo n.º 1
0
    def test_train_steps(self):
        real_batch = common.load_images(self.N, size=self.H)

        # Setup optimizers
        optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
        optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))

        # Log statistics to check
        log_data = metric_log.MetricLog()

        # Test D train step
        log_data = self.netD.train_step(real_batch=real_batch,
                                        netG=self.netG,
                                        optD=optD,
                                        device='cpu',
                                        log_data=log_data)

        log_data = self.netG.train_step(real_batch=real_batch,
                                        netD=self.netD,
                                        optG=optG,
                                        log_data=log_data,
                                        device='cpu')

        for name, metric_dict in log_data.items():
            assert type(name) == str
            assert type(metric_dict['value']) == float
Exemplo n.º 2
0
    def test_no_decay(self):
        optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
        optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))

        lr_scheduler = scheduler.LRScheduler(lr_decay='None',
                                             optD=optD,
                                             optG=optG,
                                             num_steps=self.num_steps)

        log_data = metric_log.MetricLog()
        for step in range(1, self.num_steps + 1):
            lr_scheduler.step(log_data, step)

            assert (self.lr_D == self.get_lr(optD))
            assert (self.lr_G == self.get_lr(optG))
Exemplo n.º 3
0
    def test_linear_decay(self):
        optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
        optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))

        lr_scheduler = scheduler.LRScheduler(lr_decay='linear',
                                             optD=optD,
                                             optG=optG,
                                             num_steps=self.num_steps)

        log_data = metric_log.MetricLog()
        for step in range(1, self.num_steps + 1):
            lr_scheduler.step(log_data, step)

            curr_lr = ((1 - step / self.num_steps) * self.lr_D)

            assert (curr_lr - self.get_lr(optD) < 1e-5)
            assert (curr_lr - self.get_lr(optG) < 1e-5)
Exemplo n.º 4
0
    def test_print_log(self):
        log_data = metric_log.MetricLog()
        global_step = 10

        # Populate log data with some value
        for scalar in self.scalars:
            if scalar == 'img':
                continue

            log_data.add_metric(scalar, 1.0)

        printed = self.logger.print_log(global_step=global_step,
                                        log_data=log_data,
                                        time_taken=10)

        assert printed == (
            'INFO: [Epoch 1/1][Global Step: 10/100] ' +
            '\n| D(G(z)): 1.0\n| D(x): 1.0\n| errD: 1.0\n| errG: 1.0' +
            '\n| lr_D: 1.0\n| lr_G: 1.0\n| (10.0000 sec/idx)')
Exemplo n.º 5
0
    def train(self):
        """
        Runs the training pipeline with all given parameters in Trainer.
        """
        # Restore models
        global_step = self._restore_models_and_step()
        print("INFO: Starting training from global step {}...".format(
            global_step))

        try:
            start_time = time.time()

            # Iterate through data
            iter_dataloader = iter(self.dataloader)
            while global_step < self.num_steps:
                log_data = metric_log.MetricLog()  # log data for tensorboard

                # -------------------------
                #   One Training Step
                # -------------------------
                # Update n_dis times for D
                for i in range(self.n_dis):
                    iter_dataloader, real_batch = self._fetch_data(
                        iter_dataloader=iter_dataloader)

                    # -----------------------
                    #   Update G Network
                    # -----------------------
                    # Update G, but only once.
                    if i == 0:
                        log_data = self.netG.train_step(
                            real_batch=real_batch,
                            netD=self.netD,
                            optG=self.optG,
                            global_step=global_step,
                            log_data=log_data,
                            device=self.device)

                    # ------------------------
                    #   Update D Network
                    # -----------------------
                    log_data = self.netD.train_step(real_batch=real_batch,
                                                    netG=self.netG,
                                                    optD=self.optD,
                                                    log_data=log_data,
                                                    global_step=global_step,
                                                    device=self.device)

                # --------------------------------
                #   Update Training Variables
                # -------------------------------
                global_step += 1

                log_data = self.scheduler.step(log_data=log_data,
                                               global_step=global_step)

                # -------------------------
                #   Logging and Metrics
                # -------------------------
                if global_step % self.log_steps == 0:
                    self.logger.write_summaries(log_data=log_data,
                                                global_step=global_step)

                if global_step % self.print_steps == 0:
                    curr_time = time.time()
                    self.logger.print_log(global_step=global_step,
                                          log_data=log_data,
                                          time_taken=(curr_time - start_time) /
                                          self.print_steps)
                    start_time = curr_time

                if global_step % self.vis_steps == 0:
                    self.logger.vis_images(netG=self.netG,
                                           global_step=global_step)

                if global_step % self.save_steps == 0:
                    print("INFO: Saving checkpoints...")
                    self.netG.save_checkpoint(directory=self.netG_ckpt_dir,
                                              global_step=global_step,
                                              optimizer=self.optG)

                    self.netD.save_checkpoint(directory=self.netD_ckpt_dir,
                                              global_step=global_step,
                                              optimizer=self.optD)

            # Save models at the very end of training
            if self.save_when_end:
                print("INFO: Saving final checkpoints...")
                self.netG.save_checkpoint(directory=self.netG_ckpt_dir,
                                          global_step=global_step,
                                          optimizer=self.optG)

                self.netD.save_checkpoint(directory=self.netD_ckpt_dir,
                                          global_step=global_step,
                                          optimizer=self.optD)

        except KeyboardInterrupt:
            print("INFO: Saving checkpoints from keyboard interrupt...")
            self.netG.save_checkpoint(directory=self.netG_ckpt_dir,
                                      global_step=global_step,
                                      optimizer=self.optG)

            self.netD.save_checkpoint(directory=self.netD_ckpt_dir,
                                      global_step=global_step,
                                      optimizer=self.optD)

        finally:
            self.logger.close_writers()

        print("INFO: Training Ended.")
Exemplo n.º 6
0
    def train(self, upload_path=None):
        """
        Runs the training pipeline with all given parameters in Trainer.
        """
        # Restore models
        global_step = self._restore_models_and_step()
        print("INFO: Starting training from global step {}...".format(
            global_step))

        try:
            start_time = time.time()

            # Iterate through data
            iter_dataloader = iter(self.dataloader)
            while global_step < self.num_steps:
                log_data = metric_log.MetricLog()  # log data for tensorboard

                # -------------------------
                #   One Training Step
                # -------------------------
                # Update n_dis times for D
                for i in range(self.n_dis):
                    iter_dataloader, real_batch = self._fetch_data(
                        iter_dataloader=iter_dataloader)

                    # ------------------------
                    #   Update D Network
                    # -----------------------
                    log_data = self.netD.train_step(
                        real_batch=real_batch,
                        netG=self.netG,
                        optD=self.optD,
                        log_data=log_data,
                        global_step=global_step,
                        device=self.device,
                    )

                    # -----------------------
                    #   Update G Network
                    # -----------------------
                    # Update G, but only once.
                    if i == (self.n_dis - 1):
                        log_data = self.netG.train_step(
                            real_batch=real_batch,
                            netD=self.netD,
                            optG=self.optG,
                            global_step=global_step,
                            log_data=log_data,
                            device=self.device,
                        )

                # --------------------------------
                #   Update Training Variables
                # -------------------------------
                global_step += 1

                log_data = self.scheduler.step(log_data=log_data,
                                               global_step=global_step)

                # -------------------------
                #   Logging and Metrics
                # -------------------------
                if global_step % self.log_steps == 0:
                    self.logger.write_summaries(log_data=log_data,
                                                global_step=global_step)

                if global_step % self.print_steps == 0:
                    curr_time = time.time()
                    self.logger.print_log(
                        global_step=global_step,
                        log_data=log_data,
                        time_taken=(curr_time - start_time) / self.print_steps,
                    )
                    start_time = curr_time

                if global_step % self.vis_steps == 0:
                    self.logger.vis_images(netG=self.netG,
                                           global_step=global_step)

                if global_step % self.save_steps == 0:
                    print("INFO: Saving checkpoints...")
                    self._save_model_checkpoints(global_step)
                    if upload_path is not None:
                        call_string = "rsync --update -arq " + "'" + self.log_dir + "'" + " " + "'" + upload_path + "'"
                        result = subprocess.call(call_string, shell=True)
                        if result is 0:
                            print("INFO: Upload checkpoints SUCCESSFUL")
                        elif result is 1:
                            print("INFO: Upload checkpoints FAILED")

            print("INFO: Saving final checkpoints...")
            self._save_model_checkpoints(global_step)

        except KeyboardInterrupt:
            print("INFO: Saving checkpoints from keyboard interrupt...")
            self._save_model_checkpoints(global_step)

        finally:
            self.logger.close_writers()

        print("INFO: Training Ended.")
Exemplo n.º 7
0
    def train(self):
        """
        Runs the training pipeline with all given parameters in Trainer.
        """
        # Restore models
        global_step = self._restore_models_and_step()
        print("INFO: Starting training from global step {}...".format(
            global_step))

        try:
            start_time = time.time()

            # Iterate through data
            iter_dataloader = iter(self.dataloader)
            while global_step < self.num_steps:
                log_data = metric_log.MetricLog()  # log data for tensorboard

                # -------------------------
                #   One Training Step
                # -------------------------
                # Update n_dis times for D

                #R=np.random.randint(18,22,1)

                for i in range(self.n_dis):
                    iter_dataloader, real_batch = self._fetch_data(
                        iter_dataloader=iter_dataloader)

                    # ------------------------
                    #   Update D Network
                    # -----------------------
                    '''
                    normal training
                    log_data = self.netD.train_step(real_batch=real_batch,
                                                    netG=self.netG,
                                                    optD=self.optD,
                                                    log_data=log_data,
                                                    global_step=global_step,
                                                    #radius=R,
                                                    device=self.device)
                    '''
                    #DAT
                    log_data = self.netD.advtrain_step(real_batch=real_batch,
                                                       netG=self.netG,
                                                       optD=self.optD,
                                                       log_data=log_data,
                                                       global_step=global_step,
                                                       device=self.device)

                    # -----------------------
                    #   Update G Network
                    # -----------------------
                    # Update G, but only once.
                    if i == (self.n_dis - 1):
                        log_data = self.netG.train_step(
                            real_batch=real_batch,
                            netD=self.netD,
                            optG=self.optG,
                            global_step=global_step,
                            log_data=log_data,
                            #radius=R,
                            device=self.device)

                # --------------------------------
                #   Update Training Variables
                # -------------------------------
                global_step += 1

                log_data = self.scheduler.step(log_data=log_data,
                                               global_step=global_step)

                # -------------------------
                #   Logging and Metrics
                # -------------------------
                if global_step % self.log_steps == 0:
                    self.logger.write_summaries(log_data=log_data,
                                                global_step=global_step)

                if global_step % self.print_steps == 0:
                    curr_time = time.time()
                    self.logger.print_log(global_step=global_step,
                                          log_data=log_data,
                                          time_taken=(curr_time - start_time) /
                                          self.print_steps)
                    start_time = curr_time

                if global_step % self.vis_steps == 0:
                    self.logger.vis_images(netG=self.netG,
                                           global_step=global_step)

                    self.logger.summary_fid(netG=self.netG,
                                            dataset=self.dataset,
                                            global_step=global_step)
                    self.logger.summary_IS(netG=self.netG,
                                           global_step=global_step)
                    self.logger.summary_KID(netG=self.netG,
                                            dataset=self.dataset,
                                            global_step=global_step)

                if global_step % self.save_steps == 0:
                    print("INFO: Saving checkpoints...")
                    self._save_model_checkpoints(global_step)

            print("INFO: Saving final checkpoints...")
            self._save_model_checkpoints(global_step)

        except KeyboardInterrupt:
            print("INFO: Saving checkpoints from keyboard interrupt...")
            self._save_model_checkpoints(global_step)

        finally:
            self.logger.close_writers()

        print("INFO: Training Ended.")
Exemplo n.º 8
0
    def train(self):
        """
        Runs the training pipeline with all given parameters in Trainer.
        """
        # Restore models
        global_step = self._restore_models_and_step()
        
        if self.gold and global_step >= self.gold_step:
            self.netD.use_gold = True

        print("INFO: Starting training from global step {}...".format(
            global_step))
        logit_save_num = 0

        self.logit_results = defaultdict(dict)

        try:
            start_time = time.time()

            # Mixed precision
            if self.amp:
                print("INFO: Using mixed precision training...")
                scaler = torch.cuda.amp.GradScaler()
            else:
                scaler = None

            # Iterate through data
            iter_dataloader = iter(self.dataloader)
            if self.train_drs:
                iter_dataloader_drs = iter(self.dataloader_drs)
            while global_step < self.num_steps:
                log_data = metric_log.MetricLog()  # log data for tensorboard

                if self.topk:
                    self.netG.decay_topk_rate(global_step, epoch_steps=len(self.dataloader))

                if self.gold and global_step == self.gold_step:
                    self.netD.use_gold = True
                # -------------------------
                #   One Training Step
                # -------------------------
                # Update n_dis times for D
                for i in range(self.n_dis):
                    iter_dataloader, real_batch = self._fetch_data(
                        iter_dataloader=iter_dataloader)

                    # ------------------------
                    #   Update D Network
                    # -----------------------
                    log_data = self.netD.train_step(
                        real_batch=real_batch,
                        netG=self.netG,
                        optD=self.optD,
                        log_data=log_data,
                        global_step=global_step,
                        device=self.device,
                        scaler=scaler)

                    # train netD2 for DRS
                    if self.train_drs:
                        iter_dataloader_drs, real_batch_drs = self._fetch_data(
                            iter_dataloader=iter_dataloader_drs)
                        log_data = self.netD_drs.train_step(
                            real_batch=real_batch_drs,
                            netG=self.netG,
                            optD=self.optD_drs,
                            log_data=log_data,
                            global_step=global_step,
                            device=self.device,
                            scaler=scaler)

                    # -----------------------
                    #   Update G Network
                    # -----------------------
                    # Update G, but only once.
                    if i == (self.n_dis - 1):
                        log_data = self.netG.train_step(
                            real_batch=real_batch,
                            netD=self.netD,
                            optG=self.optG,
                            global_step=global_step,
                            log_data=log_data,
                            device=self.device,
                            scaler=scaler)

                # --------------------------------
                #   Update Training Variables
                # -------------------------------
                global_step += 1

                log_data = self.scheduler.step(log_data=log_data,
                                               global_step=global_step)

                # -------------------------
                #   Logging and Metrics
                # -------------------------
                if global_step % self.log_steps == 0:
                    self.logger.write_summaries(log_data=log_data,
                                                global_step=global_step)

                if global_step % self.print_steps == 0:
                    curr_time = time.time()
                    topk_rate = self.netG.topk_rate if hasattr(self.netG, 'topk_rate') else 1
                    log_data.add_metric(f'topk_rate', topk_rate, group='topk_rate', precision=6)
                    self.logger.print_log(global_step=global_step,
                                          log_data=log_data,
                                          time_taken=(curr_time - start_time) /
                                          self.print_steps)
                    start_time = curr_time

                if global_step % self.vis_steps == 0:
                    if 'gaussian' in self.log_dir:
                        plot_gaussian_samples(netG=self.netG,
                                              global_step=global_step,
                                              log_dir=self.log_dir,
                                              device=self.device)
                    else:
                        self.logger.vis_images(netG=self.netG,
                                               global_step=global_step)
                
                if self.save_logits and global_step % self.logit_save_steps == 0 and global_step >= self.save_logit_after and global_step <= self.stop_save_logit_after:
                    if self.train_drs:
                        netD = self.netD_drs
                        netD_name = 'netD_drs'
                    else:
                        netD = self.netD
                        netD_name = 'netD'
                    mode = 'eval' if self.save_eval_logits else 'train'
                    print(f"INFO: logit saving {mode} netD: {netD_name}...")
                    logit_list = self._get_logit(netD=netD, eval_mode=mode=='eval')
                    self.logit_results[f'{netD_name}_{mode}'][global_step] = logit_list

                    logit_save_num += 1

                if global_step % self.save_steps == 0:
                    print("INFO: Saving checkpoints...")
                    self._save_model_checkpoints(global_step)
                    if self.save_logits and global_step >= self.save_logit_after:
                        self._save_logit(self.logit_results)

            print("INFO: Saving final checkpoints...")
            self._save_model_checkpoints(global_step)
            if self.save_logits and global_step >= self.save_logit_after:
                self._save_logit(self.logit_results)

        except KeyboardInterrupt:
            print("INFO: Saving checkpoints from keyboard interrupt...")
            self._save_model_checkpoints(global_step)
            if self.save_logits and global_step >= self.save_logit_after:
                self._save_logit(self.logit_results)

        finally:
            self.logger.close_writers()

        print("INFO: Training Ended.")
Exemplo n.º 9
0
 def setup(self):
     self.log_data = metric_log.MetricLog()