def train_fn(model: torch.nn, data_loader: DataLoader, optimizer: optim, device: torch.device, epoch: int): model.train() start_time = datetime.datetime.now() num_images: int = 0 for i, (images, targets) in enumerate(data_loader): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] images = torch.stack(images) num_images += len(images) optimizer.zero_grad() loss_dict: Dict[str, torch.Tensor] = model(images, targets) loss: float = sum(loss for loss in loss_dict.values()) loss.backward() optimizer.step() if (i + 1) % 10 == 0: print('-' * 50) print( f'Epoch {epoch+1}[{len(data_loader.dataset):,}/{(num_images/len(data_loader.dataset))*100:.2f}%] ' f'- Elapsed time: {datetime.datetime.now() - start_time}\n' f' - loss: classifier={loss_dict["loss_classifier"]:.6f}, box_reg={loss_dict["loss_box_reg"]:.6f}, ' f'objectness={loss_dict["loss_objectness"]:.6f}, rpn_box_reg={loss_dict["loss_rpn_box_reg"]:.6f}' )
def pretrain_value_model(model: torch.nn.Module, optimizer: torch.optim, epochs=1, single_batch=False): print('Pretraining Value model') data = minerl.data.make('MineRLTreechop-v0', data_dir=os.environ['DATASET_DIR']) criterion = MSELoss() for idx, (frames, target_rewards) in enumerate(next_batch(data, epochs, 512), start=1): frames = frames.to(DEVICE) target_rewards = target_rewards.to(DEVICE) # Clear gradients optimizer.zero_grad() prediction = model(frames) loss = criterion(prediction.squeeze(), target_rewards) loss.backward() optimizer.step() if single_batch: del data return
def train(model, optimizer: torch.optim, data: torch_geometric.data.Data, attacked_nodes: torch.Tensor, attacked_x: torch.Tensor, adv_scale: int = 1): """ trains the model with both losses - clean and adversarial, for one epoch Parameters ---------- model: Model optimizer: torch.optim data: torch_geometric.data.Data attacked_nodes: torch.Tensor - the victim nodes attacked_x: torch.Tensor - the feature matrices after the attack adv_scale: int - the lambda scale hyperparameter between the two losses """ model.train() optimizer.zero_grad() basic_loss = F.nll_loss(model()[data.train_mask], data.y[data.train_mask]) adv_loss = F.nll_loss( model(attacked_x)[attacked_nodes], data.y[attacked_nodes]) loss = basic_loss + adv_scale * adv_loss loss.backward() optimizer.step() model.eval()
def clf_train(net, tloader, opti: torch.optim, crit: nn.Module, **kwargs): # TODO Fix this if kwargs['topk'] != (1, 5): raise Exception('topk other than (1, 5) not supported for now.') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net.to(device) net.train() a1mtr = AvgMeter('train_acc1') a5mtr = AvgMeter('train_acc5') tloss = 0 try: crit = crit() except: pass for ii, (data, labl) in enumerate(tqdm(tloader)): data, labl = data.to(device), labl.to(device) out = net(data) loss = crit(out, labl) opti.zero_grad() loss.backward() opti.step() with torch.no_grad(): tloss += loss.item() acc1, acc5 = accuracy(out, labl, topk=kwargs['topk']) a1mtr(acc1, data.size(0)) a5mtr(acc5, data.size(0)) tloss /= len(tloader) return (a1mtr.avg, a5mtr.avg), tloss
def train(model: nn.Module, iterator: BucketIterator, optimizer: optim, criterion: nn.Module, clip: float) -> float: """ Trains the NCN model for a single epoch. Based on: https://github.com/bentrevett/pytorch-seq2seq. ## Parameters: - **model** *(nn.Module)*: The model optimized by this function. - **iterator** *(BucketIterator)*: Bucketized iterator containing the training data. - **optimizer** *(optim)*: Torch gradient descent optimizer used to train the model. - **criterion** *(nn.Module.loss)*: Loss function for training the model. - **clip** *(int)*: Apply gradient clipping at the given value. ## Output: - **loss** *(float)*: Epoch loss. """ model.train() epoch_loss = 0 for i, batch in enumerate(iterator): # unpack and move to GPU if available cntxt, citing, ttl, cited = batch.context, batch.authors_citing, batch.title_cited, batch.authors_cited cntxt = cntxt.to(DEVICE) citing = citing.to(DEVICE) ttl = ttl.to(DEVICE) cited = cited.to(DEVICE) optimizer.zero_grad() output = model(context=cntxt, title=ttl, authors_citing=citing, authors_cited=cited) #ttl = [trg sent len, batch size] #output = [trg sent len, batch size, output dim] output = output[1:].view(-1, output.shape[-1]) ttl = ttl[1:].view(-1) #ttl = [(trg sent len - 1) * batch size] #output = [(trg sent len - 1) * batch size, output dim] loss = criterion(output, ttl) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() epoch_loss += loss.item() return epoch_loss / len(iterator)
def gradient_update( self, device: str, optimizer: torch.optim, gamma: float, batch: List[Tuple[np.ndarray, np.ndarray, int, float]] ) -> float: self.model.train() preds = self.model.forward_np_array( device=device, x=np.array([x[0] for x in batch]) ) labels = preds.clone().detach() labels = labels.to(device) next_frames_preds = self.model_target.forward_np_array( device=device, x=np.array([x[1] if x[1] is not None else x[0] for x in batch]) ).detach() for i, b in enumerate(batch): _, next_frame, action, reward = b if next_frame is None: # is it terminal state labels[i][action] = reward else: labels[i][action] = reward + gamma * max(next_frames_preds[i]) loss = self.criterion(preds, labels) optimizer.zero_grad() loss.backward() optimizer.step() return float(loss)
def train(model, targeted: bool, attacked_nodes: torch.Tensor, y_targets: torch.Tensor, optimizer: torch.optim): """ trains the attack for one epoch Parameters ------- model: Model targeted: bool attacked_nodes: torch.Tensor y_targets: torch.Tensor - the target labels of the attack optimizer: torch.optim """ model.train() optimizer.zero_grad() attacked_nodes = [attacked_nodes.item()] model_output = model()[attacked_nodes] if torch.sum(model_output - model_output[:y_targets.shape[0], y_targets]) == 0: model.eval() model_output = model()[attacked_nodes] loss = F.nll_loss(model_output, y_targets) loss = loss if targeted else -loss loss.backward() optimizer.step() model.eval()
def train(model, optimizer: torch.optim, data: torch_geometric.data.Data, perturbation, gamma): """ trains the model for one epoch Parameters ---------- model: Model optimizer: torch.optim data: torch_geometric.data.Data """ model.train() optimizer.zero_grad() y_hat, R = model.forward(perturbation=perturbation, grad_perturbation=False) # look here - what do we do. accuracy goes up, then down" y_hat = y_hat[model.data.train_mask] loss = F.nll_loss(y_hat, data.y[model.data.train_mask]) + gamma * R # loss = F.nll_loss(y_hat, data.y) + gamma * R loss.backward() optimizer.step() model.eval()
def train_embedding_model(model: nn.Module, data: iter, n_epochs: int, criterion: nn.modules.loss, optimizer: torch.optim) -> None: """ :param model - class inherited from nn.Module with DL model :param data - iterator for batching data :param n_epochs - number of epochs :param criterion - loss for model :param optimizer - optimizer from torch for model """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') for epoch in range(n_epochs): for y, cont_x, cat_x in data: cat_x = cat_x.to(device) cont_x = cont_x.to(device) y = y.to(device) preds = model(cont_x, cat_x) loss = criterion(preds, y) optimizer.zero_grad() loss.backward() optimizer.step() print(f'loss on epoch {epoch} is {loss}')
def train_loop(model, opt: torch.optim, train_loader, val_loader=None, sheduler=None, batch_size=2, dice_loss_beta=1., num_epochs=60, save=False, validate=True, device='cpu', save_name='unet'): training_scores = [] train_batch_gen = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True) if validate and val_loader is not None: validation_scores = [] val_batch_gen = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=False) for epoch in range(num_epochs): print(f'Epoch number: {epoch}') start_time = time.time() model.train() epoch_train_loss = [] for (X_image, X_mask) in train_batch_gen: X_image = X_image.to(device) pred_mask = model(X_image)[:, 0].contiguous().view(-1) true_mask = X_mask[:, 0].contiguous().view(-1).to(device) loss = dice_loss(true_mask, pred_mask, dice_loss_beta) loss.backward() opt.step() opt.zero_grad() if sheduler is not None: sheduler.step() epoch_train_loss.append(loss.data.cpu().numpy()) training_scores.append(np.mean(epoch_train_loss)) if validate: model.eval() masks_2_pred = predict_val(val_batch_gen, model, device) val_preds = np.vstack(list(masks_2_pred.values())) val_true = np.vstack(list(masks_2_pred.keys())) val_iou = calc_iou(val_preds, val_true) validation_scores.append(val_iou) print('validation iou is {}'.format(val_iou)) print(f'Training epoch loss: {training_scores[-1]}') print("Epoch {} of {} took {:.3f}s".format( epoch + 1, num_epochs, time.time() - start_time)) if save: torch.save(model.state_dict(), f'{save_name}') return (training_scores, validation_scores) if validate else training_scores
def train(self, epoch_s: int, epoch_e: int, data: Data, n_samples: int, optimizer: torch.optim, device: torch.device, strategy: str = 'max', mode: bool = True) -> None: train_time = time.time() prefix_sav = f'./model_save/WNGat_{train_time}' loss_list = [] super().train() negloss = NEGLoss(data.x, data.edge_index, n_samples) for epoch in range(epoch_s, epoch_e): optimizer.zero_grad() oup = self.forward(data.x, data.edge_index) loss = negloss(oup, data.edge_index) loss_list.append(loss.data) sr_params = {'oup': oup} sr_rls = sr_test(device, self.emb, strategy, **sr_params) save_model(epoch, self, optimizer, loss_list, prefix_sav, oup, sr=sr_rls) loss.backward() optimizer.step()
def train(model: nn.Module, train_loader: torch.utils.data.dataloader.DataLoader, optimizer: torch.optim, epoch: int): train_loss = 0 train_loss_list = [] batch_list = [] num_data = 0 device = torch_device(model) model.train() for X, target in train_loader: batch_size = X.size(0) num_data += batch_size X, target = X.to(device), target.to(device) output = model(X) loss = _loss_DeepAnT(output, target) train_loss += loss.item() train_loss_list.append(loss.item()) batch_list.append(epoch-1 + (num_data / len(train_loader.sampler))) # backpropagation and weight update optimizer.zero_grad() loss.backward() optimizer.step() avg_train_loss = train_loss / num_data return avg_train_loss, train_loss_list, batch_list
def train(model: nn.Module, optimizer: optim, loss_fn, train_loader: DataLoader, test_loader: DataLoader, params: utils.Params, epoch: int) -> float: '''Train the model on one epoch by batches. Args: model: (torch.nn.Module) the neural network optimizer: (torch.optim) optimizer for parameters of model loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch train_loader: load train data and labels test_loader: load test data and labels params: (Params) hyperparameters epoch: (int) the current training epoch ''' model.train() loss_epoch = np.zeros(len(train_loader)) # Train_loader: # train_batch ([batch_size, train_window, 1+cov_dim]): z_{0:T-1} + x_{1:T}, note that z_0 = 0; # idx ([batch_size]): one integer denoting the time series id; # labels_batch ([batch_size, train_window]): z_{1:T}. for i, (train_batch, idx, labels_batch) in enumerate(tqdm(train_loader)): optimizer.zero_grad() batch_size = train_batch.shape[0] train_batch = train_batch.permute(1, 0, 2).to(torch.float32).to( params.device) # not scaled labels_batch = labels_batch.permute(1, 0).to(torch.float32).to( params.device) # not scaled idx = idx.unsqueeze(0).to(params.device) loss = torch.zeros(1, device=params.device) hidden = model.init_hidden(batch_size) cell = model.init_cell(batch_size) for t in range(params.train_window): # if z_t is missing, replace it by output mu from the last time step zero_index = (train_batch[t, :, 0] == 0) if t > 0 and torch.sum(zero_index) > 0: train_batch[t, zero_index, 0] = mu[zero_index] mu, sigma, hidden, cell = model( train_batch[t].unsqueeze_(0).clone(), idx, hidden, cell) loss += loss_fn(mu, sigma, labels_batch[t]) loss.backward() optimizer.step() loss = loss.item() / params.train_window # loss per timestep loss_epoch[i] = loss if i % 1000 == 0: test_metrics = evaluate(model, loss_fn, test_loader, params, epoch, sample=args.sampling) model.train() logger.info(f'train_loss: {loss}') if i == 0: logger.info(f'train_loss: {loss}') return loss_epoch
def train(self, model: torchvision.models, criterion: torch.nn, optimizer: torch.optim, train_dataset: ImageFoldersDataset, test_dataset: ImageFoldersDataset, n_epochs: int = 25, batch_size: int = 32, shuffle: bool = True, *args, **kwargs): # TODO(lukasz): add scheduler for learning rate metrics = defaultdict(list) best_score_test = 0. for epoch in range(n_epochs): model.train() running_loss = 0. for data_idx, data in enumerate( train_dataset.loader( batch_size=batch_size, shuffle=shuffle # TODO(lukasz): add sampler for imbalanced dataset )): inputs, labels = data inputs = inputs.to(self.device) labels = labels.to(self.device) optimizer.zero_grad() model = model.to(self.device) outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() # TODO(lukasz): add as argument if data_idx % 100 == 0: msg = '[%d, %5d] loss: %.3f' print(msg % (epoch + 1, data_idx + 1, running_loss / 100)) running_loss = 0. score_train = self.score(model, train_dataset) score_test = self.score(model, test_dataset) metrics['score_train'].append(score_train) metrics['score_test'].append(score_test) msg = '[%d] train score: %.3f, test score: %.3f' print(msg % (epoch + 1, score_train, score_test)) # save model (make sure that Google Colab do not destroy your results) if score_test > best_score_test: torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, self.save_experiment) best_score_test = score_test self.metrics = metrics return self
def normal_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader): model.train() epoch_loss = 0 for input, target in data_loader: input, target = variable(input), variable(target) optimizer.zero_grad() output = model(input) loss = F.cross_entropy(output, target) epoch_loss += loss.data[0] loss.backward() optimizer.step() return epoch_loss / len(data_loader)
def train(model: nn.Module, device, train_loader: DataLoader, optimizer: torch.optim, epoch): model.train() for batch_id, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.cross_entropy(output, target) loss.backward() optimizer.step() if batch_id % 10 == 0: print(f"batch: {batch_id} epoch: {epoch} loss: {loss.item()}")
def ewc_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewc: EWC, importance: float): model.train() epoch_loss = 0 for input, target in data_loader: input, target = variable(input), variable(target) optimizer.zero_grad() output = model(input) loss = F.cross_entropy(output, target) + importance * ewc.penalty(model) epoch_loss += loss.data[0] loss.backward() optimizer.step() return epoch_loss / len(data_loader)
def fit(self, epochs: int, train_dl: DataLoader, test_dl: DataLoader, criterion: torch.nn, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler = None): train_losses = [] eval_losses = [] for epoch in tqdm(range(epochs), desc="Epochs"): # train self.train() batch_losses = [] batches = len(train_dl) for batch_input in tqdm(train_dl, total=batches, desc="- Remaining batches"): batch_input = [x.to(self.device) for x in batch_input] input_ids, att_masks, labels = batch_input # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = self(input_ids, att_masks) loss = criterion(outputs.squeeze(), labels) loss.backward() optimizer.step() if scheduler is not None: scheduler.step() batch_losses.append(loss.item()) train_loss = np.mean(batch_losses) self.last_train_loss = train_loss # evaluate tqdm.write(f"Epoch: {epoch+1}") _, eval_loss = self.evaluate(test_dl, criterion) train_losses.append(train_loss) eval_losses.append(eval_loss) return train_losses, eval_losses
def update_weights(optimizer: torch.optim, network: Network, batch): optimizer.zero_grad() p_loss, v_loss = 0, 0 image, actions, target_values, target_policies = batch image, actions, target_values, target_policies = torch.FloatTensor( image), torch.FloatTensor(actions), torch.FloatTensor( target_values), torch.FloatTensor(target_policies) image, actions, target_values, target_policies = image.to( device), actions.to(device), target_values.to( device), target_policies.to(device) # Initial step, from the real observation. value, policy, hidden_state = network.initial_inference(image) p_value, p_policy = [], [] p_value.append(value) p_policy.append(policy) # Recurrent steps, from action and previous hidden state. for action in actions: value, policy, hidden_state = network.recurrent_inference( hidden_state, action) p_value.append(value) p_policy.append(policy) p_value = torch.stack(p_value).squeeze() p_policy = torch.stack(p_policy) # p_value = p_value.view(config.batch_size, config.num_unroll_steps+1) # p_policy = p_policy.view(config.batch_size, config.num_unroll_steps+1, config.action_space_size) target_policies = target_policies.transpose(0, 1) p_policy = p_policy.transpose(0, 1) target_values = target_values.transpose(0, 1) p_value = p_value.transpose(0, 1) p_loss += torch.mean( torch.sum(-target_policies * torch.log(p_policy), dim=2)) v_loss += torch.mean(torch.sum((target_values - p_value)**2, dim=1)) total_loss = (p_loss + v_loss) total_loss.backward() optimizer.step() if network.steps % 10 == 0: print('step {}: p_loss {:.4f} v_loss {:.4f}'.format( network.steps, p_loss, v_loss)) network.steps += 1
def train(epoch: int, network: nn.Module, optimizer: torch.optim, train_loader: torch.utils.data.DataLoader, loss: nn): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') network.train() for data, target in train_loader: optimizer.zero_grad() output = network(data.to(device)) loss_value = loss(output, target.to(device)) loss_value.backward() optimizer.step() print('Train Epoch: {} Length {} \tLoss: {:.6f}'.format(epoch, len(train_loader), loss_value.item()))
def train(model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, optimizer: torch.optim, loss_fn) -> int: model.train() train_loss = [] for batch_idx, (inputs, target) in enumerate(train_loader.train_loader): inputs, target = inputs.to(device), target.to(device) optimizer.zero_grad() output = model(inputs.float()) loss = loss_fn(output, target) loss.backward() optimizer.step() train_loss.append(loss.detach()) return sum(train_loss) / len(train_loss)
def our_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float, gpu: torch.device, cut_idx, if_freeze): #还需要进行loss判断,true:freeze #---------------------freeze if if_freeze == 1: for idx, param in enumerate(model.parameters()): if idx >= cut_idx: continue param.requires_grad = False optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters())) # no need to add: lr=0.1? #---------------------- model.train() model.apply(set_bn_eval) #冻结BN及其统计数据 epoch_loss = 0 for data, target in data_loader: data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu) optimizer.zero_grad() output = model(data) for idx in range(output.size(1)): if idx not in labels: output[range(len(output)), idx] = 0 criterion = nn.CrossEntropyLoss() loss = criterion(output, target) # print('loss:', loss.item()) for ewc in ewcs: loss += (lam / 2) * ewc.penalty(model) # print('ewc loss:', loss.item()) epoch_loss += loss.item() loss.backward() optimizer.step() #-----------------------------解冻 if if_freeze == 1: for idx, param in enumerate(model.parameters()): if idx >= cut_idx: continue param.requires_grad = True optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters())) # no need to add: lr=0.1? #------------------------------- return epoch_loss / len(data_loader)
def train(model, optimizer: torch.optim, data: torch_geometric.data.Data): """ trains the model for one epoch Parameters ---------- model: Model optimizer: torch.optim data: torch_geometric.data.Data """ model.train() optimizer.zero_grad() F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() optimizer.step() model.eval()
def normal_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, epoch: int): model.train() epoch_loss = 0 for i, (input, target) in enumerate(data_loader): input, target = variable(input), variable(target) optimizer.zero_grad() output = model(input) loss = F.cross_entropy(output, target) epoch_loss += loss.data # loss.data[0] loss.backward() optimizer.step() #if (i + 1) % 100 == 0: # print('Loss: {:.4f}'.format(loss.item())) return epoch_loss / len(data_loader)
def normal_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, gpu: torch.device): model.train() model.apply(set_bn_eval) #冻结BN及其统计数据 epoch_loss = 0 for data, target in data_loader: data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu) optimizer.zero_grad() output = model(data) for idx in range(output.size(1)): if idx not in labels: output[range(len(output)), idx] = 0 criterion = nn.CrossEntropyLoss() loss = criterion(output, target) epoch_loss += loss.item() loss.backward() optimizer.step() return epoch_loss / len(data_loader)
def ewc_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewc: EWC, importance: float, epoch:int): model.train() epoch_loss = 0 for i, (input, target) in enumerate(data_loader): input, target = variable(input), variable(target) optimizer.zero_grad() output = model(input) loss = F.cross_entropy(output, target) + importance * ewc.penalty(model) epoch_loss += loss.data # [0] loss.backward() optimizer.step() #if (i + 1) % 100 == 0: # print('Loss: {:.4f}'.format(loss.item())) # rv = torch.true_divide(epoch_loss, len(data_loader)) return epoch_loss / len(data_loader)
def train_attr(model, optimizer_attr: torch.optim, data: torch_geometric.data.Data): model.train() optimizer_attr.zero_grad() labels = data.y.to(model.device) x, pos_edge_index = data.x, data.train_pos_edge_index _edge_index, _ = remove_self_loops(pos_edge_index) pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index, num_nodes=x.size(0)) neg_edge_index = negative_sampling( edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0), num_neg_samples=pos_edge_index.size(1)) F.nll_loss(model(pos_edge_index, neg_edge_index)[1][data.train_mask], labels[data.train_mask]).backward() optimizer_attr.step() model.eval()
def normal_train(model: nn.Module, opt: torch.optim, loss_func: torch.nn, data_loader: torch.utils.data.DataLoader, device): epoch_loss = 0 for i, (inputs, labels) in enumerate(data_loader): inputs = inputs.to(device).long() labels = labels.to(device).float() opt.zero_grad() output = model(inputs) loss = loss_func(output.view(-1), labels) epoch_loss += loss.item() loss.backward() opt.step() # return epoch_loss / len(data_loader) return loss
def ewc_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, fisher_info: dict, importance: float): model.train() epoch_loss = 0 params = {n: p.detach() for n, p in model.named_parameters()} for input, target in data_loader: input, target = variable(input), variable(target) optimizer.zero_grad() output = model(input) xent_loss = F.cross_entropy(output, target) ewc_loss = importance * ewc_penalty(params, model, fisher_info) loss = xent_loss + ewc_loss print('cls loss {}, ewc loss {}'.format(xent_loss, ewc_loss)) epoch_loss += loss.item() loss.backward() optimizer.step() return epoch_loss / len(data_loader)
def update_weights( optimizer: torch.optim, network: Network, data_loader, ): optimizer.zero_grad() p_loss, v_loss = 0, 0 for image, actions, target_values, target_rewards, target_policies in data_loader: image = image.to(device) # Initial step, from the real observation. net_output = network.initial_inference(image) predictions = [(1.0, net_output.value, net_output.reward, net_output.policy_logits)] hidden_state = net_output.hidden_state # Recurrent steps, from action and previous hidden state. for action in actions: action = action.to(device) net_output = network.recurrent_inference(hidden_state, action) predictions.append((1.0 / len(actions), net_output.value, net_output.reward, net_output.policy_logits)) hidden_state = net_output.hidden_state for prediction, target_value, target_reward, target_policy in zip( predictions, target_values, target_rewards, target_policies): target_value, target_reward, target_policy = target_value.to( device), target_reward.to(device), target_policy.to(device) _, value, reward, policy_logits = prediction p_loss += torch.mean( torch.sum(-target_policy * torch.log(policy_logits), dim=1)) v_loss += torch.mean(torch.sum((target_value - value)**2, dim=1)) total_loss = (p_loss + v_loss) total_loss.backward() optimizer.step() print('step %d: p_loss %f v_loss %f' % (network.steps % config.checkpoint_interval, p_loss, v_loss)) network.steps += 1