def run(self, train_set, valid_set, epochs: int, batch_size: int, num_workers: int = 0, device: str = 'cuda', **kwargs): # pylint: disable=unused-argument assert isinstance(train_set, torch.utils.data.Dataset) assert isinstance(valid_set, torch.utils.data.Dataset) assert isinstance(epochs, int) assert isinstance(batch_size, int) assert isinstance(num_workers, int) assert device.startswith('cuda') or device == 'cpu' logger = kwargs.get('logger', None) self.backbone = self.backbone.to(device) self.projector = self.projector.to(device) train_loader = get_dataloader(train_set, batch_size, num_workers=num_workers) valid_loader = get_dataloader(valid_set, batch_size, num_workers=num_workers) with tqdm.tqdm(**get_tqdm_config(total=epochs, leave=True, color='blue')) as pbar: best_valid_loss = float('inf') best_epoch = 0 for epoch in range(1, epochs + 1): # 0. Train & evaluate train_history = self.train(train_loader, device=device) valid_history = self.evaluate(valid_loader, device=device) # 1. Epoch history (loss) epoch_history = { 'loss': { 'train': train_history.get('loss'), 'valid': valid_history.get('loss'), } } # 2. Epoch history (other metrics if provided) if self.metrics is not None: raise NotImplementedError # 3. TensorBoard if self.writer is not None: for metric_name, metric_dict in epoch_history.items(): self.writer.add_scalars( main_tag=metric_name, tag_scalar_dict=metric_dict, global_step=epoch ) if self.scheduler is not None: self.writer.add_scalar( tag='lr', scalar_value=self.scheduler.get_last_lr()[0], global_step=epoch ) # 4. Save model if it is the current best valid_loss = epoch_history['loss']['valid'] if valid_loss < best_valid_loss: best_valid_loss = valid_loss best_epoch = epoch self.save_checkpoint(self.best_ckpt, epoch=epoch, **epoch_history) if kwargs.get('save_every', False): new_ckpt = os.path.join(self.checkpoint_dir, f'epoch_{epoch:04d}.loss_{valid_loss:.4f}.pt') self.save_checkpoint(new_ckpt, epoch=epoch, **epoch_history) # 5. Update learning rate scheduler if self.scheduler is not None: self.scheduler.step() # 6. Logging desc = make_epoch_description( history=epoch_history, current=epoch, total=epochs, best=best_epoch ) pbar.set_description_str(desc) pbar.update(1) if logger is not None: logger.info(desc) # 7. Save last model self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history) # 8. Test model (optional) if 'test_set' in kwargs.keys(): test_loader = get_dataloader(kwargs.get('test_set'), batch_size=batch_size, num_workers=num_workers) self.test(test_loader, device=device, logger=logger)
def run(self, train_set, valid_set, epochs: int, batch_size: int, num_workers: int = 0, **kwargs): """Train, evaluate and optionally test.""" logger = kwargs.get('logger', None) self.backbone.to(self.local_rank) self.classifier.to(self.local_rank) train_loader = balanced_loader(train_set, batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) valid_loader = DataLoader(valid_set, batch_size, num_workers=num_workers, shuffle=True, drop_last=False, pin_memory=False) with tqdm.tqdm(**get_tqdm_config( total=epochs, leave=True, color='blue')) as pbar: best_valid_loss = float('inf') best_epoch = 0 for epoch in range(1, epochs + 1): # 0. Train & evaluate train_history = self.train(train_loader) valid_history = self.evaluate(valid_loader) # 1. Epoch history (loss) epoch_history = { 'loss': { 'train': train_history.get('loss'), 'valid': valid_history.get('loss') } } # 2. Epoch history (other metrics if provided) if isinstance(self.metrics, dict): for metric_name, _ in self.metrics.items(): epoch_history[metric_name] = { 'train': train_history[metric_name], 'valid': valid_history[metric_name], } # 3. Tensorboard if self.writer is not None: for metric_name, metric_dict in epoch_history.items(): self.writer.add_scalars(main_tag=metric_name, tag_scalar_dict=metric_dict, global_step=epoch) if self.scheduler is not None: self.writer.add_scalar( tag='lr', scalar_value=self.scheduler.get_last_lr()[0], global_step=epoch) # 4. Save model if it is the current best valid_loss = epoch_history['loss']['valid'] if valid_loss <= best_valid_loss: best_valid_loss = valid_loss best_epoch = epoch if self.local_rank == 0: self.save_checkpoint(self.best_ckpt, epoch=epoch, **epoch_history) # 5. Update learning rate scheduler (optional) if self.scheduler is not None: self.scheduler.step() # 6. Logging desc = make_epoch_description( history=epoch_history, current=epoch, total=epochs, best=best_epoch, ) pbar.set_description_str(desc) pbar.update(1) if logger is not None: logger.info(desc) # 7. Save last model self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history) # 8. Test model (optional) if self.local_rank == 0: if 'test_set' in kwargs.keys(): test_set = kwargs['test_set'] test_loader = DataLoader(test_set, batch_size, num_workers=num_workers, shuffle=True, drop_last=False, pin_memory=False) self.test(test_loader, logger=logger)
def run(self, train_set, eval_set, test_set: torch.utils.data.Dataset = None, save_every: int = 10, finetune: bool = False, **kwargs): # pylint: disable=unused-argument epochs = self.epochs batch_size = self.batch_size num_workers = self.num_workers if not self.prepared: raise RuntimeError("Training not prepared.") # DataLoader (train, val, test) sampler = DistributedSampler(train_set) if self.distributed else None shuffle = not self.distributed train_loader = DataLoader(train_set, batch_size=batch_size, sampler=sampler, shuffle=shuffle, num_workers=num_workers, drop_last=False, pin_memory=True) eval_loader = DataLoader(eval_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False, pin_memory=True) # Logging logger = kwargs.get('logger', None) # Supervised training best_eval_loss = float('inf') best_epoch = 0 for epoch in range(1, epochs + 1): if self.distributed: sampler.set_epoch(epoch) # Train & evaluate train_history = self.train(train_loader, finetune=finetune) eval_history = self.evaluate(eval_loader) epoch_history = collections.defaultdict(dict) for k, v1 in train_history.items(): epoch_history[k]['train'] = v1 try: v2 = eval_history[k] epoch_history[k]['eval'] = v2 except KeyError: continue # Write TensorBoard summary if self.writer is not None: for k, v in epoch_history.items(): self.writer.add_scalars(k, v, global_step=epoch) if self.scheduler is not None: lr = self.scheduler.get_last_lr()[0] self.writer.add_scalar('lr', lr, global_step=epoch) # Save best model checkpoint eval_loss = eval_history['loss'] if eval_loss <= best_eval_loss: best_eval_loss = eval_loss best_epoch = epoch if self.local_rank == 0: ckpt = os.path.join(self.ckpt_dir, f"ckpt.best.pth.tar") self.save_checkpoint(ckpt, epoch=epoch) # Save intermediate model checkpoints if (epoch % save_every == 0) & (self.local_rank == 0): ckpt = os.path.join(self.ckpt_dir, f"ckpt.{epoch}.pth.tar") self.save_checkpoint(ckpt, epoch=epoch) # Write logs log = make_epoch_description( history=epoch_history, current=epoch, total=epochs, best=best_epoch, ) if logger is not None: logger.info(log) # Update learning rate if self.scheduler is not None: self.scheduler.step() # Save final model checkpoint ckpt = os.path.join(self.ckpt_dir, f"ckpt.last.pth.tar") self.save_checkpoint(ckpt, epoch=epoch) # Test (optional) if test_set is not None: test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False, pin_memory=False) test_history = self.evaluate(test_loader) if (self.local_rank == 0) & (logger is not None): log = "Test: " for k, v in test_history.items(): log += f" {k}: {v:.4f} |" logger.info(log)
def run(self, train_set, valid_set, epochs: int, batch_size: int, num_workers: int = 0, device: str = 'cuda', **kwargs): """Train, evaluate and optionally test.""" assert isinstance(train_set, torch.utils.data.Dataset) assert isinstance(valid_set, torch.utils.data.Dataset) assert isinstance(epochs, int) assert isinstance(batch_size, int) assert isinstance(num_workers, int) assert device.startswith('cuda') or device == 'cpu' logger = kwargs.get('logger', None) disable_mixup = kwargs.get('disable_mixup', False) self.backbone = self.backbone.to(device) self.classifier = self.classifier.to(device) balance = kwargs.get('balance', False) if logger is not None: logger.info(f"Class balance: {balance}") shuffle = not balance train_loader = get_dataloader(train_set, batch_size, num_workers=num_workers, shuffle=shuffle, balance=balance) valid_loader = get_dataloader(valid_set, batch_size, num_workers=num_workers, balance=False) with tqdm.tqdm(**get_tqdm_config( total=epochs, leave=True, color='blue')) as pbar: # Determine model selection metric. Defaults to 'loss'. eval_metric = kwargs.get('eval_metric', 'loss') if eval_metric == 'loss': best_metric_val = float('inf') elif eval_metric in [ 'accuracy', 'precision', 'recall', 'f1', 'auroc', 'auprc' ]: best_metric_val = 0 else: raise NotImplementedError best_epoch = 0 for epoch in range(1, epochs + 1): # 0. Train & evaluate if disable_mixup: train_history = self.train(train_loader, device) else: train_history = self.train_with_mixup(train_loader, device) valid_history = self.evaluate(valid_loader, device) # 1. Epoch history (loss) epoch_history = { 'loss': { 'train': train_history.get('loss'), 'valid': valid_history.get('loss') } } # 2. Epoch history (other metrics if provided) if isinstance(self.metrics, dict): for metric_name, _ in self.metrics.items(): epoch_history[metric_name] = { 'train': train_history[metric_name], 'valid': valid_history[metric_name], } # 3. Tensorboard if self.writer is not None: for metric_name, metric_dict in epoch_history.items(): self.writer.add_scalars(main_tag=metric_name, tag_scalar_dict=metric_dict, global_step=epoch) if self.scheduler is not None: self.writer.add_scalar( tag='lr', scalar_value=self.scheduler.get_last_lr()[0], global_step=epoch) # 4. Save model if it is the current best metric_val = epoch_history[eval_metric]['valid'] if eval_metric == 'loss': if metric_val <= best_metric_val: best_metric_val = metric_val best_epoch = epoch self.save_checkpoint(self.best_ckpt, epoch=epoch, **epoch_history) elif eval_metric in ['accuracy', 'f1', 'auroc', 'auprc']: if metric_val >= best_metric_val: best_metric_val = metric_val best_epoch = epoch self.save_checkpoint(self.best_ckpt, epoch=epoch, **epoch_history) else: raise NotImplementedError # 5. Update learning rate scheduler (optional) if self.scheduler is not None: self.scheduler.step() # 6. Logging desc = make_epoch_description( history=epoch_history, current=epoch, total=epochs, best=best_epoch, ) pbar.set_description_str(desc) pbar.update(1) if logger is not None: logger.info(desc) # 7. Save last model self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history) # 8. Test model (optional) if 'test_set' in kwargs.keys(): test_loader = get_dataloader(kwargs.get('test_set'), batch_size, num_workers=num_workers) self.test(test_loader, device=device, logger=logger)
def run(self, train_set: torch.utils.data.Dataset, valid_set: torch.utils.data.Dataset, epochs: int, batch_size: int, num_workers: int = 0, **kwargs): logger = kwargs.get('logger', None) save_every = kwargs.get('save_every', epochs) self.backbone.to(self.local_rank) self.projector.to(self.local_rank) if self.distributed: raise NotImplementedError else: train_loader = DataLoader(train_set, batch_size, num_workers=num_workers, shuffle=True, pin_memory=False) valid_loader = DataLoader(valid_set, batch_size, num_workers=num_workers, shuffle=True, pin_memory=False) # Initialize memory representations for the training data if not self.memory.initialized: self.memory.initialize(self.backbone, self.projector, train_loader) with tqdm.tqdm(**get_tqdm_config( total=epochs, leave=True, color='blue')) as pbar: best_valid_loss = float('inf') best_epoch = 0 for epoch in range(1, epochs + 1): # 0. Train & evaluate train_history = self.train(train_loader) valid_history = self.evaluate(valid_loader) # 1. Epoch history (loss) epoch_history = { 'loss': { 'train': train_history.get('loss'), 'valid': valid_history.get('loss') }, } # 2. Epoch history (other metrics if provided) if self.metrics is not None: assert isinstance(self.metrics, dict) for metric_name, _ in self.metrics.items(): epoch_history[metric_name] = { 'train': train_history.get(metric_name), 'valid': valid_history.get(metric_name), } # 3. Tensorboard if self.writer is not None: for metric_name, metric_dict in epoch_history.items(): self.writer.add_scalars(main_tag=metric_name, tag_scalar_dict=metric_dict, global_step=epoch) if self.scheduler is not None: self.writer.add_scalar( tag='lr', scalar_value=self.scheduler.get_last_lr()[0], global_step=epoch) # 4-1. Save model if it is the current best valid_loss = epoch_history['loss']['valid'] if valid_loss < best_valid_loss: best_valid_loss = valid_loss best_epoch = epoch if self.local_rank == 0: self.save_checkpoint(self.best_ckpt, epoch=epoch, **epoch_history) self.memory.save(os.path.join( os.path.dirname(self.best_ckpt), 'best_memory.pt'), epoch=epoch) # 4-2. Save intermediate models if epoch % save_every == 0: if self.local_rank == 0: new_ckpt = os.path.join( self.checkpoint_dir, f'epoch_{epoch:04d}.loss_{valid_loss:.4f}.pt') self.save_checkpoint(new_ckpt, epoch=epoch, **epoch_history) # 5. Update learning rate scheduler if self.scheduler is not None: self.scheduler.step() # 6. Logging desc = make_epoch_description(history=epoch_history, current=epoch, total=epochs, best=best_epoch) pbar.set_description_str(desc) pbar.update(1) if logger is not None: logger.info(desc) # 7. Save last model if self.local_rank == 0: self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history) self.memory.save(os.path.join(os.path.dirname(self.last_ckpt), 'last_memory.pt'), epoch=epoch)