def clf_train(net, tloader, opti: torch.optim, crit: nn.Module, **kwargs): # TODO Fix this if kwargs['topk'] != (1, 5): raise Exception('topk other than (1, 5) not supported for now.') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net.to(device) net.train() a1mtr = AvgMeter('train_acc1') a5mtr = AvgMeter('train_acc5') tloss = 0 try: crit = crit() except: pass for ii, (data, labl) in enumerate(tqdm(tloader)): data, labl = data.to(device), labl.to(device) out = net(data) loss = crit(out, labl) opti.zero_grad() loss.backward() opti.step() with torch.no_grad(): tloss += loss.item() acc1, acc5 = accuracy(out, labl, topk=kwargs['topk']) a1mtr(acc1, data.size(0)) a5mtr(acc5, data.size(0)) tloss /= len(tloader) return (a1mtr.avg, a5mtr.avg), tloss
def train_drug_qed( device: torch.device, drug_qed_net: nn.Module, data_loader: torch.utils.data.DataLoader, max_num_batches: int, loss_func: callable, optimizer: torch.optim, ): drug_qed_net.train() total_loss = 0. num_samples = 0 for batch_idx, (drug_feature, target) in enumerate(data_loader): if batch_idx >= max_num_batches: break drug_feature, target = drug_feature.to(device), target.to(device) drug_qed_net.zero_grad() pred_target = drug_qed_net(drug_feature) loss = loss_func(pred_target, target) loss.backward() optimizer.step() num_samples += target.shape[0] total_loss += loss.item() * target.shape[0] print('\tDrug Weighted QED Regression Loss: %8.6f' % (total_loss / num_samples))
def train_fn(model: torch.nn, data_loader: DataLoader, optimizer: optim, device: torch.device, epoch: int): model.train() start_time = datetime.datetime.now() num_images: int = 0 for i, (images, targets) in enumerate(data_loader): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] images = torch.stack(images) num_images += len(images) optimizer.zero_grad() loss_dict: Dict[str, torch.Tensor] = model(images, targets) loss: float = sum(loss for loss in loss_dict.values()) loss.backward() optimizer.step() if (i + 1) % 10 == 0: print('-' * 50) print( f'Epoch {epoch+1}[{len(data_loader.dataset):,}/{(num_images/len(data_loader.dataset))*100:.2f}%] ' f'- Elapsed time: {datetime.datetime.now() - start_time}\n' f' - loss: classifier={loss_dict["loss_classifier"]:.6f}, box_reg={loss_dict["loss_box_reg"]:.6f}, ' f'objectness={loss_dict["loss_objectness"]:.6f}, rpn_box_reg={loss_dict["loss_rpn_box_reg"]:.6f}' )
def loadCheckpoint(checkpoint_path: str, model: nn.Module, optimizer: optim, scheduler: optim.lr_scheduler.MultiStepLR): """ Load the training instance to .pth file Parameters ---------- checkpoint_path : str the directory of the model parameter model, optimizer, scheduler : the neural network to save Return ------ model, optimizer, resume_epoch, resume_iteration, scheduler """ state = torch.load(checkpoint_path) resume_epoch = state['epoch'] resume_iteration = state['iteration'] model.load_state_dict(state['state_dict']) optimizer.load_state_dict(state['optimizer']) scheduler.load_state_dict(state['scheduler']) return model, optimizer, resume_epoch, resume_iteration, scheduler
def train_on_batch(model: Tree2Seq, criterion: nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, graph: dgl.BatchedDGLGraph, labels: List[str], params: Dict, device: torch.device) -> Dict: model.train() root_indexes = get_root_indexes(graph).to(device) # Model step model.zero_grad() root_logits, ground_truth = model(graph, root_indexes, labels, params['teacher_force'], device) root_logits = root_logits[1:] ground_truth = ground_truth[1:] loss = criterion(root_logits.view(-1, root_logits.shape[-1]), ground_truth.view(-1)) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), params['clip_norm']) optimizer.step() scheduler.step() # Calculate metrics prediction = model.predict(root_logits) batch_train_info = { 'loss': loss.item(), 'statistics': calculate_batch_statistics( ground_truth, prediction, [model.decoder.label_to_id[token] for token in [PAD, UNK, EOS]]) } return batch_train_info
def load_model_checkpoint(model: torch.nn.Module, filename: str, inference: bool, map_location=None, optimizer: torch.optim = None): """ Load a model checkpoint :param model: :param filename: :param inference: :param optimizer: :return: """ checkpoint = torch.load(filename, map_location=map_location) if optimizer: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # epoch = checkpoint['epoch'] # loss = checkpoint['loss'] if inference: model.eval() else: model.train() return model.load_state_dict(checkpoint['model_state_dict'])
def gradient_update( self, device: str, optimizer: torch.optim, gamma: float, batch: List[Tuple[np.ndarray, np.ndarray, int, float]] ) -> float: self.model.train() preds = self.model.forward_np_array( device=device, x=np.array([x[0] for x in batch]) ) labels = preds.clone().detach() labels = labels.to(device) next_frames_preds = self.model_target.forward_np_array( device=device, x=np.array([x[1] if x[1] is not None else x[0] for x in batch]) ).detach() for i, b in enumerate(batch): _, next_frame, action, reward = b if next_frame is None: # is it terminal state labels[i][action] = reward else: labels[i][action] = reward + gamma * max(next_frames_preds[i]) loss = self.criterion(preds, labels) optimizer.zero_grad() loss.backward() optimizer.step() return float(loss)
def train(model, optimizer: torch.optim, data: torch_geometric.data.Data, perturbation, gamma): """ trains the model for one epoch Parameters ---------- model: Model optimizer: torch.optim data: torch_geometric.data.Data """ model.train() optimizer.zero_grad() y_hat, R = model.forward(perturbation=perturbation, grad_perturbation=False) # look here - what do we do. accuracy goes up, then down" y_hat = y_hat[model.data.train_mask] loss = F.nll_loss(y_hat, data.y[model.data.train_mask]) + gamma * R # loss = F.nll_loss(y_hat, data.y) + gamma * R loss.backward() optimizer.step() model.eval()
def load_optim(optimizer: torch.optim, checkpoint_path: str, device: torch.device) -> torch.optim: """ Load optimizer to continuer training Args: optimizer : initialized optimizer checkpoint_path: path to the checkpoint device : device to send optimizer to (must be the same as in the model) Note: must be called after initializing the model Output: optimizer with the loaded state """ checkpoint = torch.load(checkpoint_path) optimizer.load_state_dict(checkpoint['optimizer']) for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) for param_group in optimizer.param_groups: print('learning_rate: {}'.format(param_group['lr'])) print('Loaded optimizer {} state from {}'.format(optimizer, checkpoint_path)) return optimizer
def train_loop(model, opt: torch.optim, train_loader, val_loader=None, sheduler=None, batch_size=2, dice_loss_beta=1., num_epochs=60, save=False, validate=True, device='cpu', save_name='unet'): training_scores = [] train_batch_gen = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True) if validate and val_loader is not None: validation_scores = [] val_batch_gen = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=False) for epoch in range(num_epochs): print(f'Epoch number: {epoch}') start_time = time.time() model.train() epoch_train_loss = [] for (X_image, X_mask) in train_batch_gen: X_image = X_image.to(device) pred_mask = model(X_image)[:, 0].contiguous().view(-1) true_mask = X_mask[:, 0].contiguous().view(-1).to(device) loss = dice_loss(true_mask, pred_mask, dice_loss_beta) loss.backward() opt.step() opt.zero_grad() if sheduler is not None: sheduler.step() epoch_train_loss.append(loss.data.cpu().numpy()) training_scores.append(np.mean(epoch_train_loss)) if validate: model.eval() masks_2_pred = predict_val(val_batch_gen, model, device) val_preds = np.vstack(list(masks_2_pred.values())) val_true = np.vstack(list(masks_2_pred.keys())) val_iou = calc_iou(val_preds, val_true) validation_scores.append(val_iou) print('validation iou is {}'.format(val_iou)) print(f'Training epoch loss: {training_scores[-1]}') print("Epoch {} of {} took {:.3f}s".format( epoch + 1, num_epochs, time.time() - start_time)) if save: torch.save(model.state_dict(), f'{save_name}') return (training_scores, validation_scores) if validate else training_scores
def _loop_inference( self, gt_labels: torch.Tensor, x: torch.Tensor, y: torch.Tensor, optim_inf: torch.optim, training: bool ) -> torch.Tensor: if gt_labels is not None: # Adversarial output = self.model(x, y) oracle = self.oracle_value(y, gt_labels, training) # this is the BCE loss with logits value = self.loss_fn(output, oracle) else: output = self.model(x, y) value = torch.sigmoid(output) grad = torch.autograd.grad(value, y, grad_outputs=torch.ones_like(value), only_inputs=True) y_grad = grad[0].detach() if gt_labels is None and self.use_hamming_metric: # We want to reduce !! the Hamming loss in this case y = y - optim_inf.update(y_grad) else: y = y + optim_inf.update(y_grad) y = y + optim_inf.update(y_grad) # Project back to the valid range y = torch.clamp(y, 0, 1) return y
def train_embedding_model(model: nn.Module, data: iter, n_epochs: int, criterion: nn.modules.loss, optimizer: torch.optim) -> None: """ :param model - class inherited from nn.Module with DL model :param data - iterator for batching data :param n_epochs - number of epochs :param criterion - loss for model :param optimizer - optimizer from torch for model """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') for epoch in range(n_epochs): for y, cont_x, cat_x in data: cat_x = cat_x.to(device) cont_x = cont_x.to(device) y = y.to(device) preds = model(cont_x, cat_x) loss = criterion(preds, y) optimizer.zero_grad() loss.backward() optimizer.step() print(f'loss on epoch {epoch} is {loss}')
def train_resp( device: torch.device, resp_net: nn.Module, data_loader: torch.utils.data.DataLoader, max_num_batches: int, loss_func: callable, optimizer: torch.optim, ): resp_net.train() total_loss = 0. num_samples = 0 for batch_idx, (*ids, rnaseq, drug_feature, conc, grth) \ in enumerate(data_loader): if batch_idx >= max_num_batches: break rnaseq, drug_feature, conc, grth = \ rnaseq.to(device), drug_feature.to(device), \ conc.to(device), grth.to(device) resp_net.zero_grad() pred_growth = resp_net(rnaseq, drug_feature, conc) loss = loss_func(pred_growth, grth) loss.backward() optimizer.step() num_samples += conc.shape[0] total_loss += loss.item() * conc.shape[0] print('\tDrug Response Regression Loss: %8.2f' % (total_loss / num_samples))
def train(self, epoch_s: int, epoch_e: int, data: Data, n_samples: int, optimizer: torch.optim, device: torch.device, strategy: str = 'max', mode: bool = True) -> None: train_time = time.time() prefix_sav = f'./model_save/WNGat_{train_time}' loss_list = [] super().train() negloss = NEGLoss(data.x, data.edge_index, n_samples) for epoch in range(epoch_s, epoch_e): optimizer.zero_grad() oup = self.forward(data.x, data.edge_index) loss = negloss(oup, data.edge_index) loss_list.append(loss.data) sr_params = {'oup': oup} sr_rls = sr_test(device, self.emb, strategy, **sr_params) save_model(epoch, self, optimizer, loss_list, prefix_sav, oup, sr=sr_rls) loss.backward() optimizer.step()
def train(model: nn.Module, iterator: BucketIterator, optimizer: optim, criterion: nn.Module, clip: float) -> float: """ Trains the NCN model for a single epoch. Based on: https://github.com/bentrevett/pytorch-seq2seq. ## Parameters: - **model** *(nn.Module)*: The model optimized by this function. - **iterator** *(BucketIterator)*: Bucketized iterator containing the training data. - **optimizer** *(optim)*: Torch gradient descent optimizer used to train the model. - **criterion** *(nn.Module.loss)*: Loss function for training the model. - **clip** *(int)*: Apply gradient clipping at the given value. ## Output: - **loss** *(float)*: Epoch loss. """ model.train() epoch_loss = 0 for i, batch in enumerate(iterator): # unpack and move to GPU if available cntxt, citing, ttl, cited = batch.context, batch.authors_citing, batch.title_cited, batch.authors_cited cntxt = cntxt.to(DEVICE) citing = citing.to(DEVICE) ttl = ttl.to(DEVICE) cited = cited.to(DEVICE) optimizer.zero_grad() output = model(context=cntxt, title=ttl, authors_citing=citing, authors_cited=cited) #ttl = [trg sent len, batch size] #output = [trg sent len, batch size, output dim] output = output[1:].view(-1, output.shape[-1]) ttl = ttl[1:].view(-1) #ttl = [(trg sent len - 1) * batch size] #output = [(trg sent len - 1) * batch size, output dim] loss = criterion(output, ttl) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() epoch_loss += loss.item() return epoch_loss / len(iterator)
def pretrain_value_model(model: torch.nn.Module, optimizer: torch.optim, epochs=1, single_batch=False): print('Pretraining Value model') data = minerl.data.make('MineRLTreechop-v0', data_dir=os.environ['DATASET_DIR']) criterion = MSELoss() for idx, (frames, target_rewards) in enumerate(next_batch(data, epochs, 512), start=1): frames = frames.to(DEVICE) target_rewards = target_rewards.to(DEVICE) # Clear gradients optimizer.zero_grad() prediction = model(frames) loss = criterion(prediction.squeeze(), target_rewards) loss.backward() optimizer.step() if single_batch: del data return
def train(model: nn.Module, train_loader: torch.utils.data.dataloader.DataLoader, optimizer: torch.optim, epoch: int): train_loss = 0 train_loss_list = [] batch_list = [] num_data = 0 device = torch_device(model) model.train() for X, target in train_loader: batch_size = X.size(0) num_data += batch_size X, target = X.to(device), target.to(device) output = model(X) loss = _loss_DeepAnT(output, target) train_loss += loss.item() train_loss_list.append(loss.item()) batch_list.append(epoch-1 + (num_data / len(train_loader.sampler))) # backpropagation and weight update optimizer.zero_grad() loss.backward() optimizer.step() avg_train_loss = train_loss / num_data return avg_train_loss, train_loss_list, batch_list
def train(model, targeted: bool, attacked_nodes: torch.Tensor, y_targets: torch.Tensor, optimizer: torch.optim): """ trains the attack for one epoch Parameters ------- model: Model targeted: bool attacked_nodes: torch.Tensor y_targets: torch.Tensor - the target labels of the attack optimizer: torch.optim """ model.train() optimizer.zero_grad() attacked_nodes = [attacked_nodes.item()] model_output = model()[attacked_nodes] if torch.sum(model_output - model_output[:y_targets.shape[0], y_targets]) == 0: model.eval() model_output = model()[attacked_nodes] loss = F.nll_loss(model_output, y_targets) loss = loss if targeted else -loss loss.backward() optimizer.step() model.eval()
def train(model, optimizer: torch.optim, data: torch_geometric.data.Data, attacked_nodes: torch.Tensor, attacked_x: torch.Tensor, adv_scale: int = 1): """ trains the model with both losses - clean and adversarial, for one epoch Parameters ---------- model: Model optimizer: torch.optim data: torch_geometric.data.Data attacked_nodes: torch.Tensor - the victim nodes attacked_x: torch.Tensor - the feature matrices after the attack adv_scale: int - the lambda scale hyperparameter between the two losses """ model.train() optimizer.zero_grad() basic_loss = F.nll_loss(model()[data.train_mask], data.y[data.train_mask]) adv_loss = F.nll_loss( model(attacked_x)[attacked_nodes], data.y[attacked_nodes]) loss = basic_loss + adv_scale * adv_loss loss.backward() optimizer.step() model.eval()
def train(model: torch.nn.Module, preprocessing: PreProcessing, optimizer: torch.optim, loss_fn) -> int: model.train() train_loss = [] for batch_idx, (inputs, target) in enumerate( preprocessing.dataloader.train_loader): inputs, target = inputs.to(device), target.to(device) if args.model_type == 'vae': output, mu, logvar = model(inputs.float()) loss_vector = cl.vae_loss(output.float(), target.float(), mu, logvar) loss_per_dim = torch.sum(loss_vector, dim=0) else: output = model(inputs.float()) loss_per_dim = loss_fn(output.float(), target.float()) train_loss.append(sum(loss_per_dim) / len(loss_per_dim)) count = 0 for loss in loss_per_dim: loss.backward(retain_graph=True) optimizer.step() count += 1 mean_loss = sum(train_loss) / batch_idx + 1 mean_loss = mean_loss.detach() return mean_loss
def train(model: nn.Module, optimizer: optim, loss_fn, train_loader: DataLoader, test_loader: DataLoader, params: utils.Params, epoch: int) -> float: '''Train the model on one epoch by batches. Args: model: (torch.nn.Module) the neural network optimizer: (torch.optim) optimizer for parameters of model loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch train_loader: load train data and labels test_loader: load test data and labels params: (Params) hyperparameters epoch: (int) the current training epoch ''' model.train() loss_epoch = np.zeros(len(train_loader)) # Train_loader: # train_batch ([batch_size, train_window, 1+cov_dim]): z_{0:T-1} + x_{1:T}, note that z_0 = 0; # idx ([batch_size]): one integer denoting the time series id; # labels_batch ([batch_size, train_window]): z_{1:T}. for i, (train_batch, idx, labels_batch) in enumerate(tqdm(train_loader)): optimizer.zero_grad() batch_size = train_batch.shape[0] train_batch = train_batch.permute(1, 0, 2).to(torch.float32).to( params.device) # not scaled labels_batch = labels_batch.permute(1, 0).to(torch.float32).to( params.device) # not scaled idx = idx.unsqueeze(0).to(params.device) loss = torch.zeros(1, device=params.device) hidden = model.init_hidden(batch_size) cell = model.init_cell(batch_size) for t in range(params.train_window): # if z_t is missing, replace it by output mu from the last time step zero_index = (train_batch[t, :, 0] == 0) if t > 0 and torch.sum(zero_index) > 0: train_batch[t, zero_index, 0] = mu[zero_index] mu, sigma, hidden, cell = model( train_batch[t].unsqueeze_(0).clone(), idx, hidden, cell) loss += loss_fn(mu, sigma, labels_batch[t]) loss.backward() optimizer.step() loss = loss.item() / params.train_window # loss per timestep loss_epoch[i] = loss if i % 1000 == 0: test_metrics = evaluate(model, loss_fn, test_loader, params, epoch, sample=args.sampling) model.train() logger.info(f'train_loss: {loss}') if i == 0: logger.info(f'train_loss: {loss}') return loss_epoch
def loadCheckpoint(checkpoint_path: str, model: nn.Module, optimizer: optim, scheduler: optim.lr_scheduler.MultiStepLR): state = torch.load(checkpoint_path) resume_epoch = state['epoch'] model.load_state_dict(state['state_dict']) optimizer.load_state_dict(state['optimizer']) scheduler.load_state_dict(state['scheduler']) return model, optimizer, resume_epoch, scheduler
def train(self, model: torchvision.models, criterion: torch.nn, optimizer: torch.optim, train_dataset: ImageFoldersDataset, test_dataset: ImageFoldersDataset, n_epochs: int = 25, batch_size: int = 32, shuffle: bool = True, *args, **kwargs): # TODO(lukasz): add scheduler for learning rate metrics = defaultdict(list) best_score_test = 0. for epoch in range(n_epochs): model.train() running_loss = 0. for data_idx, data in enumerate( train_dataset.loader( batch_size=batch_size, shuffle=shuffle # TODO(lukasz): add sampler for imbalanced dataset )): inputs, labels = data inputs = inputs.to(self.device) labels = labels.to(self.device) optimizer.zero_grad() model = model.to(self.device) outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() # TODO(lukasz): add as argument if data_idx % 100 == 0: msg = '[%d, %5d] loss: %.3f' print(msg % (epoch + 1, data_idx + 1, running_loss / 100)) running_loss = 0. score_train = self.score(model, train_dataset) score_test = self.score(model, test_dataset) metrics['score_train'].append(score_train) metrics['score_test'].append(score_test) msg = '[%d] train score: %.3f, test score: %.3f' print(msg % (epoch + 1, score_train, score_test)) # save model (make sure that Google Colab do not destroy your results) if score_test > best_score_test: torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, self.save_experiment) best_score_test = score_test self.metrics = metrics return self
def load_checkpoint(self, model: torch.nn.Module, optimizer: torch.optim): state = torch.load(self.state_dir) try: model.load_state_dict(state['model_state_dict']) except RuntimeError: new_state_dict = OrderedDict() for k, v in state['model_state_dict'].items(): name = k[7:] new_state_dict[name] = v model.load_state_dict(new_state_dict) optimizer.load_state_dict(state['optimizer_state_dict']) return model, optimizer
def normal_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader): model.train() epoch_loss = 0 for input, target in data_loader: input, target = variable(input), variable(target) optimizer.zero_grad() output = model(input) loss = F.cross_entropy(output, target) epoch_loss += loss.data[0] loss.backward() optimizer.step() return epoch_loss / len(data_loader)
def model_save(model: torch.nn.Module, encoder_optimizer: torch.optim, decoder_optimizer: torch.optim, loss, latent_dim, ckpt_dir): torch.save( { 'model_state_dict': model.state_dict(), 'encoder_optimizer_state_dict': encoder_optimizer.state_dict(), 'decoder_optimizer_state_dict': decoder_optimizer.state_dict(), 'loss': loss, 'latent_dim': latent_dim, 'model': model }, ckpt_dir)
def clf_fit(net: nn.Module, crit: nn.Module, opti: torch.optim, tloader, vloader, **kwargs): """ This function is used to train the classification networks. """ epochs = kwargs['epochs'] lr = kwargs['lr'] lr_step = kwargs['lr_step'] lr_decay = kwargs['lr_decay'] seed = kwargs['seed'] if kwargs['seed'] else np.random.randint(100) bloss = float('inf') torch.manual_seed(seed) np.random.seed(seed) print('[INFO] Setting torch seed to {}'.format(seed)) device = 'cuda' if torch.cuda.is_available() else 'cpu' tlist = [] vlist = [] for e in range(1, epochs + 1): if lr_step is not None and type(lr_step) == int and e % lr_step == 0: lr = adjust_lr(opti, lr, lr_decay) if lr_step is not None and type(lr_step) == list and e in lr_step: lr = adjust_lr(opti, lr, lr_decay) tacc, tloss = clf_train(net, tloader, opti, crit, topk=kwargs['topk']) vacc, vloss = clf_test(net, vloader, crit, topk=kwargs['topk']) tlist.append((tacc, tloss)) vlist.append((vacc, vloss)) if vloss < bloss: bloss = vloss torch.save({ 'net': net.state_dict(), 'opti': opti.state_dict() }, 'best_net-{}-{:.2f}.pth'.format(e, vacc[0])) # TODO The tloss and vloss needs a recheck. print('Epoch: {}/{} - Train Loss: {:.3f} - Train Acc@1: {:.3f}' '- Train Acc@5: {:.3f} - Val Loss: {:.3f} - Val Acc@1: {:.3f}' '- Val Acc@5: {:.3f}'.format(e, epochs, tloss, tacc[0], tacc[1], vloss, vacc[0], vacc[1])) torch.save({ 'net': net.cpu().state_dict(), 'opti': opti.state_dict() }, 'net-{}-{:.2f}.pth'.format(e, vacc[0])) return tlist, vlist
def train(model: nn.Module, device, train_loader: DataLoader, optimizer: torch.optim, epoch): model.train() for batch_id, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.cross_entropy(output, target) loss.backward() optimizer.step() if batch_id % 10 == 0: print(f"batch: {batch_id} epoch: {epoch} loss: {loss.item()}")
def load_model_checkpoint(model: torch.nn.Module, filename: str, inference: bool, map_location=None, optimizer: torch.optim = None): checkpoint = torch.load(filename, map_location=map_location) if optimizer: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if inference: model.eval() else: model.train() return model.load_state_dict(checkpoint['model_state_dict'])
def ewc_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewc: EWC, importance: float): model.train() epoch_loss = 0 for input, target in data_loader: input, target = variable(input), variable(target) optimizer.zero_grad() output = model(input) loss = F.cross_entropy(output, target) + importance * ewc.penalty(model) epoch_loss += loss.data[0] loss.backward() optimizer.step() return epoch_loss / len(data_loader)