def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() loss_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() top1_meter = AverageMeter() top5_meter = AverageMeter() model.train() end = time.time() max_iter = CONFIG.epochs * len(train_loader) for i, (input, target) in enumerate(train_loader): data_time.update(time.time() - end) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) if CONFIG.mixup_alpha: eps = CONFIG.label_smoothing if CONFIG.label_smoothing else 0.0 input, target_a, target_b, lam = mixup_data( input, target, CONFIG.mixup_alpha) output = model(input) loss = mixup_loss(output, target_a, target_b, lam, eps) else: output = model(input) loss = (smooth_loss(output, target, CONFIG.label_smoothing) if CONFIG.label_smoothing else criterion(output, target)) optimizer.zero_grad() loss.backward() optimizer.step() top1, top5 = cal_accuracy(output, target, topk=(1, 5)) n = input.size(0) if CONFIG.multiprocessing_distributed: with torch.no_grad(): loss, top1, top5 = loss.detach() * n, top1 * n, top5 * n count = target.new_tensor([n], dtype=torch.long) distributed.all_reduce(loss) distributed.all_reduce(top1) distributed.all_reduce(top5) distributed.all_reduce(count) n = count.item() loss, top1, top5 = loss / n, top1 / n, top5 / n loss_meter.update(loss.item(), n), top1_meter.update( top1.item(), n), top5_meter.update(top5.item(), n) output = output.max(1)[1] intersection, union, target = intersection_and_union_gpu( output, target, train_loader.dataset.response_shape[0], CONFIG.ignore_label) if CONFIG.multiprocessing_distributed: distributed.all_reduce(intersection) distributed.all_reduce(union) distributed.all_reduce(target) intersection, union, target = ( intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy(), ) intersection_meter.update(intersection), union_meter.update( union), target_meter.update(target) accuracy = sum( intersection_meter.val) / (sum(target_meter.val) + 1e-10) batch_time.update(time.time() - end) end = time.time() # calculate remain time current_iter = epoch * len(train_loader) + i + 1 remain_iter = max_iter - current_iter remain_time = remain_iter * batch_time.avg t_m, t_s = divmod(remain_time, 60) t_h, t_m = divmod(t_m, 60) remain_time = f"{int(t_h):02d}:{int(t_m):02d}:{int(t_s):02d}" if ((i + 1) % CONFIG.print_freq == 0) and is_main_process(): logger.info( f"Epoch: [{epoch + 1}/{CONFIG.epochs}][{i + 1}/{len(train_loader)}] Data {data_time.val:.3f} (" f"{data_time.avg:.3f}) Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) Remain {remain_time} Loss " f"{loss_meter.val:.4f} Accuracy {accuracy:.4f} Acc@1 {top1_meter.val:.3f} ({top1_meter.avg:.3f}) " f"Acc@5 {top5_meter.val:.3f} ({top5_meter.avg:.3f}).") if is_main_process(): writer.scalar("loss_train_batch", loss_meter.val, current_iter) writer.scalar( "mIoU_train_batch", numpy.mean(intersection / (union + 1e-10)), current_iter, ) writer.scalar( "mAcc_train_batch", numpy.mean(intersection / (target + 1e-10)), current_iter, ) writer.scalar("allAcc_train_batch", accuracy, current_iter) writer.scalar("top1_train_batch", top1, current_iter) writer.scalar("top5_train_batch", top5, current_iter) iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) mIoU = numpy.mean(iou_class) mAcc = numpy.mean(accuracy_class) allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) if is_main_process(): logger.info( f"Train result at epoch [{epoch + 1}/{CONFIG.epochs}]: mIoU/mAcc/allAcc/top1/top5 {mIoU:.4f}/" f"{mAcc:.4f}/{allAcc:.4f}/{top1_meter.avg:.4f}/{top5_meter.avg:.4f}." ) return loss_meter.avg, mIoU, mAcc, allAcc, top1_meter.avg, top5_meter.avg
def validate(val_loader, model, criterion): if is_main_process(): logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") batch_time = AverageMeter() data_time = AverageMeter() loss_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() top1_meter = AverageMeter() top5_meter = AverageMeter() model.eval() end = time.time() for i, (input, target) in enumerate(val_loader): data_time.update(time.time() - end) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(input) loss = criterion(output, target) top1, top5 = cal_accuracy(output, target, topk=(1, 5)) n = input.size(0) if CONFIG.multiprocessing_distributed: with torch.no_grad(): loss, top1, top5 = loss.detach() * n, top1 * n, top5 * n count = target.new_tensor([n], dtype=torch.long) distributed.all_reduce(loss), distributed.all_reduce( top1), distributed.all_reduce( top5), distributed.all_reduce(count) n = count.item() loss, top1, top5 = loss / n, top1 / n, top5 / n loss_meter.update(loss.item(), n), top1_meter.update( top1.item(), n), top5_meter.update(top5.item(), n) output = output.max(1)[1] intersection, union, target = intersection_and_union_gpu( output, target, val_loader.dataset.response_shape[0], CONFIG.ignore_label) if CONFIG.multiprocessing_distributed: distributed.all_reduce(intersection), distributed.all_reduce( union), distributed.all_reduce(target) intersection, union, target = ( intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy(), ) intersection_meter.update(intersection), union_meter.update( union), target_meter.update(target) accuracy = sum( intersection_meter.val) / (sum(target_meter.val) + 1e-10) batch_time.update(time.time() - end) end = time.time() if ((i + 1) % CONFIG.print_freq == 0) and is_main_process(): logger.info( f"Test: [{i + 1}/{len(val_loader)}] Data {data_time.val:.3f} ({data_time.avg:.3f}) Batch " f"{batch_time.val:.3f} ({batch_time.avg:.3f}) Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) " f"Accuracy {accuracy:.4f} Acc@1 {top1_meter.val:.3f} ({top1_meter.avg:.3f}) Acc@5 " f"{top5_meter.val:.3f} ({top5_meter.avg:.3f}).") iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) mIoU = numpy.mean(iou_class) mAcc = numpy.mean(accuracy_class) allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) if is_main_process(): logger.info( f"Val result: mIoU/mAcc/allAcc/top1/top5 {mIoU:.4f}/{mAcc:.4f}/{allAcc:.4f}/{top1_meter.avg:.4f}/" f"{top5_meter.avg:.4f}.") for i in range(val_loader.dataset.response_shape[0]): if target_meter.sum[i] > 0: logger.info( f"Class_{i} Result: iou/accuracy {iou_class[i]:.4f}/{accuracy_class[i]:.4f} Count:{target_meter.sum[i]}" ) logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") return loss_meter.avg, mIoU, mAcc, allAcc, top1_meter.avg, top5_meter.avg
def main(shuffle: bool = True, how_many_batches=10, batch_size=1): def get_logger(): logger_name = "main-logger" logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) handler = logging.StreamHandler() fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" handler.setFormatter(logging.Formatter(fmt)) logger.addHandler(handler) return logger from samples.classification.san.configs.imagenet_san10_patchwise import ( SAN_CONFIG, ) dataset = SAN_CONFIG.dataset_type(SAN_CONFIG.dataset_path, Split.Validation) logger = get_logger() logger.info(SAN_CONFIG) logger.info("=> creating model ...") logger.info(f"Classes: {dataset.response_shape[0]}") os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( str(x) for x in SAN_CONFIG.test_gpu) model = make_san( self_attention_type=SelfAttentionTypeEnum( SAN_CONFIG.self_attention_type), layers=SAN_CONFIG.layers, kernels=SAN_CONFIG.kernels, num_classes=dataset.response_shape[0], ) logger.info(model) model = torch.nn.DataParallel(model.cuda()) if os.path.isdir(SAN_CONFIG.save_path): logger.info(f"=> loading checkpoint '{SAN_CONFIG.model_path}'") checkpoint = torch.load(SAN_CONFIG.model_path) model.load_state_dict(checkpoint["state_dict"], strict=True) logger.info(f"=> loaded checkpoint '{SAN_CONFIG.model_path}'") else: raise RuntimeError( f"=> no checkpoint found at '{SAN_CONFIG.model_path}'") criterion = nn.CrossEntropyLoss(ignore_index=SAN_CONFIG.ignore_label) val_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=SAN_CONFIG.test_workers, pin_memory=True, ) logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") batch_time = AverageMeter() data_time = AverageMeter() loss_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() target_meter = AverageMeter() top1_meter = AverageMeter() top5_meter = AverageMeter() model.eval() end = time.time() if how_many_batches: T = range(how_many_batches) else: T = count() for i, (input, target) in zip(T, val_loader): data_time.update(time.time() - end) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) with torch.no_grad(): output = model(input) pyplot.imshow(dataset.inverse_base_transform(input[0].cpu())) pyplot.title( f"pred:{dataset.category_names[output.max(1)[1][0].item()]} truth:{dataset.category_names[target[0].item()]}" ) pyplot.show() loss = criterion(output, target) top1, top5 = cal_accuracy(output, target, topk=(1, 5)) n = input.size(0) loss_meter.update(loss.item(), n), top1_meter.update( top1.item(), n), top5_meter.update(top5.item(), n) intersection, union, target = intersection_and_union_gpu( output.max(1)[1], target, dataset.response_shape[0], SAN_CONFIG.ignore_label, ) intersection, union, target = ( intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy(), ) intersection_meter.update(intersection), union_meter.update( union), target_meter.update(target) accuracy = sum( intersection_meter.val) / (sum(target_meter.val) + 1e-10) batch_time.update(time.time() - end) end = time.time() if (i + 1) % SAN_CONFIG.print_freq == 0: logger.info( f"Test: [{i + 1}/{len(val_loader)}] Data {data_time.val:.3f} ({data_time.avg:.3f}) Batch " f"{batch_time.val:.3f} ({batch_time.avg:.3f}) Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) " f"Accuracy {accuracy:.4f} Acc@1 {top1_meter.val:.3f} ({top1_meter.avg:.3f}) Acc@5 " f"{top5_meter.val:.3f} ({top5_meter.avg:.3f}).") iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) mIoU = numpy.mean(iou_class) mAcc = numpy.mean(accuracy_class) allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) logger.info( f"Val result: mIoU/mAcc/allAcc/top1/top5 {mIoU:.4f}/{mAcc:.4f}/{allAcc:.4f}/{top1_meter.avg:.4f}/" f"{top5_meter.avg:.4f}.") for i in range(dataset.response_shape[0]): if target_meter.sum[i] > 0: logger.info( f"Class_{i} Result: iou/accuracy {iou_class[i]:.4f}/{accuracy_class[i]:.4f} Count:{target_meter.sum[i]}" ) logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") print(loss_meter.avg, mIoU, mAcc, allAcc, top1_meter.avg, top5_meter.avg)