def _train_epoch( self, data: DataLoader, model: nn.Module, optimizer: optim.Optimizer, criterion: Callable, scheduler: optim.lr_scheduler._LRScheduler = None, clip: float = 1.0 ): model.train() losses = [] for i, inputs in enumerate(data): inputs = self.to_device(inputs) x, y = inputs['features'], inputs['targets'] optimizer.zero_grad() outputs = model(x) loss = criterion(outputs, y) losses.append(loss.item()) loss.backward() if clip is not None: nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() if scheduler is not None: scheduler.step() return losses
def joint_train(self, epoch: int = 0, optimizer: optim.Optimizer = None, lr_scheduler: optim.lr_scheduler._LRScheduler = None, poison_loader=None, discrim_loader=None, save=False, **kwargs): in_dim = self.model._model.classifier[0].in_features D = nn.Sequential( OrderedDict([('fc1', nn.Linear(in_dim, 256)), ('bn1', nn.BatchNorm1d(256)), ('relu1', nn.LeakyReLU()), ('fc2', nn.Linear(256, 128)), ('bn2', nn.BatchNorm1d(128)), ('relu2', nn.ReLU()), ('fc3', nn.Linear(128, 2))])) if env['num_gpus']: D.cuda() optim_params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: optim_params.extend(param_group['params']) optimizer.zero_grad() best_acc = 0.0 losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') for _epoch in range(epoch): self.discrim_train(epoch=100, D=D, discrim_loader=discrim_loader) self.model.train() self.model.activate_params(optim_params) for data in poison_loader: optimizer.zero_grad() _input, _label_f, _label_d = self.bypass_get_data(data) out_f = self.model(_input) loss_f = self.model.criterion(out_f, _label_f) out_d = D(self.model.get_final_fm(_input)) loss_d = self.model.criterion(out_d, _label_d) loss = loss_f - self.lambd * loss_d loss.backward() optimizer.step() optimizer.zero_grad() if lr_scheduler: lr_scheduler.step() self.model.activate_params([]) self.model.eval() _, cur_acc = self.validate_fn(get_data_fn=self.bypass_get_data) if cur_acc >= best_acc: prints('best result update!', indent=0) prints( f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=0) best_acc = cur_acc if save: self.save() print('-' * 50)
def attack(self, epoch: int, lr_scheduler: optim.lr_scheduler._LRScheduler = None, save: bool = False, **kwargs): print('Sample Data') poison_loader, discrim_loader = self.sample_data() # with poisoned images print('Joint Training') super().attack(epoch=10, lr_scheduler=lr_scheduler, **kwargs) if isinstance(lr_scheduler, optim.lr_scheduler._LRScheduler): lr_scheduler.step(0) self.joint_train(epoch=epoch, poison_loader=poison_loader, discrim_loader=discrim_loader, save=save, lr_scheduler=lr_scheduler, **kwargs)
def save( self, model: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, epoch: int, metric: float, ): if self.best_metric < metric: self.best_metric = metric self.best_epoch = epoch is_best = True else: is_best = False os.makedirs(self.root_dir, exist_ok=True) torch.save( { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "epoch": epoch, "best_epoch": self.best_epoch, "best_metric": self.best_metric, }, osp.join(self.root_dir, f"{epoch:02d}.pth"), ) if is_best: shutil.copy( osp.join(self.root_dir, f"{epoch:02d}.pth"), osp.join(self.root_dir, "best.pth"), )
def log_checkpoints( checkpoint_dir: Path, model: Union[nn.Module, nn.DataParallel], optimizer: Optimizer, scheduler: optim.lr_scheduler._LRScheduler, epoch: int, ) -> None: """ Serialize a PyTorch model in the `checkpoint_dir`. Args: checkpoint_dir: the directory to store checkpoints model: the model to serialize optimizer: the optimizer to be saved scheduler: the LR scheduler to be saved epoch: the epoch number """ checkpoint_file = 'checkpoint_{}.pt'.format(epoch) checkpoint_dir.mkdir(exist_ok=True, parents=True) file_path = checkpoint_dir / checkpoint_file if isinstance(model, nn.DataParallel): model_state_dict = model.module.state_dict() else: model_state_dict = model.state_dict() torch.save( # type: ignore { 'epoch': epoch, 'model_state_dict': model_state_dict, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, file_path, )
def _restore( mdl: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, ckpt_loc: str ) -> t.Tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, int, float]: """Restore model training state Args: mdl (nn.Module): The randomly initialized model optimizer (optim.Optimizer): The optimizer scheduler (optim.lr_scheduler._LRScheduler): The scheduler for learning rate ckpt_loc (str): Location to store model checkpoints Returns: t.Tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, int, float]: The restored status """ # Restore model checkpoint mdl.load_state_dict(torch.load(os.path.join(ckpt_loc, 'mdl.ckpt'))) optimizer.load_state_dict( torch.load(os.path.join(ckpt_loc, 'optimizer.ckpt'))) scheduler.load_state_dict( torch.load(os.path.join(ckpt_loc, 'scheduler.ckpt'))) # Restore timer and step counter with open(os.path.join(ckpt_loc, 'log.out')) as f: records = f.readlines() if records[-1] != 'Training finished\n': final_record = records[-1] else: final_record = records[-2] global_counter, t_final = final_record.split('\t')[:2] global_counter = int(global_counter) t_final = float(t_final) t0 = time.time() - t_final * 60 return mdl, optimizer, scheduler, global_counter, t0
def train_run(model: nn.Module, train_dl: torch.utils.data.dataloader.DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, num_it: int, on_batch_end: Callable[[int, float, float], None], device: torch.device): 'TODO: docstring' iterator = iter(train_dl) bar = tqdm(range(num_it)) for i in bar: try: xs, ys = next(iterator) except StopIteration: iterator = iter(train_dl) xs, ys = next(iterator) xs = xs.to(device) ys = ys.to(device) loss = train_batch(xs, ys, model, criterion, optimizer) on_batch_end(bar, i, loss, scheduler.get_lr()[0]) scheduler.step()
def _train_step(mdl: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, min_lr: float, clip_grad: float, device: torch.device, iter_train: t.Iterator): """Helper function to perform one step of training Args: mdl (nn.Module): The randomly initialized model optimizer (optim.Optimizer): The optimizer scheduler (optim.lr_scheduler._LRScheduler): The scheduler for learning rate min_lr (float): The minimum learning rate clip_grad (float): Gradient clipping device (torch.device): The device where tensors should be intialized iter_train (t.Iterator): The iterator for trainer """ # Prepare for training optimizer.zero_grad() # Clear gradient if all([ params_group['lr'] > min_lr for params_group in optimizer.param_groups ]): # Update learning rate if it is still larger than min_lr scheduler.step() # Get data mol_array, log_p = next(iter_train) loss = _loss(mol_array, log_p, mdl, device) loss.backward() # Clip gradient torch.nn.utils.clip_grad_value_(mdl.parameters(), clip_grad) optimizer.step() return loss
def _train_step( self, rank: int, dataset: Dataset, model: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler) -> Dict[str, float]: model.train() optimizer.zero_grad() data = self._fetch_from(dataset, rank, self.config.batch_train) metrics = self.spec.train_objective(data, model) loss = metrics['loss'] if self.config.use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() scheduler.step() return {k: self._to_value(v) for k, v in metrics.items()}
def _run_epoch(self, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, criterion: nn.Module): lr_step = self.min_lr with tqdm(self.loader, postfix=["Current state: ", dict(loss=0, lr=lr_step)]) as t: for data, target in t: data = data.to(self.device) target = target.to(self.device) optimizer.zero_grad() output = self.model(data) loss = criterion(output, target) loss.backward() optimizer.step() scheduler.step() lr_step = optimizer.state_dict()['param_groups'][0]['lr'] self.lr_s += [lr_step] self.losses += [loss.item()] t.postfix[1]['loss'] = loss.item() t.postfix[1]['lr'] = lr_step t.update()
def _save(mdl: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, global_counter: int, t0: float, loss: float, current_lr: float, ckpt_loc: str) -> str: """Saving checkpoint to file Args: mdl (nn.Module): The randomly initialized model optimizer (optim.Optimizer): The optimizer scheduler (optim.lr_scheduler._LRScheduler): The scheduler for learning rate global_counter (int): The global counter for training t0 (float): The time training was started loss (float): The loss of the model current_lr (float): The current learning rate ckpt_loc (str): Location to store model checkpoints Return: str: The message string """ # Save status torch.save(mdl.state_dict(), os.path.join(ckpt_loc, 'mdl.ckpt')) torch.save(optimizer.state_dict(), os.path.join(ckpt_loc, 'optimizer.ckpt')) torch.save(scheduler.state_dict(), os.path.join(ckpt_loc, 'scheduler.ckpt')) message_str = (f'{global_counter}\t' f'{float(time.time() - t0) / 60}\t' f'{loss}\t' f'{current_lr}\n') return message_str
def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None, validate_interval=10, save=False, verbose=True, indent=0, epoch_fn: Callable = None, **kwargs): loader_train = self.dataset.loader['train'] file_path = self.folder_path + self.get_filename() + '.pth' _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') params = [param_group['params'] for param_group in optimizer.param_groups] for _epoch in range(epoch): if callable(epoch_fn): self.model.activate_params([]) epoch_fn() self.model.activate_params(params) losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() if verbose and env['tqdm']: loader_train = tqdm(loader_train) optimizer.zero_grad() for data in loader_train: _input, _label = self.model.get_data(data) noise = torch.zeros_like(_input) poison_input, poison_label = self.get_poison_data(data) def loss_fn(X: torch.FloatTensor): return -self.model.loss(X, _label) adv_x = _input self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() for m in range(self.pgd.iteration): self.model.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1) optimizer.zero_grad() self.model.train() x = torch.cat((adv_x, poison_input)) y = torch.cat((_label, poison_label)) loss = self.model.loss(x, y) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.model.get_logits(_input) acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) self.model.eval() self.model.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30), f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) if cur_acc < best_acc: prints('best result update!', indent=indent) prints(f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: self.save() if verbose: print('-' * 50) self.model.zero_grad()
def adv_train(self, epochs: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None, validate_interval=10, save=False, verbose=True, indent=0, **kwargs): loader_train = self.dataset.loader['train'] file_path = os.path.join(self.folder_path, self.get_filename() + '.pth') best_acc, _ = self.validate_fn(verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) for _epoch in range(epochs): losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() if verbose and env['tqdm']: loader_train = tqdm(loader_train) self.model.activate_params(params) optimizer.zero_grad() for data in loader_train: _input, _label = self.model.get_data(data) noise = torch.zeros_like(_input) adv_x = _input self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() for m in range(self.pgd.iteration): self.model.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, target=_label, iteration=1) optimizer.zero_grad() self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.model(_input) acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) epoch_time = str( datetime.timedelta(seconds=int(time.perf_counter() - epoch_start))) self.model.eval() self.model.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epochs), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30), f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format( **ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epochs - 1: adv_acc, _ = self.validate_fn(verbose=verbose, indent=indent, **kwargs) if adv_acc < best_acc: prints('{purple}best result update!{reset}'.format( **ansi), indent=indent) prints( f'Current Acc: {adv_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = adv_acc if save: self.model.save(file_path=file_path, verbose=verbose) if verbose: print('-' * 50) self.model.zero_grad()
def _iter_impl(epoch: int, phase: str, data_loader: DataLoader, device: str, model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, larger_holder: LargerHolder, baseline_flops: float, flops_tester: FLOPs, logger: logging.Logger, output_directory: str, writer: SummaryWriter, log_frequency: int): start = datetime.now() clear_statistics(model) model.train(phase == "train") loss_metric = AverageMetric() accuracy_metric = AccuracyMetric(topk=(1, 5)) for iter_, (datas, targets) in enumerate(data_loader, start=1): datas, targets = datas.to(device=device), targets.to(device=device) with torch.set_grad_enabled(phase == "train"): outputs = model(datas) loss = criterion(outputs, targets) if optimizer is not None: optimizer.zero_grad() loss.backward() optimizer.step() loss_metric.update(loss) accuracy_metric.update(targets, outputs) if iter_ % log_frequency == 0: logger.info( f"{phase.upper()}, epoch={epoch:03d}, iter={iter_}/{len(data_loader)}, " f"loss={loss_metric.last:.4f}({loss_metric.value:.4f}), " f"accuracy@1={accuracy_metric.last_accuracy(1).rate*100:.2f}%" f"({accuracy_metric.accuracy(1).rate*100:.2f}%), " f"accuracy@5={accuracy_metric.last_accuracy(5).rate*100:.2f}%" f"({accuracy_metric.accuracy(5).rate*100:.2f}%), ") if phase != "train": acc = accuracy_metric.accuracy(1).rate if larger_holder.update(new_value=acc, metadata=dict(epoch=epoch)): if output_directory is not None: torch.save(model.state_dict(), os.path.join(output_directory, "best_model.pth")) if scheduler is not None: scheduler.step() flops = flops_tester.compute() logger.info( f"{phase.upper()} Complete, epoch={epoch:03d}, " f"loss={loss_metric.value:.4f}, " f"accuracy@1={accuracy_metric.accuracy(1).rate*100:.2f}%, " f"accuracy@5={accuracy_metric.accuracy(5).rate*100:.2f}%, " f"flops={flops/1e6:.2f}M({flops/baseline_flops*100:.2f}%), " f"best_accuracy={larger_holder.value*100:.2f}%(epoch={larger_holder.metadata['epoch']:03d}), " f"propotions={network_proportion(model)}, " f"eplased time={datetime.now()-start}.") writer.add_scalar(f"{phase}/loss", loss_metric.value, epoch) writer.add_scalar(f"{phase}/accuracy@1", accuracy_metric.accuracy(1).rate, epoch) writer.add_scalar(f"{phase}/accuracy@5", accuracy_metric.accuracy(5).rate, epoch)
def train_model( train_ds: tf.data.Dataset, dev_ds: tf.data.Dataset, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler, args: argparse.Namespace, ) -> nn.Module: device = model_utils.get_device() loss_fn = model_utils.depth_proportional_loss val_loss_fn = model_utils.l1_norm_loss best_val_loss = torch.tensor(float('inf')) saved_checkpoints = [] writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}') cos = nn.CosineSimilarity(dim=1, eps=0) get_gradient: nn.Module = sobel.Sobel().to(device) for e in range(1, args.train_epochs + 1): print(f'Training epoch {e}...') if args.use_scheduler: lr_scheduler.step() # Training portion torch.cuda.empty_cache() torch.set_grad_enabled(True) with tqdm(total=args.train_batch_size * len(train_ds)) as progress_bar: model.train() for i, (x_batch_orig, y_batch) in enumerate(train_ds.as_numpy_iterator()): x_batch, y_batch = model_utils.preprocess_training_example( x_batch_orig, y_batch) y_blurred = model_utils.blur_depth_map(y_batch) ones = torch.ones(y_batch.shape, dtype=torch.float32, device=device) # Forward pass on model optimizer.zero_grad() y_pred = model(x_batch) depth_grad = get_gradient(y_blurred) output_grad = get_gradient(y_pred) depth_grad_dx = depth_grad[:, 0, :, :].contiguous().view_as( y_blurred) depth_grad_dy = depth_grad[:, 1, :, :].contiguous().view_as( y_batch) output_grad_dx = output_grad[:, 0, :, :].contiguous().view_as( y_blurred) output_grad_dy = output_grad[:, 1, :, :].contiguous().view_as( y_batch) depth_normal = torch.cat( (-depth_grad_dx, -depth_grad_dy, ones), 1) output_normal = torch.cat( (-output_grad_dx, -output_grad_dy, ones), 1) loss_depth = torch.log(torch.abs(y_pred - y_batch) + 0.5).mean() loss_dx = torch.log( torch.abs(output_grad_dx - depth_grad_dx) + 0.5).mean() loss_dy = torch.log( torch.abs(output_grad_dy - depth_grad_dy) + 0.5).mean() loss_normal = torch.abs( 1 - cos(output_normal, depth_normal)).mean() loss = loss_depth + loss_normal + (loss_dx + loss_dy) # Backward pass and optimization loss.backward() optimizer.step() progress_bar.update(len(x_batch)) progress_bar.set_postfix(loss=loss.item()) writer.add_scalar("train/Loss", loss, ((e - 1) * len(train_ds) + i) * args.train_batch_size) # Periodically save a diagram if (i + 1) % args.picture_frequency == 0: model_utils.make_diagram( np.transpose(x_batch_orig, (0, 3, 1, 2)), x_batch.cpu().numpy(), y_batch.cpu().numpy(), y_pred.cpu().detach().numpy(), f'{args.save_path}/{args.experiment}/diagram_{e}_{i+1}.png', ) del x_batch del y_batch del y_blurred del y_pred del loss # Validation portion torch.cuda.empty_cache() torch.set_grad_enabled(False) with tqdm(total=args.dev_batch_size * len(dev_ds)) as progress_bar: model.eval() val_loss = 0.0 num_batches_processed = 0 total_pixels = 0 total_examples = 0 squared_error = 0 rel_error = 0 log_error = 0 threshold1 = 0 # 1.25 threshold2 = 0 # 1.25^2 threshold3 = 0 # corresponds to 1.25^3 for i, (x_batch, y_batch) in enumerate(dev_ds.as_numpy_iterator()): x_batch, y_batch = model_utils.preprocess_test_example( x_batch, y_batch) # Forward pass on model in validation environment y_pred = model(x_batch) # TODO: Process y_pred in whatever way inference requires. loss = val_loss_fn(y_pred, y_batch) val_loss += loss.item() num_batches_processed += 1 nanmask = getNanMask(y_batch) total_pixels = torch.sum(~nanmask) total_examples += x_batch.shape[0] # RMS, REL, LOG10, threshold calculation squared_error += ( torch.sum(torch.pow(y_pred - y_batch, 2)).item() / total_pixels)**0.5 rel_error += torch.sum( removeNans(torch.abs(y_pred - y_batch) / y_batch)).item() / total_pixels log_error += torch.sum( torch.abs( removeNans(torch.log10(y_pred)) - removeNans( torch.log10(y_batch)))).item() / total_pixels threshold1 += torch.sum( torch.max(y_pred / y_batch, y_batch / y_pred) < 1.25).item() / total_pixels threshold2 += torch.sum( torch.max(y_pred / y_batch, y_batch / y_pred) < 1.25**2).item() / total_pixels threshold3 += torch.sum( torch.max(y_pred / y_batch, y_batch / y_pred) < 1.25**3).item() / total_pixels progress_bar.update(len(x_batch)) progress_bar.set_postfix(val_loss=val_loss / num_batches_processed) writer.add_scalar("Val/Loss", loss, ((e - 1) * len(dev_ds) + i) * args.dev_batch_size) del x_batch del y_batch del y_pred del loss writer.add_scalar("Val/RMS", squared_error / total_examples, e) writer.add_scalar("Val/REL", rel_error / total_examples, e) writer.add_scalar("Val/LOG10", log_error / total_examples, e) writer.add_scalar("Val/delta1", threshold1 / total_examples, e) writer.add_scalar("Val/delta2", threshold2 / total_examples, e) writer.add_scalar("Val/delta3", threshold3 / total_examples, e) # Save model if it's the best one yet. if val_loss / num_batches_processed < best_val_loss: best_val_loss = val_loss / num_batches_processed filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint' model_utils.save_model(model, filename) print(f'Model saved!') print(f'Best validation loss yet: {best_val_loss}') # Save model on checkpoints. if e % args.checkpoint_freq == 0: filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint' model_utils.save_model(model, filename) print(f'Model checkpoint reached!') saved_checkpoints.append(filename) # Delete checkpoints if there are too many while len(saved_checkpoints) > args.num_checkpoints: os.remove(saved_checkpoints.pop(0)) return model
def train( logger: lavd.Logger, model: nn.Module, optimiser: optim.Optimizer, # type: ignore train_data_loader: DataLoader, validation_data_loaders: DataLoader, lr_scheduler: optim.lr_scheduler._LRScheduler, device: torch.device, checkpoint: Dict, num_epochs: int = num_epochs, model_kind: str = default_model, amp_scaler: Optional[amp.GradScaler] = None, masked_lm: bool = True, ): start_epoch = checkpoint["epoch"] train_stats = checkpoint["train"] validation_cp = checkpoint["validation"] outdated_validations = checkpoint["outdated_validation"] validation_results_dict: Dict[str, Dict] = OrderedDict() for val_data_loader in validation_data_loaders: val_name = val_data_loader.dataset.name val_result = (validation_cp[val_name] if val_name in validation_cp else OrderedDict(start=start_epoch, stats=OrderedDict(loss=[], perplexity=[]))) validation_results_dict[val_name] = val_result # All validations that are no longer used, will be stored in outdated_validation # just to have them available. outdated_validations.append( OrderedDict({ k: v for k, v in validation_cp.items() if k not in validation_results_dict })) tokeniser = train_data_loader.dataset.tokeniser # type: ignore for epoch in range(num_epochs): actual_epoch = start_epoch + epoch + 1 epoch_text = "[{current:>{pad}}/{end}] Epoch {epoch}".format( current=epoch + 1, end=num_epochs, epoch=actual_epoch, pad=len(str(num_epochs)), ) logger.set_prefix(epoch_text) logger.start(epoch_text, prefix=False) start_time = time.time() logger.start("Train") train_result = run_epoch( train_data_loader, model, optimiser, device=device, epoch=epoch, train=True, name="Train", logger=logger, amp_scaler=amp_scaler, masked_lm=masked_lm, ) train_stats["stats"]["loss"].append(train_result["loss"]) train_stats["stats"]["perplexity"].append(train_result["perplexity"]) epoch_lr = lr_scheduler.get_last_lr()[0] # type: ignore train_stats["lr"].append(epoch_lr) lr_scheduler.step() logger.end("Train") validation_results = [] for val_data_loader in validation_data_loaders: val_name = val_data_loader.dataset.name val_text = "Validation: {}".format(val_name) logger.start(val_text) validation_result = run_epoch( val_data_loader, model, optimiser, device=device, epoch=epoch, train=False, name=val_text, logger=logger, amp_scaler=amp_scaler, masked_lm=masked_lm, ) validation_results.append( OrderedDict(name=val_name, stats=validation_result)) validation_results_dict[val_name]["stats"]["loss"].append( validation_result["loss"]) validation_results_dict[val_name]["stats"]["perplexity"].append( validation_result["perplexity"]) logger.end(val_text) with logger.spinner("Checkpoint", placement="right"): # Multi-gpu models wrap the original model. To make the checkpoint # compatible with the original model, the state dict of .module is saved. model_unwrapped = (model.module if isinstance( model, DistributedDataParallel) else model) save_checkpoint( logger, model_unwrapped, tokeniser, stats=OrderedDict( epoch=actual_epoch, train=train_stats, validation=validation_results_dict, outdated_validation=outdated_validations, model=OrderedDict(kind=model_kind), ), step=actual_epoch, ) with logger.spinner("Logging Data", placement="right"): log_results( logger, actual_epoch, OrderedDict(lr=epoch_lr, stats=train_result), validation_results, model_unwrapped, ) with logger.spinner("Best Checkpoints", placement="right"): val_stats = OrderedDict({ val_name: { "name": val_name, "start": val_result["start"], "stats": val_result["stats"], } for val_name, val_result in validation_results_dict.items() }) log_top_checkpoints(logger, val_stats, metrics) time_difference = time.time() - start_time epoch_results = [OrderedDict(name="Train", stats=train_result) ] + validation_results log_epoch_stats(logger, epoch_results, metrics, lr=epoch_lr, time_elapsed=time_difference) logger.end(epoch_text, prefix=False)
def train_model( train_graph: pyg.torch_geometric.data.Data, valid_graph: pyg.torch_geometric.data.Data, train_dl: data.DataLoader, dev_dl: data.DataLoader, evaluator: Evaluator, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler, args: argparse.Namespace, ) -> nn.Module: device = model_utils.get_device() loss_fn = nn.functional.binary_cross_entropy val_loss_fn = nn.functional.binary_cross_entropy best_val_loss = torch.tensor(float('inf')) best_val_hits = torch.tensor(0.0) saved_checkpoints = [] writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}') for e in range(1, args.train_epochs + 1): print(f'Training epoch {e}...') # Training portion torch.cuda.empty_cache() torch.set_grad_enabled(True) with tqdm(total=args.train_batch_size * len(train_dl)) as progress_bar: model.train() # Load graph into GPU adj_t = train_graph.adj_t.to(device) edge_index = train_graph.edge_index.to(device) x = train_graph.x.to(device) pos_pred = [] neg_pred = [] for i, (y_pos_edges,) in enumerate(train_dl): y_pos_edges = y_pos_edges.to(device).T y_neg_edges = negative_sampling( edge_index, num_nodes=train_graph.num_nodes, num_neg_samples=y_pos_edges.shape[1] ).to(device) y_batch = torch.cat([torch.ones(y_pos_edges.shape[1]), torch.zeros( y_neg_edges.shape[1])], dim=0).to(device) # Ground truth edge labels (1 or 0) # Forward pass on model optimizer.zero_grad() y_pred = model(adj_t, torch.cat( [y_pos_edges, y_neg_edges], dim=1)) loss = loss_fn(y_pred, y_batch) # Backward pass and optimization loss.backward() optimizer.step() if args.use_scheduler: lr_scheduler.step(loss) batch_acc = torch.mean( 1 - torch.abs(y_batch.detach() - torch.round(y_pred.detach()))).item() pos_pred += [y_pred[y_batch == 1].detach()] neg_pred += [y_pred[y_batch == 0].detach()] progress_bar.update(y_pos_edges.shape[1]) progress_bar.set_postfix(loss=loss.item(), acc=batch_acc) writer.add_scalar( "train/Loss", loss, ((e - 1) * len(train_dl) + i) * args.train_batch_size) writer.add_scalar("train/Accuracy", batch_acc, ((e - 1) * len(train_dl) + i) * args.train_batch_size) del y_pos_edges del y_neg_edges del y_pred del loss del adj_t del edge_index del x # Training set evaluation Hits@K Metrics pos_pred = torch.cat(pos_pred, dim=0) neg_pred = torch.cat(neg_pred, dim=0) results = {} for K in [10, 20, 30]: evaluator.K = K hits = evaluator.eval({ 'y_pred_pos': pos_pred, 'y_pred_neg': neg_pred, })[f'hits@{K}'] results[f'Hits@{K}'] = hits print() print(f'Train Statistics') print('*' * 30) for k, v in results.items(): print(f'{k}: {v}') writer.add_scalar( f"train/{k}", v, (pos_pred.shape[0] + neg_pred.shape[0]) * e) print('*' * 30) del pos_pred del neg_pred # Validation portion torch.cuda.empty_cache() torch.set_grad_enabled(False) with tqdm(total=args.val_batch_size * len(dev_dl)) as progress_bar: model.eval() adj_t = valid_graph.adj_t.to(device) edge_index = valid_graph.edge_index.to(device) x = valid_graph.x.to(device) val_loss = 0.0 accuracy = 0 num_samples_processed = 0 pos_pred = [] neg_pred = [] for i, (edges_batch, y_batch) in enumerate(dev_dl): edges_batch = edges_batch.T.to(device) y_batch = y_batch.to(device) # Forward pass on model in validation environment y_pred = model(adj_t, edges_batch) loss = val_loss_fn(y_pred, y_batch) num_samples_processed += edges_batch.shape[1] batch_acc = torch.mean( 1 - torch.abs(y_batch - torch.round(y_pred))).item() accuracy += batch_acc * edges_batch.shape[1] val_loss += loss.item() * edges_batch.shape[1] pos_pred += [y_pred[y_batch == 1].detach()] neg_pred += [y_pred[y_batch == 0].detach()] progress_bar.update(edges_batch.shape[1]) progress_bar.set_postfix( val_loss=val_loss / num_samples_processed, acc=accuracy/num_samples_processed) writer.add_scalar( "Val/Loss", loss, ((e - 1) * len(dev_dl) + i) * args.val_batch_size) writer.add_scalar( "Val/Accuracy", batch_acc, ((e - 1) * len(dev_dl) + i) * args.val_batch_size) del edges_batch del y_batch del y_pred del loss del adj_t del edge_index del x # Validation evaluation Hits@K Metrics pos_pred = torch.cat(pos_pred, dim=0) neg_pred = torch.cat(neg_pred, dim=0) results = {} for K in [10, 20, 30]: evaluator.K = K hits = evaluator.eval({ 'y_pred_pos': pos_pred, 'y_pred_neg': neg_pred, })[f'hits@{K}'] results[f'Hits@{K}'] = hits print() print(f'Validation Statistics') print('*' * 30) for k, v in results.items(): print(f'{k}: {v}') writer.add_scalar( f"Val/{k}", v, (pos_pred.shape[0] + neg_pred.shape[0]) * e) print('*' * 30) del pos_pred del neg_pred # Save model if it's the best one yet. if results['Hits@20'] > best_val_hits: best_val_hits = results['Hits@20'] filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint' model_utils.save_model(model, filename) print(f'Model saved!') print(f'Best validation Hits@20 yet: {best_val_hits}') # Save model on checkpoints. if e % args.checkpoint_freq == 0: filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint' model_utils.save_model(model, filename) print(f'Model checkpoint reached!') saved_checkpoints.append(filename) # Delete checkpoints if there are too many while len(saved_checkpoints) > args.num_checkpoints: os.remove(saved_checkpoints.pop(0)) return model
def train(epoch: int, model: nn.Module, loader: data.DataLoader, criterion: nn.modules.loss._Loss, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, only_epoch_sche: bool, use_amp: bool, accmulated_steps: int, device: str, log_interval: int): model.train() scaler = GradScaler() if use_amp else None gradident_accumulator = GradientAccumulator(accmulated_steps) loss_metric = AverageMetric("loss") accuracy_metric = AccuracyMetric(topk=(1, 5)) ETA = EstimatedTimeArrival(len(loader)) speed_tester = SpeedTester() lr = optimizer.param_groups[0]['lr'] _logger.info(f"Train start, epoch={epoch:04d}, lr={lr:.6f}") for time_cost, iter_, (inputs, targets) in time_enumerate(loader, start=1): inputs, targets = inputs.to(device=device), targets.to(device=device) optimizer.zero_grad() with autocast(enabled=use_amp): outputs = model(inputs) loss: torch.Tensor = criterion(outputs, targets) gradident_accumulator.backward_step(model, loss, optimizer, scaler) if scheduler is not None: if only_epoch_sche: if iter_ == 1: scheduler.step() else: scheduler.step() loss_metric.update(loss) accuracy_metric.update(outputs, targets) ETA.step() speed_tester.update(inputs) if iter_ % log_interval == 0 or iter_ == len(loader): _logger.info(", ".join([ "TRAIN", f"epoch={epoch:04d}", f"iter={iter_:05d}/{len(loader):05d}", f"fetch data time cost={time_cost*1000:.2f}ms", f"fps={speed_tester.compute()*world_size():.0f} images/s", f"{loss_metric}", f"{accuracy_metric}", f"{ETA}", ])) speed_tester.reset() return { "lr": lr, "train/loss": loss_metric.compute(), "train/top1_acc": accuracy_metric.at(1).rate, "train/top5_acc": accuracy_metric.at(5).rate, }
def train_model( train_dl: data.DataLoader, dev_dl: data.DataLoader, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler, args: argparse.Namespace, ) -> nn.Module: device = model_utils.get_device() # loss_fn = nn.functional.binary_cross_entropy loss_fn = model_utils.l1_norm_loss val_loss_fn = model_utils.l1_norm_loss best_val_loss = torch.tensor(float('inf')) saved_checkpoints = [] writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}') scalar_rand = torch.distributions.uniform.Uniform(0.5, 1.5) for e in range(1, args.train_epochs + 1): print(f'Training epoch {e}...') # Training portion torch.cuda.empty_cache() with tqdm(total=args.train_batch_size * len(train_dl)) as progress_bar: model.train() for i, (x_batch, y_batch_biden, y_batch_trump, _) in enumerate(train_dl): # trump_scale = scalar_rand.sample() # biden_scale = scalar_rand.sample() # y_batch_biden = y_batch_biden * biden_scale # y_batch_trump = y_batch_trump * trump_scale # x_batch = (y_batch_trump + y_batch_biden).abs().to(device) x_batch = x_batch.abs().to(device) y_batch_biden = y_batch_biden.abs().to(device) y_batch_trump = y_batch_trump.abs().to(device) # Forward pass on model optimizer.zero_grad() y_pred_b, y_pred_t = model(x_batch) if args.train_trump: # loss = loss_fn(y_pred_t * x_batch, y_batch_trump) loss = loss_fn(y_pred_t, y_batch_trump) else: # loss = loss_fn(y_pred_b * x_batch, y_batch_biden) loss = loss_fn(y_pred_b, y_batch_biden) # Backward pass and optimization loss.backward() optimizer.step() if args.use_scheduler: lr_scheduler.step(loss) progress_bar.update(len(x_batch)) progress_bar.set_postfix(loss=loss.item()) writer.add_scalar("train/Loss", loss, ((e - 1) * len(train_dl) + i) * args.train_batch_size) del x_batch del y_batch_biden del y_batch_trump del y_pred_b del y_pred_t del loss # Validation portion torch.cuda.empty_cache() with tqdm(total=args.val_batch_size * len(dev_dl)) as progress_bar: model.eval() val_loss = 0.0 num_batches_processed = 0 for i, (x_batch, y_batch_biden, y_batch_trump, _) in enumerate(dev_dl): x_batch = x_batch.abs().to(device) y_batch_biden = y_batch_biden.abs().to(device) y_batch_trump = y_batch_trump.abs().to(device) # Forward pass on model y_pred_b, y_pred_t = model(x_batch) # y_pred_b_mask = torch.ones_like(y_pred_b) * (y_pred_b > args.alpha) # y_pred_t_mask = torch.ones_like(y_pred_t) * (y_pred_t > args.alpha) y_pred_b_mask = torch.clamp(y_pred_b / x_batch, 0, 1) y_pred_t_mask = torch.clamp(y_pred_t / x_batch, 0, 1) loss_trump = val_loss_fn(y_pred_t_mask * x_batch, y_batch_trump) loss_biden = val_loss_fn(y_pred_b_mask * x_batch, y_batch_biden) if args.train_trump: val_loss += loss_trump.item() else: val_loss += loss_biden.item() num_batches_processed += 1 progress_bar.update(len(x_batch)) progress_bar.set_postfix(val_loss=val_loss / num_batches_processed) writer.add_scalar("Val/Biden Loss", loss_biden, ((e - 1) * len(dev_dl) + i) * args.val_batch_size) writer.add_scalar("Val/Trump Loss", loss_trump, ((e - 1) * len(dev_dl) + i) * args.val_batch_size) del x_batch del y_batch_biden del y_batch_trump del y_pred_b del y_pred_t del loss_trump del loss_biden # Save model if it's the best one yet. if val_loss / num_batches_processed < best_val_loss: best_val_loss = val_loss / num_batches_processed filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint' model_utils.save_model(model, filename) print(f'Model saved!') print(f'Best validation loss yet: {best_val_loss}') # Save model on checkpoints. if e % args.checkpoint_freq == 0: filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint' model_utils.save_model(model, filename) print(f'Model checkpoint reached!') saved_checkpoints.append(filename) # Delete checkpoints if there are too many while len(saved_checkpoints) > args.num_checkpoints: os.remove(saved_checkpoints.pop(0)) return model