def trainer(self, train_loader, criterion, optimizer, epoch, cum_epochs, update): logging.info('\n' + '-' * 200 + '\n' + '\t' * 10 + 'TRAINING\n') losses = AverageMeter() self.train() for i, ((ida, xa, xp), (idn, xn)) in enumerate(train_loader): xa, xp, xn = xa.unsqueeze(1), xp.unsqueeze(1), xn.unsqueeze(1) xa, xp, xn = xa.to(self.device), xp.to(self.device), xn.to( self.device) outa, outp, outn = self.forward(xa, xp, xn) optimizer.zero_grad() loss = criterion(outa, outp, outn) losses.update(loss.item(), xa.size(0)) loss.backward() optimizer.step() if (i + 1) % self.show_every == 0: logging.info('\tEpoch: [{0}][{1}/{2}]\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( cum_epochs + epoch + 1, i + 1, len(train_loader), loss=losses)) if update: score_update = update_cos(outa, outp, outn, 1.5 * self.margin) train_loader.dataset.update(score_update, ida, idn) if update: logging.info('\n{}'.format( np.array_str(train_loader.dataset.probability_matrix, precision=2, suppress_small=True))) train_loader.dataset.reset() logging.info('\n' + '-' * 200)
def trainer(self, train_loader, criterion, optimizer, epoch, cum_epochs): logging.info('\n' + '-' * 200 + '\n' + '\t' * 10 + 'TRAINING\n') losses = AverageMeter() self.train() for i, (x1, x2, label) in enumerate(train_loader): x1, x2 = x1.unsqueeze(1), x2.unsqueeze(1) if not self.normalize: label = 2 * label - 1 x1, x2, label = x1.to(self.device), x2.to(self.device), label.to( self.device) out1, out2 = self.forward(x1, x2) optimizer.zero_grad() loss = criterion(out1.squeeze(), out2.squeeze(), label.squeeze().float()) losses.update(loss.item(), x1.size(0)) loss.backward() optimizer.step() if (i + 1) % self.show_every == 0: logging.info('\tEpoch: [{0}][{1}/{2}]\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( cum_epochs + epoch + 1, i + 1, len(train_loader), loss=losses)) logging.info('\n' + '-' * 200)
def train_fn(train_loader, model, optimizer, criterion, scheduler, device, print_log: int = 10): losses = AverageMeter() hammings = AverageMeter() model.train() for i, (_, image, plant, disease) in enumerate(train_loader): image = image.to(device) combined = multi_label_tensors_to_single_label_tensor(plant, disease) combined = combined.to(device) optimizer.zero_grad() outputs = model(image) loss = criterion(outputs, combined) loss.backward() optimizer.step() _, preds = torch.max(outputs, 1) acc = cal_hamming_loss(plant, disease, preds) if (i + 1) % print_log == 0: print(f' Train - loss: {loss.item():.4f} hamming loss: {acc}') losses.update(loss, image.size(0)) hammings.update(acc, image.size(0)) scheduler.step()
def train(train_loader, model, optimizer, criterion, device, scheduler): model.train() losses = AverageMeter() hammings = AverageMeter() train_loader = tqdm(train_loader, total=len(train_loader)) for _, image, plant, disease in train_loader: image = image.to(device) combined = multi_label_tensors_to_single_label_tensor(plant, disease) combined = combined.to(device) optimizer.zero_grad() outputs = model(image) loss = criterion(outputs, combined) loss.backward() optimizer.step() _, preds = torch.max(outputs, 1) hamming = cal_hamming_loss(plant, disease, preds) losses.update(loss.item(), image.size(0)) hammings.update(hamming.item(), image.size(0)) train_loader.set_postfix(loss=losses.avg, hamming=hammings.avg) scheduler.step()
def validation_step(self, summary_dev): """Summary Extract the batch of datapoints and return the predicted logits in validation step Args: summary_dev (TYPE): Description Returns: TYPE: Description """ losses = AverageMeter() torch.set_grad_enabled(False) self.model.eval() output_ = np.array([]) target_ = np.array([]) with torch.no_grad(): for i, (inputs, target, _) in enumerate(self.valid_loader): target = target.to(self.device) if isinstance(inputs, tuple): inputs = tuple([ e.to(self.device) if type(e) == torch.Tensor else e for e in inputs ]) else: inputs = inputs.to(self.device) logits = self.forward(inputs) loss = self.criterion(logits, target) losses.update(loss.item(), target.size(0)) if self.hparams.multi_cls: output = F.softmax(logits) _, output = torch.max(output, 1) else: output = torch.sigmoid(logits) target = target.detach().to('cpu').numpy() target_ = np.concatenate( (target_, target), axis=0) if len(target_) > 0 else target y_pred = output.detach().to('cpu').numpy() output_ = np.concatenate( (output_, y_pred), axis=0) if len(output_) > 0 else y_pred summary_dev['loss'] = losses.avg return summary_dev, output_, target_
def evaluate(self, loader): val_bar = tqdm(loader) avg_psnr = AverageMeter() avg_ssim = AverageMeter() recon_images = [] gt_images = [] input_images = [] for data in val_bar: self.set_input(data) self.forward() if self.opts.wr_L1 > 0: psnr_recon = psnr(complex_abs_eval(self.recon), complex_abs_eval(self.tag_image_full)) avg_psnr.update(psnr_recon) ssim_recon = ssim(complex_abs_eval(self.recon)[0,0,:,:].cpu().numpy(), complex_abs_eval(self.tag_image_full)[0,0,:,:].cpu().numpy()) avg_ssim.update(ssim_recon) recon_images.append(self.recon[0].cpu()) gt_images.append(self.tag_image_full[0].cpu()) input_images.append(self.tag_image_sub[0].cpu()) message = 'PSNR: {:4f} '.format(avg_psnr.avg) message += 'SSIM: {:4f} '.format(avg_ssim.avg) val_bar.set_description(desc=message) self.psnr_recon = avg_psnr.avg self.ssim_recon = avg_ssim.avg self.results = {} if self.opts.wr_L1 > 0: self.results['recon'] = torch.stack(recon_images).squeeze().numpy() self.results['gt'] = torch.stack(gt_images).squeeze().numpy() self.results['input'] = torch.stack(input_images).squeeze().numpy()
def valid_fn(data_loader, model, optimizer, criterion, scheduler, device): with torch.no_grad(): model.eval() losses = AverageMeter() hammings = AverageMeter() for _, image, plant, disease in data_loader: image = image.to(device) combined = multi_label_tensors_to_single_label_tensor( plant, disease) combined = combined.to(device) outputs = model(image) loss = criterion(outputs, combined) _, preds = torch.max(outputs, 1) acc = cal_hamming_loss(plant, disease, preds) losses.update(loss, image.size(0)) hammings.update(acc, image.size(0)) print(f' Valid - loss: {losses.avg:.4f} hamming loss: {hammings.avg}') return losses.avg
def evaluate(data_loader, model, criterion, device): model.eval() losses = AverageMeter() hammings = AverageMeter() data_loader = tqdm(data_loader, total=len(data_loader)) with torch.no_grad(): for _, image, plant, disease in data_loader: image = image.to(device) combined = multi_label_tensors_to_single_label_tensor( plant, disease) combined = combined.to(device) outputs = model(image) loss = criterion(outputs, combined) _, preds = torch.max(outputs, 1) hamming = cal_hamming_loss(plant, disease, preds) losses.update(loss.item(), image.size(0)) hammings.update(hamming.item(), image.size(0)) data_loader.set_postfix(loss=losses.avg, hamming=hammings.avg) return hammings.avg
def evaluate(self, loader): val_bar = tqdm(loader) avg_psnr = AverageMeter() avg_ssim = AverageMeter() avg_mse = AverageMeter() pred_images = [] gt_images = [] gt_inp_images = [] for data in val_bar: self.set_input(data) self.forward(self.IH) psnr_ = psnr(self.IT_fake + 1, self.IT + 1) mse_ = mse(self.IT_fake + 1, self.IT + 1) ssim_ = ssim(self.IT_fake[0, 0, ...].cpu().numpy() + 1, self.IT[0, 0, ...].cpu().numpy() + 1) avg_psnr.update(psnr_) avg_mse.update(mse_) avg_ssim.update(ssim_) pred_images.append(self.IT_fake[0].cpu()) gt_images.append(self.IT[0].cpu()) gt_inp_images.append(self.IH[0].cpu()) message = 'PSNR: {:4f} '.format(avg_psnr.avg) message += 'SSIM: {:4f} '.format(avg_ssim.avg) message += 'MSE: {:4f} '.format(avg_mse.avg) val_bar.set_description(desc=message) self.psnr = avg_psnr.avg self.ssim = avg_ssim.avg self.mse = avg_mse.avg self.results = {} self.results['pred_IT'] = torch.stack(pred_images).squeeze().numpy() self.results['gt_IT'] = torch.stack(gt_images).squeeze().numpy() self.results['gt_IH'] = torch.stack(gt_inp_images).squeeze().numpy()
def val_epoch(epoch, data_loader, model, criterion, opt, logger): print('validation at epoch {}'.format(epoch)) model.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() accuracies = AverageMeter() end_time = time.time() for i, (inputs, targets, video_ids) in enumerate(data_loader): data_time.update(time.time() - end_time) # if not opt.no_cuda: # targets = targets.cuda(async=True) inputs = Variable(inputs, volatile=True) targets = Variable(targets, volatile=True) if opt.save_features: model.module.label = video_ids[0] + str(targets.tolist()[0]) outputs = model(inputs) loss = criterion(outputs, targets) acc = calculate_accuracy(outputs, targets) losses.update(loss.data[0], inputs.size(0)) accuracies.update(acc, inputs.size(0)) batch_time.update(time.time() - end_time) end_time = time.time() print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(epoch, i + 1, len(data_loader), batch_time=batch_time, data_time=data_time, loss=losses, acc=accuracies)) logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg}) return losses.avg
def validate(): bs = 256 # create model model = create_model('vit_base_patch16_224', pretrained=True, num_classes=1000) criterion = nn.CrossEntropyLoss() dataset = create_val_dataset(root='/data/imagenet', batch_size=bs, num_workers=4, img_size=224) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with jt.no_grad(): input = jt.random((bs, 3, 224, 224)) model(input) end = time.time() for batch_idx, (input, target) in enumerate(dataset): # dataset.display_worker_status() batch_size = input.shape[0] # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss, batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % 10 == 0: # jt.sync_all(True) print( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(dataset), batch_time=batch_time, rate_avg=batch_size / batch_time.avg, loss=losses, top1=top1, top5=top5)) # if batch_idx>50:break top1a, top5a = top1.avg, top5.avg top1 = round(top1a, 4) top1_err = round(100 - top1a, 4) top5 = round(top5a, 4) top5_err = round(100 - top5a, 4) print(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( top1, top1_err, top5, top5_err))
def training_step(self, summary_train, summary_dev, best_dict): """Summary Extract the batch of datapoints and return the predicted logits Args: summary_train: summary_dev: best_dict: Returns: TYPE: Description """ losses = AverageMeter() torch.set_grad_enabled(True) self.model.train() time_now = time.time() for i, (inputs, target, _) in enumerate(self.train_loader): if isinstance(inputs, tuple): inputs = tuple([ e.to(self.device) if type(e) == torch.Tensor else e for e in inputs ]) else: inputs = inputs.to(self.device) target = target.to(self.device) self.optimizer.zero_grad() if self.cfg.no_jsd: if self.cfg.n_crops: bs, n_crops, c, h, w = inputs.size() inputs = inputs.view(-1, c, h, w) if len(self.hparams.mixtype) > 0: if self.hparams.multi_cls: target = target.view(target.size()[0], -1) inputs, targets_a, targets_b, lam = self.mix_data( inputs, target.repeat(1, n_crops).view(-1), self.device, self.hparams.alpha) else: inputs, targets_a, targets_b, lam = self.mix_data( inputs, target.repeat(1, n_crops).view( -1, len(self.num_tasks)), self.device, self.hparams.alpha) logits = self.forward(inputs) if len(self.hparams.mixtype) > 0: loss_func = self.mixup_criterion( targets_a, targets_b, lam) loss = loss_func(self.criterion, logits) else: if self.hparams.multi_cls: target = target.view(target.size()[0], -1) loss = self.criterion( logits, target.repeat(1, n_crops).view(-1)) else: loss = self.criterion( logits, target.repeat(1, n_crops).view( -1, len(self.num_tasks))) else: if len(self.hparams.mixtype) > 0: inputs, targets_a, targets_b, lam = self.mix_data( inputs, target, self.device, self.hparams.alpha) logits = self.forward(inputs) if len(self.hparams.mixtype) > 0: loss_func = self.mixup_criterion( targets_a, targets_b, lam) loss = loss_func(self.criterion, logits) else: loss = self.criterion(logits, target) else: images_all = torch.cat(inputs, 0) logits_all = self.forward(images_all) logits_clean, logits_aug1, logits_aug2 = torch.split( logits_all, inputs[0].size(0)) # Cross-entropy is only computed on clean images loss = F.cross_entropy(logits_clean, target) p_clean, p_aug1, p_aug2 = F.softmax( logits_clean, dim=1), F.softmax(logits_aug1, dim=1), F.softmax(logits_aug2, dim=1) # Clamp mixture distribution to avoid exploding KL divergence p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log() loss += 12 * ( F.kl_div(p_mixture, p_clean, reduction='batchmean') + F.kl_div(p_mixture, p_aug1, reduction='batchmean') + F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3. assert not np.isnan( loss.item()), 'Model diverged with losses = NaN' loss.backward() self.optimizer.step() summary_train['step'] += 1 losses.update(loss.item(), target.size(0)) if summary_train['step'] % self.hparams.log_every == 0: time_spent = time.time() - time_now time_now = time.time() logging.info('Train, ' 'Epoch : {}, ' 'Step : {}/{}, ' 'Loss: {loss.val:.4f} ({loss.avg:.4f}), ' 'Run Time : {runtime:.2f} sec'.format( summary_train['epoch'] + 1, summary_train['step'], summary_train['total_step'], loss=losses, runtime=time_spent)) print('Train, ' 'Epoch : {}, ' 'Step : {}/{}, ' 'Loss: {loss.val:.4f} ({loss.avg:.4f}), ' 'Run Time : {runtime:.2f} sec'.format( summary_train['epoch'] + 1, summary_train['step'], summary_train['total_step'], loss=losses, runtime=time_spent)) if summary_train['step'] % self.hparams.test_every == 0: self.validation_end(summary_dev, summary_train, best_dict) self.model.train() torch.set_grad_enabled(True) summary_train['epoch'] += 1 return summary_train, best_dict
def train_epoch(epoch, data_loader, model, criterion, optimizer, opt, epoch_logger, batch_logger, save_features): print('train at epoch {}'.format(epoch)) model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() accuracies = AverageMeter() end_time = time.time() for i, (inputs, targets, video_ids) in enumerate(data_loader): data_time.update(time.time() - end_time) if not opt.no_cuda: targets = targets.cuda(async=True) inputs = Variable(inputs) targets = Variable(targets) if opt.save_features: model.module.label = video_ids[0] + str(targets.tolist()[0]) outputs = model(inputs) loss = criterion(outputs, targets) acc = calculate_accuracy(outputs, targets) losses.update(loss.data[0], inputs.size(0)) accuracies.update(acc, inputs.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end_time) end_time = time.time() batch_logger.log({ 'epoch': epoch, 'batch': i + 1, 'iter': (epoch - 1) * len(data_loader) + (i + 1), 'loss': losses.val, 'acc': accuracies.val, 'lr': optimizer.param_groups[0]['lr'] }) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.val:.3f} ({acc.avg:.3f})'.format( epoch, i + 1, len(data_loader), batch_time=batch_time, data_time=data_time, loss=losses, acc=accuracies)) epoch_logger.log({ 'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg, 'lr': optimizer.param_groups[0]['lr'] }) if epoch % opt.checkpoint == 0: save_file_path = os.path.join(opt.result_path, 'save_{}.pth'.format(epoch)) states = { 'epoch': epoch + 1, 'arch': opt.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(states, save_file_path)