def train_step( model: FlowModel, config: TrainConfig, action: ActionFn, optimizer: optim.Optimizer, batch_size: int, scheduler: Any = None, scaler: GradScaler = None, pre_model: FlowModel = None, dkl_factor: float = 1., xi: torch.Tensor = None, ): """Perform a single training step. TODO: Add `torch.device` to arguments for DDP. """ t0 = time.time() # layers, prior = model['layers'], model['prior'] optimizer.zero_grad() loss_dkl = torch.tensor(0.0) if torch.cuda.is_available(): loss_dkl = loss_dkl.cuda() if pre_model is not None: pre_xi = pre_model.prior.sample_n(batch_size) x = qed.ft_flow(pre_model.layers, pre_xi) xi = qed.ft_flow_inv(pre_model.layers, x) # with torch.cuda.amp.autocast(): x, xi, logq = apply_flow_to_prior(model.prior, model.layers, xi=xi, batch_size=batch_size) logp = (-1.) * action(x) dkl = calc_dkl(logp, logq) ess = calc_ess(logp, logq) qi = qed.batch_charges(xi) q = qed.batch_charges(x) plaq = logp / (config.beta * config.volume) dq = torch.sqrt((q - qi) ** 2) loss_dkl = dkl_factor * dkl if scaler is not None: scaler.scale(loss_dkl).backward() scaler.step(optimizer) scaler.update() else: loss_dkl.backward() optimizer.step() if scheduler is not None: scheduler.step(loss_dkl) metrics = { 'dt': time.time() - t0, 'ess': grab(ess), 'logp': grab(logp), 'logq': grab(logq), 'loss_dkl': grab(loss_dkl), 'q': grab(q), 'dq': grab(dq), 'plaq': grab(plaq), } return metrics
def train_model( train_dl: data.DataLoader, dev_dl: data.DataLoader, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler, args: argparse.Namespace, ) -> nn.Module: device = model_utils.get_device() # loss_fn = nn.functional.binary_cross_entropy loss_fn = model_utils.l1_norm_loss val_loss_fn = model_utils.l1_norm_loss best_val_loss = torch.tensor(float('inf')) saved_checkpoints = [] writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}') scalar_rand = torch.distributions.uniform.Uniform(0.5, 1.5) for e in range(1, args.train_epochs + 1): print(f'Training epoch {e}...') # Training portion torch.cuda.empty_cache() with tqdm(total=args.train_batch_size * len(train_dl)) as progress_bar: model.train() for i, (x_batch, y_batch_biden, y_batch_trump, _) in enumerate(train_dl): # trump_scale = scalar_rand.sample() # biden_scale = scalar_rand.sample() # y_batch_biden = y_batch_biden * biden_scale # y_batch_trump = y_batch_trump * trump_scale # x_batch = (y_batch_trump + y_batch_biden).abs().to(device) x_batch = x_batch.abs().to(device) y_batch_biden = y_batch_biden.abs().to(device) y_batch_trump = y_batch_trump.abs().to(device) # Forward pass on model optimizer.zero_grad() y_pred_b, y_pred_t = model(x_batch) if args.train_trump: # loss = loss_fn(y_pred_t * x_batch, y_batch_trump) loss = loss_fn(y_pred_t, y_batch_trump) else: # loss = loss_fn(y_pred_b * x_batch, y_batch_biden) loss = loss_fn(y_pred_b, y_batch_biden) # Backward pass and optimization loss.backward() optimizer.step() if args.use_scheduler: lr_scheduler.step(loss) progress_bar.update(len(x_batch)) progress_bar.set_postfix(loss=loss.item()) writer.add_scalar("train/Loss", loss, ((e - 1) * len(train_dl) + i) * args.train_batch_size) del x_batch del y_batch_biden del y_batch_trump del y_pred_b del y_pred_t del loss # Validation portion torch.cuda.empty_cache() with tqdm(total=args.val_batch_size * len(dev_dl)) as progress_bar: model.eval() val_loss = 0.0 num_batches_processed = 0 for i, (x_batch, y_batch_biden, y_batch_trump, _) in enumerate(dev_dl): x_batch = x_batch.abs().to(device) y_batch_biden = y_batch_biden.abs().to(device) y_batch_trump = y_batch_trump.abs().to(device) # Forward pass on model y_pred_b, y_pred_t = model(x_batch) # y_pred_b_mask = torch.ones_like(y_pred_b) * (y_pred_b > args.alpha) # y_pred_t_mask = torch.ones_like(y_pred_t) * (y_pred_t > args.alpha) y_pred_b_mask = torch.clamp(y_pred_b / x_batch, 0, 1) y_pred_t_mask = torch.clamp(y_pred_t / x_batch, 0, 1) loss_trump = val_loss_fn(y_pred_t_mask * x_batch, y_batch_trump) loss_biden = val_loss_fn(y_pred_b_mask * x_batch, y_batch_biden) if args.train_trump: val_loss += loss_trump.item() else: val_loss += loss_biden.item() num_batches_processed += 1 progress_bar.update(len(x_batch)) progress_bar.set_postfix(val_loss=val_loss / num_batches_processed) writer.add_scalar("Val/Biden Loss", loss_biden, ((e - 1) * len(dev_dl) + i) * args.val_batch_size) writer.add_scalar("Val/Trump Loss", loss_trump, ((e - 1) * len(dev_dl) + i) * args.val_batch_size) del x_batch del y_batch_biden del y_batch_trump del y_pred_b del y_pred_t del loss_trump del loss_biden # Save model if it's the best one yet. if val_loss / num_batches_processed < best_val_loss: best_val_loss = val_loss / num_batches_processed filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint' model_utils.save_model(model, filename) print(f'Model saved!') print(f'Best validation loss yet: {best_val_loss}') # Save model on checkpoints. if e % args.checkpoint_freq == 0: filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint' model_utils.save_model(model, filename) print(f'Model checkpoint reached!') saved_checkpoints.append(filename) # Delete checkpoints if there are too many while len(saved_checkpoints) > args.num_checkpoints: os.remove(saved_checkpoints.pop(0)) return model
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int, device, print_freq, display=False): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() # switch to train mode model.train() for i, (inputs, targets) in enumerate(loader): # measure data loading time data_time.update(time.time() - end) inputs = inputs.to(device) targets = targets.to(device) # compute output outputs = model(inputs) loss = criterion(outputs, targets) # measure accuracy and record loss acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0 and display == True: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) return (losses.avg, top1.avg)
def update_theta(epoch: int, baseline: MovingAverageMetric, entropy_coeff, grad_clip: int, data_loader, device: str, master_pair: MasterPairs, architecture: NASNetwork, optimizer: optim.Optimizer, writer: SummaryWriter, log_frequency: int): start = datetime.now() policy_loss_metric = AverageMetric() accuracy_metric = AccuracyMetric(topk=(1, 5)) normal_logp_metric = AverageMetric() node_normal_entropy_metric = AverageMetric() op_normal_entropy_metric = AverageMetric() reduced_logp_metric = AverageMetric() node_reduced_entropy_metric = AverageMetric() op_reduced_entropy_metric = AverageMetric() node_normal_entropy_coeff, node_reduced_entropy_coeff, \ op_normal_entropy_coeff, op_reduced_entropy_coeff = [entropy_coeff, ]*4 master_pair.unset_force_uniform() for iter_, (datas, targets) in enumerate(data_loader, start=1): datas, targets = datas.to(device=device), targets.to(device=device) (normal_arch, normal_logp, node_normal_entropy, op_normal_entropy), \ (reduced_arch, reduced_logp, node_reduced_entropy, op_reduced_entropy) = master_pair() with torch.no_grad(): outputs = architecture(datas, normal_arch, reduced_arch) accuracy_metric.update(targets, outputs) accuracy_1 = accuracy_metric.last_accuracy(1).rate baseline.update(accuracy_1) reward = accuracy_1 - baseline.value policy_loss = -(normal_logp + reduced_logp) * reward \ - (node_normal_entropy*node_normal_entropy_coeff + op_normal_entropy*op_normal_entropy_coeff + node_reduced_entropy*node_reduced_entropy_coeff + op_reduced_entropy*op_reduced_entropy_coeff) optimizer.zero_grad() policy_loss.backward() if grad_clip is not None: nn.utils.clip_grad_norm_(master_pair.parameters(), grad_clip) optimizer.step() # update metrics policy_loss_metric.update(policy_loss) normal_logp_metric.update(normal_logp) node_normal_entropy_metric.update(node_normal_entropy) op_normal_entropy_metric.update(op_normal_entropy) reduced_logp_metric.update(reduced_logp) node_reduced_entropy_metric.update(node_reduced_entropy) op_reduced_entropy_metric.update(op_reduced_entropy) # iteration log if iter_ % log_frequency == 0 or iter_ == len(data_loader): message = f"UPDATE THETA, epoch={epoch:03d}, Iter={iter_}/{len(data_loader)}, " message += f"reward={reward:.4f}, " message += f"pocily loss={policy_loss_metric.last:.4f}({policy_loss_metric.value:.4f}), " message += f"moving accuracy={baseline.value*100:.2f}%, " message += f"normal_logp={normal_logp_metric.last:.4f}({normal_logp_metric.value:.4f}), " message += f"node_normal_entropy={node_normal_entropy_metric.last:.4f}({node_normal_entropy_metric.value:.4f}), " message += f"op_normal_entropy={op_normal_entropy_metric.last:.4f}({op_normal_entropy_metric.value:.4f}), " message += f"reduced_logp={reduced_logp_metric.last:.4f}({reduced_logp_metric.value:.4f}), " message += f"node_reduced_entropy={node_reduced_entropy_metric.last:.4f}({node_reduced_entropy_metric.value:.4f}), " message += f"op_reduced_entropy={op_reduced_entropy_metric.last:.4f}({op_reduced_entropy_metric.value:.4f})." if iter_ == len(data_loader): message += f" Eplased time={datetime.now()-start}." utils.logger.info(message) writer.add_scalar("update_theta/policy_loss", policy_loss_metric.value, epoch) writer.add_scalar("update_theta/baseline", baseline.value, epoch) writer.add_scalar("update_theta/accuracy@1", accuracy_metric.accuracy(1).rate, epoch) writer.add_scalar("update_theta/accuracy@5", accuracy_metric.accuracy(5).rate, epoch) writer.add_scalar("update_theta/normal_logp", normal_logp_metric.value, epoch) writer.add_scalar("update_theta/node_normal_entropy", node_normal_entropy_metric.value, epoch) writer.add_scalar("update_theta/op_normal_entropy", op_normal_entropy_metric.value, epoch) writer.add_scalar("update_theta/reduced_logp", reduced_logp_metric.value, epoch) writer.add_scalar("update_theta/node_reduced_entropy", node_reduced_entropy_metric.value, epoch) writer.add_scalar("update_theta/op_reduced_entropy", op_reduced_entropy_metric.value, epoch)
def train( model: nn.Module, num_epochs: int, dataloader: DataLoader, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None, num_gradient_accumulation_steps: Optional[int] = 1, max_gradient_norm: Optional[float] = None, device: Optional[torch.device] = torch.device('cpu'), local_rank: Optional[int] = 0, use_distributed: Optional[bool] = False, is_master: Optional[bool] = True, use_tqdm: Optional[bool] = True, logger: Optional[Logger] = None, ) -> None: # put model in train mode model.train() # keep track of the last loss last_loss = 0 for epoch in range(num_epochs): # synchronize all processes if use_distributed: dist.barrier() if is_master and logger is not None: logger.info(f'Starting with epoch {epoch+1}/{num_epochs}') # initialize the progress bar if is_master and use_tqdm: pbar = tqdm( desc=f'Training [epoch {epoch+1}/{num_epochs}]', total=len(dataloader), unit='batch', ) for step, batch in enumerate(dataloader): # unpack batch sequences, attention_masks, _, start_positions, end_positions, _, _, _ = batch # send sequences, attention_masks, start_positions and end_positions to device sequences = sequences.to(device) attention_masks = attention_masks.to(device) start_positions = start_positions.to(device) end_positions = end_positions.to(device) # forward pass (loss computation included) outputs = model(input_ids=sequences, attention_mask=attention_masks, start_positions=start_positions, end_positions=end_positions) loss = outputs[0] last_loss = loss.item() if use_distributed: loss = loss.mean() # rescale the loss loss /= num_gradient_accumulation_steps # backward pass loss.backward() if step % num_gradient_accumulation_steps == 0: # clip the gradient if max_gradient_norm is not None: clip_grad_norm_(model.parameters(), max_gradient_norm) # update the parameters optimizer.step() if lr_scheduler is not None: lr_scheduler.step() # clear all gradients optimizer.zero_grad() # update the progress bar if is_master and use_tqdm: pbar.update() pbar.set_postfix({'last_loss': last_loss}) # close the progress bar if is_master and use_tqdm: pbar.close()
def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None, validate_interval=10, save=False, verbose=True, indent=0, **kwargs): loader_train = self.dataset.loader['train'] file_path = os.path.join(self.folder_path, self.get_filename() + '.pth') _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) for _epoch in range(epoch): losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() if verbose and env['tqdm']: loader_train = tqdm(loader_train) self.model.activate_params(params) optimizer.zero_grad() for data in loader_train: _input, _label = self.model.get_data(data) noise = torch.zeros_like(_input) def loss_fn(X: torch.FloatTensor): return -self.model.loss(X, _label) adv_x = _input self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() for m in range(self.pgd.iteration): self.model.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1) optimizer.zero_grad() self.model.train() loss = self.model.loss(adv_x, _label) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.model.get_logits(_input) acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) self.model.eval() self.model.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30), f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs) if cur_acc < best_acc: prints('best result update!', indent=indent) prints(f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: self.model.save(file_path=file_path, verbose=verbose) if verbose: print('-' * 50) self.model.zero_grad()
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int, args): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() # switch to train mode model.train() m = Bernoulli(torch.tensor([args.calibrated_alpha]).cuda()) for i, (inputs, targets) in enumerate(loader): # measure data loading time data_time.update(time.time() - end) inputs = inputs.cuda() targets = targets.cuda() # make MNIST binary if args.dataset == 'mnist': inputs = (inputs > 0.5).type(torch.cuda.FloatTensor) # augment inputs with noise if args.perturb == 'bernoulli': mask = m.sample(inputs.shape).squeeze(-1) # make sure that the value is normalized rand_inputs = torch.randint_like( inputs, low=0, high=args.K + 1, device='cuda') / float(args.K) inputs = inputs * mask + rand_inputs * (1 - mask) elif args.perturb == 'gaussian': inputs = inputs + torch.randn_like(inputs, device='cuda') * args.sigma # compute output outputs = model(inputs) loss = criterion(outputs, targets) # measure accuracy and record loss acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if (i + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i + 1, len(loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) return (losses.avg, top1.avg)
def train(model: nn.Module, optimizer: optim.Optimizer, train_data: DataLoader, use_cuda: bool = True, scheduler=None, bce_weight: float = 1, mse_weight: float = 0.1, misclass_weight: float = 1, corclass_weight: float = 1 , threshold: float = 0.7, gci_threshold: float = 0.5): model.train() loss_sum = 0 bce_loss = 0 mse_loss = 0 gci_misclass = 0 misses = 0 bce_weight = Variable(th.Tensor([bce_weight])) mse_weight = Variable(th.Tensor([mse_weight])) misclass_weight = Variable(th.Tensor([misclass_weight])) corclass_weight = Variable(th.Tensor([corclass_weight])) thresh = Variable(th.Tensor([threshold])) gci_thresh = Variable(th.Tensor([gci_threshold])) batches = len(train_data) if use_cuda: if th.cuda.is_available(): model.cuda() else: print('Warning: GPU not available, Running on CPU') for data, target in train_data: if scheduler is not None: scheduler.step() if use_cuda: data, target = data.cuda(), target.cuda() bce_weight = bce_weight.cuda() mse_weight = mse_weight.cuda() misclass_weight = misclass_weight.cuda() corclass_weight = corclass_weight.cuda() thresh_val = thresh.cuda() gci_thresh = gci_thresh.cuda() data, target = Variable(data), Variable(target) optimizer.zero_grad() # print(len(data)) peak_distance_target = target[:, 0] peak_indicator_target = target[:, 1] output = model(data) distance = (output[:, 1]) probabilities = output[:, 0] loss_bce = F.binary_cross_entropy_with_logits(probabilities, peak_indicator_target) # print(loss_bce, loss_bce.mean()) loss_mse = (distance * peak_indicator_target - peak_distance_target * peak_indicator_target) ** 2 loss_mse = loss_mse.sum()/peak_indicator_target.sum() out = (F.sigmoid(probabilities) > gci_thresh).float() loss_misclass = (1 - peak_indicator_target) * ( F.sigmoid(probabilities)**2) # loss_misclass = (1 - peak_indicator_target) * (out) loss_misclass = loss_misclass.mean() misses_temp = (1 - peak_indicator_target) * out misses += misses_temp.mean().data[0] out = (F.sigmoid(probabilities) > gci_thresh).float() gci_misclass_temp = peak_indicator_target * (1 - out) gci_misclass += gci_misclass_temp.mean().data[0] loss_corrclass = peak_indicator_target * (( 1 - F.sigmoid(probabilities))**2) loss_corrclass = loss_corrclass.mean() net_loss = bce_weight * loss_bce + mse_weight * loss_mse loss_sum += net_loss.data[0] bce_loss += loss_bce.data[0] mse_loss += loss_mse.data[0] net_loss.backward() # TODO: Gradient Clipping optimizer.step() return loss_sum / batches , bce_loss / batches , mse_loss / batches , gci_misclass / batches, misses / batches
def train(model_G: nn.Module, model_D: nn.Module, optimizer_G: optim.Optimizer, optimizer_D: optim.Optimizer, train_data: DataLoader, use_cuda: bool = True): model_G.train() model_D.train() loss_sum = 0 loss_D = 0 loss_G = 0 D_real_prob = 0 D_fake_prob = 0 batches = len(train_data) if use_cuda: if th.cuda.is_available(): model_G.cuda() model_D.cuda() else: print('Warning: GPU not available, Running on CPU') for x_train, y_train in train_data: if use_cuda: y_train = y_train.type(th.LongTensor).cuda() x_train, y_train = x_train.cuda(), y_train.cuda() batch_size = x_train.shape[0] x_train, y_train = Variable(x_train), Variable(y_train) optimizer_G.zero_grad() optimizer_D.zero_grad() # Training the DISCRIMINATOR z = Variable(th.randn(batch_size, 1)).cuda() ones_label = Variable(th.ones(batch_size, 1)).cuda() zeros_label = Variable(th.zeros(batch_size, 1)).cuda() image = model_G(z) D_real = model_D(x_train) D_fake = model_D(image) D_loss_real = F.binary_cross_entropy(D_real, ones_label) D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label) D_loss = D_loss_real + D_loss_fake D_real_prob += D_real.mean().item() D_fake_prob += D_fake.mean().item() D_loss.backward() optimizer_D.step() optimizer_D.zero_grad() optimizer_G.zero_grad() # Training the GENERATOR for i in range(10): z = Variable(th.randn(batch_size, 1)).cuda() image = model_G(z) D_fake = model_D(image) G_loss = F.binary_cross_entropy(D_fake, ones_label) G_loss.backward() optimizer_G.step() optimizer_D.zero_grad() optimizer_G.zero_grad() loss_D += D_loss.item() loss_G += G_loss.item() loss_sum += loss_D + loss_G th.cuda.empty_cache() return loss_sum / batches, loss_D / batches, loss_G / batches, D_real_prob / batches, D_fake_prob / batches
def train(epoch: int, model: nn.Module, loader: data.DataLoader, criterion: nn.modules.loss._Loss, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, only_epoch_sche: bool, use_amp: bool, accmulated_steps: int, device: str, log_interval: int): model.train() scaler = GradScaler() if use_amp else None gradident_accumulator = GradientAccumulator(accmulated_steps) loss_metric = AverageMetric("loss") accuracy_metric = AccuracyMetric(topk=(1, 5)) ETA = EstimatedTimeArrival(len(loader)) speed_tester = SpeedTester() lr = optimizer.param_groups[0]['lr'] _logger.info(f"Train start, epoch={epoch:04d}, lr={lr:.6f}") for time_cost, iter_, (inputs, targets) in time_enumerate(loader, start=1): inputs, targets = inputs.to(device=device), targets.to(device=device) optimizer.zero_grad() with autocast(enabled=use_amp): outputs = model(inputs) loss: torch.Tensor = criterion(outputs, targets) gradident_accumulator.backward_step(model, loss, optimizer, scaler) if scheduler is not None: if only_epoch_sche: if iter_ == 1: scheduler.step() else: scheduler.step() loss_metric.update(loss) accuracy_metric.update(outputs, targets) ETA.step() speed_tester.update(inputs) if iter_ % log_interval == 0 or iter_ == len(loader): _logger.info(", ".join([ "TRAIN", f"epoch={epoch:04d}", f"iter={iter_:05d}/{len(loader):05d}", f"fetch data time cost={time_cost*1000:.2f}ms", f"fps={speed_tester.compute()*world_size():.0f} images/s", f"{loss_metric}", f"{accuracy_metric}", f"{ETA}", ])) speed_tester.reset() return { "lr": lr, "train/loss": loss_metric.compute(), "train/top1_acc": accuracy_metric.at(1).rate, "train/top5_acc": accuracy_metric.at(5).rate, }
def train(args, train_data_loader: DataLoader, valid_data_loader: DataLoader, model: Net, criterion, optimizer: optim.Optimizer, device): # save model if args.save_model: if not os.path.exists(args.save_directory): os.makedirs(args.save_directory) epochs = args.epochs train_losses = [] valid_losses = [] for epoch_id in range(epochs): # train_loss = 0.0 # valid_loss = 0.0 ###################### # training the model # ###################### model.train() train_batch_cnt = 0 train_mean_pts_loss = 0.0 for batch_idx, batch in enumerate(train_data_loader): train_batch_cnt += 1 img = batch['image'] landmark = batch['landmarks'] # ground truth input_img = img.to(device) target_pts = landmark.to(device) # clear the gradients of all optimized variables(torch.Tensor) optimizer.zero_grad() output_pts = model(input_img) loss = criterion(output_pts, target_pts) # do BP automatically loss.backward() optimizer.step() # 更新优化器中的参数 train_mean_pts_loss += loss.item() # show log info if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\t pts_loss: {:.6f}'. format(epoch_id, batch_idx * len(img), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) train_mean_pts_loss /= train_batch_cnt train_losses.append(train_mean_pts_loss) ####################### # validate the model # ####################### valid_mean_pts_loss = 0.0 model.eval() #prepare model for evaluation with torch.no_grad(): valid_batch_cnt = 0 for valid_batch_idx, batch in enumerate(valid_data_loader): valid_batch_cnt += 1 valid_img = batch['image'] landmark = batch['landmarks'] input_img = valid_img.to(device) target_pts = landmark.to(device) output_pts = model(input_img) valid_loss = criterion(output_pts, target_pts) valid_mean_pts_loss += valid_loss.item() valid_mean_pts_loss /= valid_batch_cnt * 1.0 valid_losses.append(valid_mean_pts_loss) print('Valid: pts_loss: {:.6f}'.format(valid_mean_pts_loss)) print('====================================================') if args.save_model: saved_model_name = os.path.join( args.save_directory, # f'detector_epoch_{epoch_id}_{train_mean_pts_loss}_{valid_mean_pts_loss}.pt') f'detector_epoch_{args.phase}_{epoch_id}.pt') torch.save(model.state_dict(), saved_model_name) draw_loss(train_losses, valid_losses, args.phase) return train_losses, valid_losses
def train_epoch(self, dataset: TLGDataset, batch_size: int, criterion: Callable[[FloatTensor, LongTensor], FloatTensor], optimizer: optim.Optimizer, train_indices: List[int], n_print=100, epoch=0) -> Tuple[float, int, int, int, int]: self.train() permutation = np.random.permutation(train_indices) batch_start = 0 loss = 0. BS, BTS, BW, BTW = 0, 0, 0, 0 running_batch_time = 0.0 # while batch_start < len(permutation): for i in range(len(permutation)): start_time = time.time() optimizer.zero_grad() # batch_end = min([batch_start + batch_size, len(permutation)]) # batch_x = [dataset.X[permutation[i]] for i in range(batch_start, batch_end)] # batch_y = [dataset.Y[permutation[i]] for i in range(batch_start, batch_end)] batch_x = dataset.X[permutation[i]] batch_y = dataset.Y[permutation[i]] # lens = list(map(len, batch_x)) lens = torch.sum((batch_x.word != dataset.x_pad_token).long(), dim=1).to(self.device) # batch_x = pad_sequence(batch_x, batch_first=True).to(self.device) # batch_y = pad_sequence(batch_y, batch_first=True).long().to(self.device) batch_e = F.embedding(batch_y.to(self.device), self.transformer.embedding_matrix) encoder_mask = torch.ones(batch_y.shape[0], batch_y.shape[1], batch_x.shape[1]) for i, l in enumerate(lens): encoder_mask[i, :, l::] = 0 encoder_mask = encoder_mask.to(self.device) decoder_mask = Mask( (batch_x.shape[0], batch_y.shape[1], batch_y.shape[1])).to( self.device) # does this have to be t()? batch_p = self.forward(batch_x, batch_e, encoder_mask, decoder_mask) batch_loss = criterion(batch_p[:, :-1].permute( 0, 2, 1), batch_y[:, 1:].to(self.device)) / lens.float().sum() loss += batch_loss.item() batch_loss.backward() optimizer.step() argmaxes = batch_p.argmax(dim=-1) # print('pre argmaxes', argmaxes.size(), argmaxes[0]) # print('pre y', batch_y.size(), batch_y[0]) argmaxes = argmaxes[:, :-1] y = batch_y[:, 1:] # print('post argmaxes', argmaxes.size(), argmaxes[0]) # print('post y', y.size(), y[0]) (bs, bts), (bw, btw) = accuracy(argmaxes, y.to(self.device), dataset.type_dict[PAD]) # (bs, bts), (bw, btw) = accuracy(batch_p[:, :-1].argmax(dim=-1), batch_y[:, 1:], dataset.type_dict[PAD]) BS += bs BTS += bts BW += bw BTW += btw running_batch_time += time.time() - start_time if i % n_print == n_print - 1: # print every n mini-batches batch_time = running_batch_time / n_print print('[%d, %5d] loss: %.3f | acc: %.3f | %.1f %s | %.1f %s' % (epoch + 1, i + 1, loss / n_print, BTW / BW, batch_time if batch_time >= 1 else 1 / batch_time, 's/batch' if batch_time >= 1 else 'batch(es)/s', batch_time / batch_size if batch_time / batch_size >= 1 else batch_size / batch_time, 's/expl' if batch_time / batch_size >= 1 else 'expl(s)/s'), file=sys.stderr) # if str(device).startswith('cuda'): # print(torch.cuda.memory_summary(abbreviated=False), file=sys.stderr) # assist.info['batch'] = train_i + 1 # assist.info['batch_loss'] = running_loss / n_print # assist.info['batch_acc'] = running_acc / n_print # assist.info['ex_per_s'] = batch_size / batch_time # assist.step() # running_loss = 0.0 # running_acc = 0.0 running_batch_time = 0.0 # batch_start += batch_size return loss, BS, BTS, BW, BTW
def train(model: nn.Module, optimizer: optim.Optimizer, dataloader: DataLoader, epochs: int, loss_criterion: str, model_dir: str, plateau_limit: int, apply_nested_dropout: bool, reconstruct: bool, **kwargs): print(f'The model has {utils.get_num_parameters(model):,} parameters') testloader = kwargs.pop('testloader', None) lr_scheduler = kwargs.pop('lr_scheduler', None) loss_function = getattr(nn, loss_criterion)() batch_print = len(dataloader) // 5 model.train() device = utils.get_device() model.to(device) # TODO check if this actually does anything losses = [] accuracies = [] best_loss = float('inf') best_accuracy = 0 plateau = 0 train_time = 0 for epoch in range(epochs): epoch_start = time.time() line = f'\tEpoch {epoch + 1}/{epochs}' if apply_nested_dropout and epoch > 0: line += f' ({model.get_converged_unit()}/{model.get_dropout_dim()} converged units)' print(line) batch_losses = [] for i, (X, y) in enumerate(dataloader): optimizer.zero_grad() X = X.to(device) y = y.to(device) prediction = model(X) if reconstruct: loss = loss_function(prediction, X) else: loss = loss_function(prediction, y) loss.backward() optimizer.step() batch_losses.append(loss.item()) if (i + 1) % batch_print == 0: batch_loss = utils.format_number(np.average(batch_losses[-batch_print:])) print(f'Batch {i + 1} loss: {batch_loss}') if apply_nested_dropout: model(X) if model.has_converged(): break epoch_loss = utils.format_number(np.average(batch_losses)) losses.append(epoch_loss) epoch_time = time.time() - epoch_start train_time += epoch_time print(f'\tEpoch loss {epoch_loss}') model_save_kwargs = dict(**kwargs, epoch=epoch, train_time=utils.format_time(train_time), losses=losses) has_improved = False if testloader is not None: model.eval() eval_accuracy = round(utils.get_model_accuracy(model, testloader, device), 3) model.train() accuracies.append(eval_accuracy) print(f'\tEvaluation accuracy {eval_accuracy}') if eval_accuracy > best_accuracy: best_accuracy = eval_accuracy has_improved = True model_save_kwargs.update(accuracies=accuracies, best_accuracy=best_accuracy) if lr_scheduler is not None: lr_scheduler.step(eval_accuracy) elif epoch_loss < best_loss: best_loss = epoch_loss has_improved = True model_save_kwargs.update(best_loss=best_loss) print(f'\tEpoch time {utils.format_time(epoch_time)}\n') if has_improved: utils.save_model(model, optimizer, f'{model_dir}/model', **model_save_kwargs) plateau = 0 else: plateau += 1 if (plateau == plateau_limit) or (apply_nested_dropout is True and model.has_converged()): break if apply_nested_dropout is True and model.has_converged(): end = 'nested dropout has converged' print('Nested dropout has converged!') elif plateau == plateau_limit: end = 'has plateaued' print('The model has plateaued...') else: end = f'reached max number of epochs ({epochs})' print('The maximum number of epochs has been reached...') utils.update_save(f'{model_dir}/model', end=end) return losses
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int, noise_sd: float): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() # switch to train mode model.train() for i, (inputs, targets) in enumerate(loader): # measure data loading time data_time.update(time.time() - end) inputs = inputs.cuda() targets = targets.cuda() # augment inputs with noise inputs = inputs + randgn_like(inputs, p=args.p, device='cuda') * noise_sd if (args.scale_down != 1): inputs = torch.nn.functional.interpolate( inputs, scale_factor=args.scale_down) # compute output outputs = model(inputs) loss = criterion(outputs, targets) # measure accuracy and record loss acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) return (losses.avg, top1.avg)
def train( model: Union[nn.Module, nn.DataParallel], train_loader: DataLoader, metrics: Dict[str, Metric], optimizer: Optimizer, scheduler: _LRScheduler, device: torch.device, epoch: int, log_interval: int, hooks: Optional[Sequence[Hook]] = None, teacher: Optional[Union[nn.Module, nn.DataParallel]] = None, ) -> Dict[str, float]: """ Train a model on some data using some criterion and with some optimizer. Args: model: Model to train train_loader: Data loader for loading training data metrics: A dict mapping evaluation metric names to metrics classes optimizer: PyTorch optimizer scheduler: PyTorch scheduler device: PyTorch device object epoch: Current epoch, where the first epoch should start at 1 log_interval: Number of batches before printing loss hooks: A sequence of functions that can implement custom behavior teacher: teacher network for knowledge distillation, if any Returns: A dictionary mapping evaluation metric names to computed values for the training set. """ if hooks is None: hooks = [] model.train() for metric in metrics.values(): metric.reset() loss_fn = model.module.loss_fn if isinstance( model, nn.DataParallel) else model.loss_fn seen_examples = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) if teacher is None: teacher_output = None loss = loss_fn(output, target) # type: ignore else: teacher_output = teacher(data) loss = loss_fn(output, teacher_output, target) # type: ignore loss.backward() optimizer.step() project(optimizer) scheduler.step() # type: ignore with torch.no_grad(): for metric in metrics.values(): metric.update(output, target, teacher_output=teacher_output) for hook in hooks: hook( epoch=epoch, global_step=1 + (epoch - 1) * len(train_loader.dataset) + batch_idx, values_dict={'lr': _get_lr(optimizer)}, log_interval=log_interval, ) seen_examples += len(data) if batch_idx % log_interval == 0: logger.info( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tBatch Loss: {:.6f}'.format( epoch, seen_examples, len(train_loader.dataset), 100 * batch_idx / len(train_loader), loss.item(), )) # Computing evaluation metrics for training set computed_metrics = { name: metric.compute() for name, metric in metrics.items() } logger.info('Training set evaluation metrics:') for name, metric in metrics.items(): logger.info(f'{name}: {metric}') return computed_metrics
def routine( model: nn.Module, dataloader: data.DataLoader, criterion: nn.Module, optimizer: optim.Optimizer = None, adversary: nn.Module = None, inverse: bool = False, descents: int = None, flow: bool = False, mask_sampler: nn.Module = None, clip: float = None, ) -> Tuple[float, torch.Tensor]: # (time, losses) r"""Training routine""" if adversary is None: adversary = Dummy() losses = [] start = time() for theta, theta_prime, x in islice(dataloader, descents): y = model.embedding(x) adv_y = adversary.embedding(x) if flow: prob = model(theta, y) l = criterion(prob) else: if mask_sampler is None: ratio, ratio_prime = model( torch.stack((theta, theta_prime)), torch.stack((y, y)), ) with torch.no_grad(): adv_ratio = adversary(theta if inverse else theta_prime, adv_y) else: if model.hyper is None: mask = mask_sampler(theta.shape[:1]) else: mask = mask_sampler() ratio, ratio_prime = model( torch.stack((theta, theta_prime)), torch.stack((y, y)), torch.stack((mask, mask)) if model.hyper is None else mask, ) with torch.no_grad(): adv_ratio = adversary(theta if inverse else theta_prime, adv_y, mask) if adv_ratio is not None: adv_ratio = (-adv_ratio if inverse else adv_ratio).exp() if inverse: l = criterion(ratio, adv_ratio) + criterion(-ratio_prime) else: l = criterion(ratio) + criterion(-ratio_prime, adv_ratio) if not l.isfinite(): continue if optimizer is not None: optimizer.zero_grad() l.backward() if clip is not None: tot = nn.utils.clip_grad_norm_(model.parameters(), clip) if not tot.isfinite(): continue optimizer.step() losses.append(l.item()) end = time() return end - start, torch.tensor(losses)
def meta_gradient_step(model: Module, optimiser: Optimizer, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor, n_shot: int, k_way: int, q_queries: int, order: int, inner_train_steps: int, inner_lr: float, train: bool, device: Union[str, torch.device]): """ Perform a gradient step on a meta-learner. # Arguments model: Base model of the meta-learner being trained optimiser: Optimiser to calculate gradient step from loss loss_fn: Loss function to calculate between predictions and outputs x: Input samples for all few shot tasks y: Input labels of all few shot tasks n_shot: Number of examples per class in the support set of each task k_way: Number of classes in the few shot classification task of each task q_queries: Number of examples per class in the query set of each task. The query set is used to calculate meta-gradients after applying the update to order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated weights on the query with respect to the original weights). inner_train_steps: Number of gradient steps to fit the fast weights during each inner update inner_lr: Learning rate used to update the fast weights on the inner update train: Whether to update the meta-learner weights at the end of the episode. device: Device on which to run computation """ data_shape = x.shape[2:] create_graph = (True if order == 2 else False) and train task_gradients = [] task_losses = [] task_predictions = [] for meta_batch in x: # By construction x is a 5D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height) # Hence when we iterate over the first dimension we are iterating through the meta batches x_task_train = meta_batch[:n_shot * k_way] x_task_val = meta_batch[n_shot * k_way:] # Create a fast model using the current meta model weights fast_weights = OrderedDict(model.named_parameters()) # Train the model for `inner_train_steps` iterations for inner_batch in range(inner_train_steps): # Perform update of model weights y = create_nshot_task_label(k_way, n_shot).to(device) logits = model.functional_forward(x_task_train, fast_weights) loss = loss_fn(logits, y) gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph) # Update weights manually fast_weights = OrderedDict( (name, param - inner_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), gradients) ) # Do a pass of the model on the validation data from the current task y = create_nshot_task_label(k_way, q_queries).to(device) logits = model.functional_forward(x_task_val, fast_weights) loss = loss_fn(logits, y) loss.backward(retain_graph=True) # Get post-update accuracies y_pred = logits.softmax(dim=1) task_predictions.append(y_pred) # Accumulate losses and gradients task_losses.append(loss) gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph) named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)} task_gradients.append(named_grads) if order == 1: if train: sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0) for k in task_gradients[0].keys()} hooks = [] for name, param in model.named_parameters(): hooks.append( param.register_hook(replace_grad(sum_task_gradients, name)) ) model.train() optimiser.zero_grad() # Dummy pass in order to create `loss` variable # Replace dummy gradients with mean task gradients using hooks logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double)) loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device)) loss.backward() optimiser.step() for h in hooks: h.remove() return torch.stack(task_losses).mean(), torch.cat(task_predictions) elif order == 2: model.train() optimiser.zero_grad() meta_batch_loss = torch.stack(task_losses).mean() if train: meta_batch_loss.backward() optimiser.step() return meta_batch_loss, torch.cat(task_predictions) else: raise ValueError('Order must be either 1 or 2.')
def forward(self, x, opt: optim.Optimizer, step, summary_writer: torch.utils.tensorboard.SummaryWriter = None, sample_gpu=None): """ train inside forward """ opt.zero_grad() batch_size, num_pts = x.shape[:2] z_mu, z_sigma = self.encoder(x) # Compute Q(z|X) and entropy H{Q(z|X)} if self.use_deterministic_encoder: z = z_mu + 0 * z_sigma # ? why, the original code added this 0 multiplier entropy = torch.zeros(batch_size).to(z) else: z = self.reparametrized_gaussian(z_mu, z_sigma) entropy = self.gaussian_entropy(z_sigma) # Compute prior P(z) if self.use_latent_flow: w, dlog_pw = self.latentCNF(z, None, torch.zeros(batch_size, 1).to(z)) log_pw = standard_normal_logp(w).view(batch_size, -1).sum(dim=1, keepdim=True) dlog_pw = dlog_pw.view(batch_size, 1).to(z) log_pz = log_pw - dlog_pw else: log_pz = torch.zeros(batch_size, 1).to(z) # Compute recon. P(X|z) z_new = z.view(z.shape) + (log_pz * 0.).mean() # ? why y, dlog_py = self.pointCNF(x, z_new, torch.zeros(batch_size, num_pts, 1).to(x)) log_py = standard_normal_logp(y).view(batch_size, -1).sum(dim=1, keepdim=True) dlog_py = dlog_py.view(batch_size, num_pts, 1).to(x) log_px = log_py - dlog_py # Loss entropy_loss = -entropy.mean() * self.entropy_w recon_loss = -log_px.mean() * self.recon_w prior_loss = -log_pz.mean() * self.prior_w loss = entropy_loss + recon_loss + prior_loss loss.backward() opt.step() # Write logs if self.distributed: raise NotImplementedError("Distributed training not implemented!") else: entropy_log = entropy.mean() recon_log = -log_px.mean() prior_log = -log_pz.mean() recon_nats = recon_log / float(x.size(1) * x.size(2)) prior_nats = prior_log / float(self.fz) # reconstruct to save with torch.no_grad(): recon_pc = self.reconstruct(x, truncate_std=True) recon_im = visualize(recon_pc, path='/home/tmp/screenshot.png', samples=1) # sample to save if self.use_latent_flow: with torch.no_grad(): sample_pc = self.sample(1, 1024, gpu=sample_gpu) sample_im = visualize(sample_pc, samples=1, path='/home/tmp/screenshot.png') record_dict = { 'train/entropy': entropy_log.cpu().detach().item() if not isinstance(entropy_log, float) else entropy_log, 'train/prior': prior_log, 'train/recon': recon_log, 'train/recon-nats': recon_nats, 'train/prior-nats': prior_nats, # 'train/sample-reconstructed': recon_pc } if summary_writer is not None: for key, value in record_dict: summary_writer.add_scalar(key, value, step) record_dict['train/sample-reconstructed'] = recon_im summary_writer.add_images('train/sample-reconstructed', recon_im, step, dataformats='NHWC') record_dict['train/sample-sampled'] = sample_im summary_writer.add_images('train/sample-sampled', sample_im, step, dataformats='NHWC') return record_dict
def train_controller(max_iter: int, database: DataBase, entropy_coeff: float, grad_clip: int, controller: NASBenchController, nac: NAC, optimizer: optim.Optimizer, writer: tensorboard.SummaryWriter, alternate_train, alternate_evaluate, random_baseline=False, log_frequence: int = 10, search_space=None): controller.train() nac.eval() optimizer.zero_grad() policy_loss_avg = MovingAverageMetric() entropy_mavg = MovingAverageMetric() logp_mavg = MovingAverageMetric() score_avg = MovingAverageMetric() pseudo_architecture_set = None with torch.no_grad(): *arch_seq, _, _ = controller(force_uniform=True) raw_arch = seq2arch_fn(arch_seq) baseline_arch = [tensorize_fn(raw_arch, device=device)] best_collect_archs = [arch_seq] for iter_ in range(max_iter): if iter_ % args.n_iteration_update_pseudoset == 0 and args.pseudo_ratio != 0: if pseudo_architecture_set is None: pseudo_architecture_set = \ generate_architecture_with_pseudo_labels( nac, controller, 2*int(args.pseudo_ratio*args.train_batch_size), int(args.pseudo_ratio*args.train_batch_size)) else: pseudo_architecture_set = list_concat( pseudo_architecture_set, generate_architecture_with_pseudo_labels( nac, controller, 2 * args.n_sample_architectures, args.n_sample_architectures)) epoch = args.nac_epochs + iter_ accuracy, rank_loss = alternate_train( epoch=epoch, pseudo_set=pseudo_architecture_set) writer.add_scalar("nac/train_accuracy", accuracy, epoch) writer.add_scalar("nac/loss", rank_loss, epoch) KTau = alternate_evaluate(epoch=epoch) writer.add_scalar("nac/ktau", KTau, epoch) *arch_seq, logp, entropy = controller() with torch.no_grad(): sample_arch = [tensorize_fn(seq2arch_fn(arch_seq), device=device)] score = nac(batchify(sample_arch), batchify(baseline_arch)) score = score.mean().item() policy_loss = -logp * score - entropy_coeff * entropy optimizer.zero_grad() if grad_clip is not None: nn.utils.clip_grad_norm_(controller.parameters(), grad_clip) policy_loss.backward() optimizer.step() policy_loss_avg.update(policy_loss) entropy_mavg.update(entropy) logp_mavg.update(logp) score_avg.update(score) if iter_ % log_frequence == 0: logger.info(", ".join([ "Policy Learning", f"iter={iter_:03d}", f"policy loss={policy_loss_avg.compute():.4f}", f"entropy={entropy_mavg.compute():.4f}", f"logp={logp_mavg.compute():.4f}", ])) writer.add_scalar("policy_learning/loss", policy_loss_avg.compute(), iter_) writer.add_scalar("policy_learning/entropy", entropy_mavg.compute(), iter_) writer.add_scalar("policy_learning/logp", logp_mavg.compute(), iter_) writer.add_scalar("policy_learning/reward", score_avg.compute(), iter_) if iter_ % args.evaluate_controller_freq == 0: baseline_arch, best_collect_archs = derive(iter_, controller, nac, 10, database, writer, best_collect_archs, random_baseline, search_space) torch.save(controller.state_dict(), os.path.join(args.output, f"controller-{iter_}.path"))
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int, noise_sd: float): """ Function to do one training epoch :param loader:DataLoader: dataloader (train) :param model:torch.nn.Module: the classifer being trained :param criterion: the loss function :param optimizer:Optimizer: the optimizer used during trainined :param epoch:int: the current epoch number (for logging) :param noise_sd:float: the std-dev of the Guassian noise perturbation of the input """ batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() # switch to train mode model.train() for i, (inputs, targets) in enumerate(loader): # measure data loading time data_time.update(time.time() - end) inputs = inputs.cuda() targets = targets.cuda() # augment inputs with noise inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd # compute output outputs = model(inputs) loss = criterion(outputs, targets) # measure accuracy and record loss acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) return (losses.avg, top1.avg)
def update_w(epoch: int, data_loader, device: str, master_pair: MasterPairs, architecture: NASNetwork, criterion: nn.Module, optimizer: optim.Optimizer, force_uniform: bool, writer: SummaryWriter, log_frequency: int): start = datetime.now() loss_metric = AverageMetric() accuracy_metric = AccuracyMetric(topk=(1, 5)) normal_logp_metric = AverageMetric() node_normal_entropy_metric = AverageMetric() op_normal_entropy_metric = AverageMetric() reduced_logp_metric = AverageMetric() node_reduced_entropy_metric = AverageMetric() op_reduced_entropy_metric = AverageMetric() master_pair.set_force_uniform(force_uniform=force_uniform) for iter_, (datas, targets) in enumerate(data_loader, start=1): datas, targets = datas.to(device=device), targets.to(device=device) with torch.no_grad(): (normal_arch, normal_logp, node_normal_entropy, op_normal_entropy), \ (reduced_arch, reduced_logp, node_reduced_entropy, op_reduced_entropy) = master_pair() outputs = architecture(datas, normal_arch, reduced_arch) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() # update metrics loss_metric.update(loss) accuracy_metric.update(targets, outputs) normal_logp_metric.update(normal_logp) node_normal_entropy_metric.update(node_normal_entropy) op_normal_entropy_metric.update(op_normal_entropy) reduced_logp_metric.update(reduced_logp) node_reduced_entropy_metric.update(node_reduced_entropy) op_reduced_entropy_metric.update(op_reduced_entropy) # iteration log if iter_ % log_frequency == 0 or iter_ == len(data_loader): message = f"UPDATE W, epoch={epoch:03d}, iter={iter_}/{len(data_loader)}, " message += f"celoss={loss_metric.last:.4f}({loss_metric.value:.4f}), " message += f"accuracy@1={accuracy_metric.last_accuracy(1).rate*100:.2f}%" message += f"({accuracy_metric.accuracy(1).rate*100:.2f}%), " message += f"accuracy@5={accuracy_metric.last_accuracy(5).rate*100:.2f}%" message += f"({accuracy_metric.accuracy(5).rate*100:.2f}%), " message += f"normal_logp={normal_logp_metric.last:.4f}({normal_logp_metric.value:.4f}), " message += f"node_normal_entropy={node_normal_entropy_metric.last:.4f}({node_normal_entropy_metric.value:.4f}), " message += f"op_normal_entropy={op_normal_entropy_metric.last:.4f}({op_normal_entropy_metric.value:.4f}), " message += f"reduced_logp={reduced_logp_metric.last:.4f}({reduced_logp_metric.value:.4f}), " message += f"node_reduced_entropy={node_reduced_entropy_metric.last:.4f}({node_reduced_entropy_metric.value:.4f}), " message += f"op_reduced_entropy={op_reduced_entropy_metric.last:.4f}({op_reduced_entropy_metric.value:.4f})." if iter_ == len(data_loader): message += f" Eplased time={datetime.now()-start}." utils.logger.info(message) writer.add_scalar("update_w/celoss", loss_metric.value, epoch) writer.add_scalar("update_w/accuracy@1", accuracy_metric.accuracy(1).rate, epoch) writer.add_scalar("update_w/accuracy@5", accuracy_metric.accuracy(5).rate, epoch) writer.add_scalar("update_w/normal_logp", normal_logp_metric.value, epoch) writer.add_scalar("update_w/node_normal_entropy", node_normal_entropy_metric.value, epoch) writer.add_scalar("update_w/op_normal_entropy", op_normal_entropy_metric.value, epoch) writer.add_scalar("update_w/reduced_logp", reduced_logp_metric.value, epoch) writer.add_scalar("update_w/node_reduced_entropy", node_reduced_entropy_metric.value, epoch) writer.add_scalar("update_w/op_reduced_entropy", op_reduced_entropy_metric.value, epoch)
def matching_net_episode(model: Module, optimiser: Optimizer, loss_fn: Loss, x: torch.Tensor, y: torch.Tensor, n_shot: int, k_way: int, q_queries: int, distance: str, fce: bool, train: bool): """Performs a single training episode for a Matching Network. # Arguments model: Matching Network to be trained. optimiser: Optimiser to calculate gradient step from loss loss_fn: Loss function to calculate between predictions and outputs x: Input samples of few shot classification task y: Input labels of few shot classification task n_shot: Number of examples per class in the support set k_way: Number of classes in the few shot classification task q_queries: Number of examples per class in the query set distance: Distance metric to use when calculating distance between support and query set samples fce: Whether or not to us fully conditional embeddings train: Whether (True) or not (False) to perform a parameter update # Returns loss: Loss of the Matching Network on this task y_pred: Predicted class probabilities for the query set on this task """ if train: # Zero gradients model.train() optimiser.zero_grad() else: model.eval() # Embed all samples embeddings = model.encoder(x) # Samples are ordered by the NShotWrapper class as follows: # k lots of n support samples from a particular class # k lots of q query samples from those classes support = embeddings[:n_shot * k_way] queries = embeddings[n_shot * k_way:] # Optionally apply full context embeddings if fce: # LSTM requires input of shape (seq_len, batch, input_size). `support` is of # shape (k_way * n_shot, embedding_dim) and we want the LSTM to treat the # support set as a sequence so add a single dimension to transform support set # to the shape (k_way * n_shot, 1, embedding_dim) and then remove the batch dimension # afterwards # Calculate the fully conditional embedding, g, for support set samples as described # in appendix A.2 of the paper. g takes the form of a bidirectional LSTM with a # skip connection from inputs to outputs support, _, _ = model.g(support.unsqueeze(1)) support = support.squeeze(1) # Calculate the fully conditional embedding, f, for the query set samples as described # in appendix A.1 of the paper. queries = model.f(support, queries) # Efficiently calculate distance between all queries and all prototypes # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way) distances = pairwise_distances(queries, support, distance) # Calculate "attention" as softmax over support-query distances attention = (-distances).softmax(dim=1) # Calculate predictions as in equation (1) from Matching Networks # y_hat = \sum_{i=1}^{k} a(x_hat, x_i) y_i y_pred = matching_net_predictions(attention, n_shot, k_way, q_queries) # Calculated loss with negative log likelihood # Clip predictions for numerical stability clipped_y_pred = y_pred.clamp(EPSILON, 1 - EPSILON) loss = loss_fn(clipped_y_pred.log(), y) if train: # Backpropagate gradients loss.backward() # I found training to be quite unstable so I clip the norm # of the gradient to be at most 1 clip_grad_norm_(model.parameters(), 1) # Take gradient step optimiser.step() return loss, y_pred
def update_model( optimizer: optim.Optimizer, scaler: amp.grad_scaler.GradScaler, buffer: Buffer, state: TSP2OPTState, done: bool, epoch: int, count: int, learn_count: int, global_step: int, logger: SummaryWriter, args, ): rewards = torch.stack(buffer.rewards, dim=0) # [horizon, batch_size, 1] returns = discounted_return(rewards, args.gamma, count) # [horizon, batch_size, 1] if not args.no_norm_return: r_mean = returns.mean() r_std = returns.std() eps = torch.finfo(torch.float).eps # small number to avoid div/0 returns = (returns - r_mean) / (r_std + eps) values = torch.stack(buffer.values, dim=0) # [horizon, batch_size, 1] advantages = (returns - values).detach() # [horizon, batch_size, 1] logps = torch.stack(buffer.log_probs, dim=0) # [horizon, batch_size, 2, graph_size] actions = torch.stack(buffer.actions, dim=0) # [horizon, batch_size, 2, 1] log_likelihood = logps.gather(-1, actions).squeeze( -1) # [horizon, batch_size, 2] log_likelihood = log_likelihood.mean(2).unsqueeze( 2) # [horizon, batch_size, 1] entropies = log_p_to_entropy(logps).mean(2).unsqueeze( 2) # [horizon, batch_size, 1] p_loss = (-log_likelihood * advantages).mean() v_loss = args.value_beta * (returns - values).pow(2).mean() e_loss = (0.9**(epoch + 1)) * args.entropy_beta * entropies.sum(0).mean() r_loss = -e_loss + v_loss loss = p_loss + r_loss optimizer.zero_grad() scaler.scale(p_loss).backward(retain_graph=True) # scaler.unscale_(optimizer) grad_norms = clip_grad_norms( optimizer.param_groups) #, args.max_grad_norm) scaler.scale(r_loss).backward(retain_graph=False) scaler.step(optimizer) scaler.update() buffer.clear_buffer() log_values( cost=state.best_tour_len, grad_norms=grad_norms, done=done, epoch=epoch, global_step=global_step, learn_count=learn_count, p_loss=p_loss, v_loss=v_loss, e_loss=e_loss, loss=loss, returns=returns.mean(), value=values.mean(), entropy=entropies.detach().mean(), logger=logger, args=args, ) learn_count += 1 return learn_count
def train_epoch( model: nn.Module, dataloader: DataLoader, criterion: Callable, optimizer: optim.Optimizer, device: torch.device, train_eval_freq: int = 50, clip_grad_norm: float = 1.0, verbose: bool = True, ) -> DefaultDict[str, List[float]]: """ Training loop on one epoch. :param nn.Module model: PyTorch Neural Network :param DataLoader dataloader: PyTorch DataLoader :param Callable criterion: PyTorch Critertion :param optim.Optimizer optimizer: PyTorch Optimizer :param torch.device device: PyTorch Device :param int train_eval_freq: evaluation frequency (number of batches) (default: 50) :param float clip_grad_norm: max_norm parameter in clip_grad_norm (default: 1.0) :param bool verbose: verbose (default: True) :return: metrics dict :rtype: DefaultDict[str, List[float]] """ metrics: DefaultDict[str, List[float]] = defaultdict(list) char2idx = dataloader.dataset.char2idx # BOS and EOS bos_id = char2idx[BOS] eos_id = char2idx[EOS] if verbose: dataloader = tqdm(dataloader, desc="iter dataloader") model.train() for i, sentence in enumerate(dataloader): sentence = sentence.to(device) # lengths and mask targets = sentence[:, 1:] # clip left lengths = infer_lengths(sentence, bos_id=bos_id, eos_id=eos_id) mask = masking(lengths + 1) # incl. EOS # forward pass outputs = model( sentence[:, :-1], # clip right lengths + 1, # incl. BOS ) loss_matrix = criterion( input=outputs.transpose(1, 2), target=targets, ) loss = (loss_matrix * mask).sum() / mask.sum() # backward pass loss.backward() # clip grad norm grad_norm = nn.utils.clip_grad_norm_( model.parameters(), max_norm=clip_grad_norm, ) # optimizer step optimizer.step() optimizer.zero_grad() # calculate metrics metrics["loss"].append(loss.item()) metrics["grad_norm"].append(grad_norm.item()) if verbose: if i % train_eval_freq == 0: generated_sequence = generate( model=model, char2idx=char2idx, prefix="", temperature=0.5, # hardcoded max_length=100, # hardcoded ) model.train() # eval to train for metric_name, metric_list in metrics.items(): print( f"{metric_name}: {np.mean(metric_list[-train_eval_freq:])}" ) print(f"inference: {generated_sequence}\n") return metrics
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int, noise_sd: float, attacker: Attacker, device: torch.device, writer=None): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_reg = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() # switch to train mode model.train() requires_grad_(model, True) for i, batch in enumerate(loader): # measure data loading time data_time.update(time.time() - end) mini_batches = _chunk_minibatch(batch, args.num_noise_vec) for inputs, targets in mini_batches: inputs, targets = inputs.to(device), targets.to(device) batch_size = inputs.size(0) noises = [ torch.randn_like(inputs, device=device) * noise_sd for _ in range(args.num_noise_vec) ] if args.adv_training: requires_grad_(model, False) model.eval() inputs = attacker.attack(model, inputs, targets, noises=noises) model.train() requires_grad_(model, True) # augment inputs with noise inputs_c = torch.cat([inputs + noise for noise in noises], dim=0) targets_c = targets.repeat(args.num_noise_vec) logits = model(inputs_c) loss_xent = criterion(logits, targets_c) logits_chunk = torch.chunk(logits, args.num_noise_vec, dim=0) loss_con = consistency_loss(logits_chunk, args.lbd, args.eta) loss = loss_xent + loss_con acc1, acc5 = accuracy(logits, targets_c, topk=(1, 5)) losses.update(loss_xent.item(), batch_size) losses_reg.update(loss_con.item(), batch_size) top1.update(acc1.item(), batch_size) top5.update(acc5.item(), batch_size) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.avg:.3f}\t' 'Data {data_time.avg:.3f}\t' 'Loss {loss.avg:.4f}\t' 'Acc@1 {top1.avg:.3f}\t' 'Acc@5 {top5.avg:.3f}'.format(epoch, i, len(loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) writer.add_scalar('loss/train', losses.avg, epoch) writer.add_scalar('loss/consistency', losses_reg.avg, epoch) writer.add_scalar('batch_time', batch_time.avg, epoch) writer.add_scalar('accuracy/train@1', top1.avg, epoch) writer.add_scalar('accuracy/train@5', top5.avg, epoch) return (losses.avg, top1.avg)
def proto_net_episode(model: Module, optimiser: Optimizer, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor, n_shot: int, k_way: int, q_queries: int, distance: str, train: bool): """Performs a single training episode for a Prototypical Network. # Arguments model: Prototypical Network to be trained. optimiser: Optimiser to calculate gradient step loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy x: Input samples of few shot classification task y: Input labels of few shot classification task n_shot: Number of examples per class in the support set k_way: Number of classes in the few shot classification task q_queries: Number of examples per class in the query set distance: Distance metric to use when calculating distance between class prototypes and queries train: Whether (True) or not (False) to perform a parameter update # Returns loss: Loss of the Prototypical Network on this task y_pred: Predicted class probabilities for the query set on this task """ if train: # Zero gradients model.train() optimiser.zero_grad() else: model.eval() # Embed all samples embeddings = model(x) # Samples are ordered by the NShotWrapper class as follows: # k lots of n support samples from a particular class # k lots of q query samples from those classes support = embeddings[:n_shot*k_way] #[n_s X 64] queries = embeddings[n_shot*k_way:] #[n_f X 64] prototypes = compute_prototypes(support, k_way, n_shot) # Calculate squared distances between all queries and all prototypes # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way) # distances = pairwise_distances(queries, prototypes, distance) distances = pairwise_distances(queries, prototypes, distance) # Calculate log p_{phi} (y = k | x) log_p_y = (-distances).log_softmax(dim=1) loss = loss_fn(log_p_y, y) # Prediction probabilities are softmax over distances y_pred = (-distances).softmax(dim=1) if train: # Take gradient step loss.backward() optimiser.step() else: pass return loss, y_pred
def train_stego(*, stegoanalyser: nn.Module, train_iterator: DataBatchIterator, val_iterator: DataBatchIterator, text_iterator: Iterator, n_epoch: int, stegoanalyser_opt: Optimizer, callbacks: Sequence[Callable] = None, logger: TBLogger, encoder: SigmoidTorchEncoder): criterion = F.binary_cross_entropy_with_logits callbacks = callbacks or [] for epoch in tqdm(range(n_epoch)): stegoanalyser_losses = [] with train_iterator as iterator: for real_batch, _ in iterator: batch_size = len(real_batch) labels = np.random.choice([0, 1], (batch_size, 1, 1, 1)) encoded_images = [] for image, label in zip(real_batch, labels): if label == 1: msg = bytes_to_bits(next(text_iterator)) key = generate_random_key(image.shape[1:], len(msg)) image = encoder.encode(transform_encoder(image), msg, key) image = inverse_transform_encoder(image) encoded_images.append(image) encoded_images = torch.stack(encoded_images) labels = torch.from_numpy(labels).float() # train stegoanalyzer stegoanalyser_opt.zero_grad() stegoanalyser_losses.append( process_batch(encoded_images.detach(), labels, stegoanalyser, criterion)) stegoanalyser_opt.step() with val_iterator as iterator: accuracy = [] for real_batch, _ in iterator: batch_size = len(real_batch) labels = np.random.choice([0, 1], batch_size) encoded_images = [] for image, label in zip(real_batch, labels): if label == 1: msg = bytes_to_bits(next(text_iterator)) key = generate_random_key(image.shape[1:], len(msg)) image = encoder.encode(transform_encoder(image), msg, key) image = inverse_transform_encoder(image) encoded_images.append(image) encoded_images = torch.stack(encoded_images) # evaluate stegoanalyzer out = inference_step(encoded_images, stegoanalyser).cpu().detach() out = torch.sigmoid(out) > 0.5 out = out.reshape(len(encoded_images)).numpy() accuracy_score = sklearn.metrics.accuracy_score(labels, out) accuracy.append(accuracy_score) mean_accuracy = np.mean(accuracy) print(f'validation accuracy score {mean_accuracy}') losses = {'Stegoanalyser loss': np.mean(stegoanalyser_losses), 'Val accuracy': mean_accuracy} logger.policies(losses, epoch) # run callbacks for callback in callbacks: callback(epoch)
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int, noise_sd: float, attacker: Attacker = None): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() # switch to train mode model.train() requires_grad_(model, True) for i, batch in enumerate(loader): # measure data loading time data_time.update(time.time() - end) mini_batches = get_minibatches(batch, args.num_noise_vec) noisy_inputs_list = [] for inputs, targets in mini_batches: inputs = inputs.cuda() targets = targets.cuda() inputs = inputs.repeat( (1, args.num_noise_vec, 1, 1)).view(batch[0].shape) # augment inputs with noise noise = torch.randn_like(inputs, device='cuda') * noise_sd if args.adv_training: requires_grad_(model, False) model.eval() inputs = attacker.attack(model, inputs, targets, noise=noise, num_noise_vectors=args.num_noise_vec, no_grad=args.no_grad_attack) model.train() requires_grad_(model, True) if args.train_multi_noise: noisy_inputs = inputs + noise targets = targets.unsqueeze(1).repeat( 1, args.num_noise_vec).reshape(-1, 1).squeeze() outputs = model(noisy_inputs) loss = criterion(outputs, targets) acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) losses.update(loss.item(), noisy_inputs.size(0)) top1.update(acc1.item(), noisy_inputs.size(0)) top5.update(acc5.item(), noisy_inputs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() else: inputs = inputs[::args.num_noise_vec] # subsample the samples noise = noise[::args.num_noise_vec] # noise = torch.randn_like(inputs, device='cuda') * noise_sd noisy_inputs_list.append(inputs + noise) if not args.train_multi_noise: noisy_inputs = torch.cat(noisy_inputs_list) targets = batch[1].cuda() assert len(targets) == len(noisy_inputs) outputs = model(noisy_inputs) loss = criterion(outputs, targets) # measure accuracy and record loss acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) losses.update(loss.item(), noisy_inputs.size(0)) top1.update(acc1.item(), noisy_inputs.size(0)) top5.update(acc5.item(), noisy_inputs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) return (losses.avg, top1.avg)
def torch_single_train(model: PyTorchForecast, opt: optim.Optimizer, criterion: Type[torch.nn.modules.loss._Loss], data_loader: DataLoader, takes_target: bool, meta_data_model: PyTorchForecast, meta_data_model_representation: torch.Tensor, meta_loss=None, multi_targets=1, forward_params: Dict = {}) -> float: probablistic = None if "probabilistic" in model.params["model_params"]: probablistic = True print('running torch_single_train') i = 0 output_std = None running_loss = 0.0 for src, trg in data_loader: opt.zero_grad() # Convert to CPU/GPU/TPU src = src.to(model.device) trg = trg.to(model.device) if meta_data_model: representation = meta_data_model.model.generate_representation( meta_data_model_representation) forward_params["meta_data"] = representation if meta_loss: output = meta_data_model.model(meta_data_model_representation) met_loss = compute_loss(meta_data_model_representation, output, torch.rand(2, 3, 2), meta_loss, None) met_loss.backward() if takes_target: forward_params["t"] = trg output = model.model(src, **forward_params) if multi_targets == 1: labels = trg[:, :, 0] elif multi_targets > 1: labels = trg[:, :, 0:multi_targets] if probablistic: output1 = output output = output.mean output_std = output1.stddev loss = compute_loss(labels, output, src, criterion, None, probablistic, output_std, m=multi_targets) if loss > 100: print("Warning: high loss detected") loss.backward() opt.step() if torch.isnan(loss) or loss == float('inf'): raise ValueError( "Error infinite or NaN loss detected. Try normalizing data or performing interpolation" ) running_loss += loss.item() i += 1 print("The running loss is: ") print(running_loss) print("The number of items in train is: " + str(i)) total_loss = running_loss / float(i) return total_loss
def train_model( train_ds: tf.data.Dataset, dev_ds: tf.data.Dataset, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler, args: argparse.Namespace, ) -> nn.Module: device = model_utils.get_device() loss_fn = model_utils.depth_proportional_loss val_loss_fn = model_utils.l1_norm_loss best_val_loss = torch.tensor(float('inf')) saved_checkpoints = [] writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}') cos = nn.CosineSimilarity(dim=1, eps=0) get_gradient: nn.Module = sobel.Sobel().to(device) for e in range(1, args.train_epochs + 1): print(f'Training epoch {e}...') if args.use_scheduler: lr_scheduler.step() # Training portion torch.cuda.empty_cache() torch.set_grad_enabled(True) with tqdm(total=args.train_batch_size * len(train_ds)) as progress_bar: model.train() for i, (x_batch_orig, y_batch) in enumerate(train_ds.as_numpy_iterator()): x_batch, y_batch = model_utils.preprocess_training_example( x_batch_orig, y_batch) y_blurred = model_utils.blur_depth_map(y_batch) ones = torch.ones(y_batch.shape, dtype=torch.float32, device=device) # Forward pass on model optimizer.zero_grad() y_pred = model(x_batch) depth_grad = get_gradient(y_blurred) output_grad = get_gradient(y_pred) depth_grad_dx = depth_grad[:, 0, :, :].contiguous().view_as( y_blurred) depth_grad_dy = depth_grad[:, 1, :, :].contiguous().view_as( y_batch) output_grad_dx = output_grad[:, 0, :, :].contiguous().view_as( y_blurred) output_grad_dy = output_grad[:, 1, :, :].contiguous().view_as( y_batch) depth_normal = torch.cat( (-depth_grad_dx, -depth_grad_dy, ones), 1) output_normal = torch.cat( (-output_grad_dx, -output_grad_dy, ones), 1) loss_depth = torch.log(torch.abs(y_pred - y_batch) + 0.5).mean() loss_dx = torch.log( torch.abs(output_grad_dx - depth_grad_dx) + 0.5).mean() loss_dy = torch.log( torch.abs(output_grad_dy - depth_grad_dy) + 0.5).mean() loss_normal = torch.abs( 1 - cos(output_normal, depth_normal)).mean() loss = loss_depth + loss_normal + (loss_dx + loss_dy) # Backward pass and optimization loss.backward() optimizer.step() progress_bar.update(len(x_batch)) progress_bar.set_postfix(loss=loss.item()) writer.add_scalar("train/Loss", loss, ((e - 1) * len(train_ds) + i) * args.train_batch_size) # Periodically save a diagram if (i + 1) % args.picture_frequency == 0: model_utils.make_diagram( np.transpose(x_batch_orig, (0, 3, 1, 2)), x_batch.cpu().numpy(), y_batch.cpu().numpy(), y_pred.cpu().detach().numpy(), f'{args.save_path}/{args.experiment}/diagram_{e}_{i+1}.png', ) del x_batch del y_batch del y_blurred del y_pred del loss # Validation portion torch.cuda.empty_cache() torch.set_grad_enabled(False) with tqdm(total=args.dev_batch_size * len(dev_ds)) as progress_bar: model.eval() val_loss = 0.0 num_batches_processed = 0 total_pixels = 0 total_examples = 0 squared_error = 0 rel_error = 0 log_error = 0 threshold1 = 0 # 1.25 threshold2 = 0 # 1.25^2 threshold3 = 0 # corresponds to 1.25^3 for i, (x_batch, y_batch) in enumerate(dev_ds.as_numpy_iterator()): x_batch, y_batch = model_utils.preprocess_test_example( x_batch, y_batch) # Forward pass on model in validation environment y_pred = model(x_batch) # TODO: Process y_pred in whatever way inference requires. loss = val_loss_fn(y_pred, y_batch) val_loss += loss.item() num_batches_processed += 1 nanmask = getNanMask(y_batch) total_pixels = torch.sum(~nanmask) total_examples += x_batch.shape[0] # RMS, REL, LOG10, threshold calculation squared_error += ( torch.sum(torch.pow(y_pred - y_batch, 2)).item() / total_pixels)**0.5 rel_error += torch.sum( removeNans(torch.abs(y_pred - y_batch) / y_batch)).item() / total_pixels log_error += torch.sum( torch.abs( removeNans(torch.log10(y_pred)) - removeNans( torch.log10(y_batch)))).item() / total_pixels threshold1 += torch.sum( torch.max(y_pred / y_batch, y_batch / y_pred) < 1.25).item() / total_pixels threshold2 += torch.sum( torch.max(y_pred / y_batch, y_batch / y_pred) < 1.25**2).item() / total_pixels threshold3 += torch.sum( torch.max(y_pred / y_batch, y_batch / y_pred) < 1.25**3).item() / total_pixels progress_bar.update(len(x_batch)) progress_bar.set_postfix(val_loss=val_loss / num_batches_processed) writer.add_scalar("Val/Loss", loss, ((e - 1) * len(dev_ds) + i) * args.dev_batch_size) del x_batch del y_batch del y_pred del loss writer.add_scalar("Val/RMS", squared_error / total_examples, e) writer.add_scalar("Val/REL", rel_error / total_examples, e) writer.add_scalar("Val/LOG10", log_error / total_examples, e) writer.add_scalar("Val/delta1", threshold1 / total_examples, e) writer.add_scalar("Val/delta2", threshold2 / total_examples, e) writer.add_scalar("Val/delta3", threshold3 / total_examples, e) # Save model if it's the best one yet. if val_loss / num_batches_processed < best_val_loss: best_val_loss = val_loss / num_batches_processed filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint' model_utils.save_model(model, filename) print(f'Model saved!') print(f'Best validation loss yet: {best_val_loss}') # Save model on checkpoints. if e % args.checkpoint_freq == 0: filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint' model_utils.save_model(model, filename) print(f'Model checkpoint reached!') saved_checkpoints.append(filename) # Delete checkpoints if there are too many while len(saved_checkpoints) > args.num_checkpoints: os.remove(saved_checkpoints.pop(0)) return model