def __call__( self, net: nn.Module, input_names: List[str], data_loaders ) -> None: optimizer = torch.optim.Adagrad( net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, ) writer = SummaryWriter(self.tensorboard_path) timer = Timer() training_iter = iter(data_loaders["training_data_loader"]) full_batch_iter = iter(data_loaders["full_batch_loader"]) avg_epoch_grad = 0.0 for epoch_no in range(self.epochs): if self.decreasing_step_size: for param_group in optimizer.param_groups: param_group["lr"] *= 1 / math.sqrt(epoch_no + 1) for batch_no in range(self.num_batches_per_epoch): with timer("gradient oracle"): data_entry = next(training_iter) optimizer.zero_grad() inputs = [ data_entry[k].to(self.device) for k in input_names ] loss = self.inference(net, inputs) loss.backward() optimizer.step() # compute the gradient norm and loss over training set avg_epoch_loss = 0.0 full_batch_iter = iter(data_loaders["full_batch_loader"]) net.zero_grad() for i, data_entry in enumerate(full_batch_iter): inputs = [data_entry[k].to(self.device) for k in input_names] loss = self.inference(net, inputs) loss.backward() avg_epoch_loss += loss.item() avg_epoch_loss /= i + 1 epoch_grad = 0.0 for p in net.parameters(): if p.grad is None: continue epoch_grad += torch.norm(p.grad.data / (i + 1)).item() net.zero_grad() # compute the validation loss validation_loss = None if self.eval_model and epoch_no % self.validation_freq == 0: validation_iter = iter(data_loaders["validation_data_loader"]) validation_loss = 0.0 with torch.no_grad(): for i, data_entry in enumerate(validation_iter): net.zero_grad() inputs = [ data_entry[k].to(self.device) for k in input_names ] loss = self.inference(net, inputs) validation_loss += loss.item() validation_loss /= i + 1 num_iters = ( self.num_batches_per_epoch * (epoch_no + 1) * self.batch_size ) avg_epoch_grad = (avg_epoch_grad * epoch_no + epoch_grad) / ( epoch_no + 1 ) time_in_ms = timer.totals["gradient oracle"] * 1000 writer.add_scalar( "gradnorm/iters", avg_epoch_grad, (epoch_no + 1) * self.num_batches_per_epoch, ) writer.add_scalar("gradnorm/grads", avg_epoch_grad, num_iters) writer.add_scalar("gradnorm/time", avg_epoch_grad, time_in_ms) writer.add_scalar( "train_loss/iters", avg_epoch_loss, (epoch_no + 1) * self.num_batches_per_epoch, ) writer.add_scalar("train_loss/grads", avg_epoch_loss, num_iters) writer.add_scalar("train_loss/time", avg_epoch_loss, time_in_ms) if self.eval_model and epoch_no % self.validation_freq == 0: writer.add_scalar( "val_loss/iters", validation_loss, (epoch_no + 1) * self.num_batches_per_epoch, ) writer.add_scalar("val_loss/grads", validation_loss, num_iters) writer.add_scalar("val_loss/time", validation_loss, time_in_ms) print( "\nTraining Loss: {:.4f}, Test Loss: {:.4f}\n".format( avg_epoch_loss, validation_loss ) ) else: print(f"\nTraining Loss: {avg_epoch_loss:.4f} \n") print("Epoch ", epoch_no, " is done!") writer.close() print( "task: " + self.task_name + " on Adagrad with lr=" + str(self.learning_rate) + " is done!" )
def __call__(self, net: nn.Module, input_names: List[str], data_loaders) -> None: optimizer = torch.optim.Adam( net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, ) writer = SummaryWriter(self.tensorboard_path) timer = Timer() training_iter = iter(data_loaders["training_data_loader"]) anchor_iter = iter(data_loaders["anchor_data_loader"]) group_ratio = (data_loaders["group_ratio"] if self.weighted_batch else None) avg_epoch_grad = 0.0 v_0_norm = 0.0 v_t_norm = 0.0 for epoch_no in range(self.epochs): if self.decreasing_step_size: for param_group in optimizer.param_groups: param_group["lr"] *= 1 / math.sqrt(epoch_no + 1) for batch_no in range(self.num_batches_per_epoch): iter_n = epoch_no * self.num_batches_per_epoch + batch_no if (iter_n == 0 or v_t_norm <= self.gamma * v_0_norm or iter_n % self.freq == 0): anchor_model = copy.deepcopy(net) sg_model = copy.deepcopy(net) anchor_model.zero_grad() with timer("gradient oracle"): data_entry = next(anchor_iter) inputs = [ data_entry[k].to(self.device) for k in input_names ] loss = self.inference( anchor_model, inputs, weighted_batch=self.weighted_batch, group_ratio=group_ratio, ) loss.backward() for p in anchor_model.parameters(): if p.grad is None: continue v_0_norm += torch.norm(p.grad.data)**2 v_t_norm = 0.0 data_entry = next(training_iter) optimizer.zero_grad() with timer("gradient oracle"): inputs = [ data_entry[k].to(self.device) for k in input_names ] inputs_ = copy.deepcopy(inputs) net.zero_grad() sg_model.zero_grad() loss = self.inference(sg_model, inputs) loss.backward() loss = self.inference(net, inputs_) loss.backward() with timer("gradient oracle"): for p1, p2, p3 in zip( net.parameters(), sg_model.parameters(), anchor_model.parameters(), ): if (p1.grad is None or p2.grad is None or p3.grad is None): continue v_t = torch.zeros_like(p1.grad.data, device=p1.device) v_t.add_(p1.grad.data - p2.grad.data + p3.grad.data) p1.grad.data.zero_().add_(v_t) v_t_norm += torch.norm(v_t)**2 optimizer.step() # compute the gradient norm and loss over training set avg_epoch_loss = 0.0 full_batch_iter = iter(data_loaders["full_batch_loader"]) net.zero_grad() for i, data_entry in enumerate(full_batch_iter): inputs = [data_entry[k].to(self.device) for k in input_names] loss = self.inference(net, inputs) loss.backward() avg_epoch_loss += loss.item() avg_epoch_loss /= i + 1 epoch_grad = 0.0 for p in net.parameters(): if p.grad is None: continue epoch_grad += torch.norm(p.grad.data / (i + 1)) net.zero_grad() # compute the validation loss if self.eval_model and epoch_no % self.validation_freq == 0: net_validate = copy.deepcopy(net) validation_iter = iter(data_loaders["validation_data_loader"]) validation_loss = 0.0 with torch.no_grad(): for i, data_entry in enumerate(validation_iter): net_validate.zero_grad() inputs = [ data_entry[k].to(self.device) for k in input_names ] loss = self.inference(net_validate, inputs) validation_loss += loss.item() validation_loss /= i + 1 num_iters = ( self.num_batches_per_epoch * (epoch_no + 1) * 2 * self.batch_size + self.num_batches_per_epoch / self.freq * self.num_strata * (epoch_no + 1) * self.batch_size) avg_epoch_grad = (avg_epoch_grad * epoch_no + epoch_grad) / (epoch_no + 1) time_in_ms = timer.totals["gradient oracle"] * 1000 writer.add_scalar( "gradnorm/iters", avg_epoch_grad, (epoch_no + 1) * self.num_batches_per_epoch, ) writer.add_scalar("gradnorm/grads", avg_epoch_grad, num_iters) writer.add_scalar("gradnorm/time", avg_epoch_grad, time_in_ms) writer.add_scalar( "train_loss/iters", avg_epoch_loss, (epoch_no + 1) * self.num_batches_per_epoch, ) writer.add_scalar("train_loss/grads", avg_epoch_loss, num_iters) writer.add_scalar("train_loss/time", avg_epoch_loss, time_in_ms) if self.eval_model and epoch_no % self.validation_freq == 0: writer.add_scalar( "val_loss/iters", validation_loss, (epoch_no + 1) * self.num_batches_per_epoch, ) writer.add_scalar("val_loss/grads", validation_loss, num_iters) writer.add_scalar("val_loss/time", validation_loss, time_in_ms) print("\nTraining Loss: {:.4f}, Test Loss: {:.4f}\n".format( avg_epoch_loss, validation_loss)) else: print("\nTraining Loss: {:.4f} \n".format(avg_epoch_loss)) print("Epoch ", epoch_no, " is done!") writer.close() print("task: " + self.task_name + " on SAdam with lr=" + str(self.learning_rate) + " is done!")