def inference(): model = DFSeg_model.RedNet(num_classes=40, pretrained=False) #model = nn.DataParallel(model) load_ckpt(model, None, args.last_ckpt, device) model.eval() model.to(device) val_data = SUNRGBD(transform=torchvision.transforms.Compose([scaleNorm(), ToTensor(), Normalize()]), phase_train=False, data_dir=args.data_dir ) val_loader = DataLoader(val_data, batch_size=1, shuffle=False,num_workers=1, pin_memory=True) acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() a_meter = AverageMeter() b_meter = AverageMeter() with torch.no_grad(): for batch_idx, sample in enumerate(val_loader): #origin_image = sample['origin_image'].numpy() #origin_depth = sample['origin_depth'].numpy() image = sample['image'].to(device) depth = sample['depth'].to(device) label = sample['label'].numpy() with torch.no_grad(): pred = model(image, depth) output = torch.max(pred, 1)[1] + 1 output = output.squeeze(0).cpu().numpy() acc, pix = accuracy(output, label) intersection, union = intersectionAndUnion(output, label, args.num_class) acc_meter.update(acc, pix) a_m, b_m = macc(output, label, args.num_class) intersection_meter.update(intersection) union_meter.update(union) a_meter.update(a_m) b_meter.update(b_m) print('[{}] iter {}, accuracy: {}' .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), batch_idx, acc)) # img = image.cpu().numpy() # print('origin iamge: ', type(origin_image)) #if args.visualize: # visualize_result(origin_image, origin_depth, label-1, output-1, batch_idx, args) iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(i, _iou)) mAcc = (a_meter.average() / (b_meter.average()+1e-10)) print(mAcc.mean()) print('[Eval Summary]:') print('Mean IoU: {:.4}, Accuracy: {:.2f}%' .format(iou.mean(), acc_meter.average() * 100))
def train(): # 记录数据在tensorboard中显示 writer_loss = SummaryWriter(os.path.join(args.summary_dir, 'loss')) # writer_loss1 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss1')) # writer_loss2 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss2')) # writer_loss3 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss3')) writer_acc = SummaryWriter(os.path.join(args.summary_dir, 'macc')) # 准备数据集 train_data = data_eval.ReadData(transform=transforms.Compose([ data_eval.scaleNorm(), data_eval.RandomScale((1.0, 1.4)), data_eval.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)), data_eval.RandomCrop(image_h, image_w), data_eval.RandomFlip(), data_eval.ToTensor(), data_eval.Normalize() ]), data_dir=args.train_data_dir) train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) val_data = data_eval.ReadData(transform=transforms.Compose([ data_eval.scaleNorm(), data_eval.RandomScale((1.0, 1.4)), data_eval.RandomCrop(image_h, image_w), data_eval.ToTensor(), data_eval.Normalize() ]), data_dir=args.val_data_dir) val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) num_train = len(train_data) # num_val = len(val_data) # build model if args.last_ckpt: model = MultiTaskCNN_Atten(38, depth_channel=1, pretrained=False, arch='resnet50', use_aspp=True) else: model = MultiTaskCNN_Atten(38, depth_channel=1, pretrained=True, arch='resnet50', use_aspp=True) # build optimizer if args.optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(model.parameters(), args.lr) elif args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=1e-4) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), args.lr) else: # rmsprop print('not supported optimizer \n') return None global_step = 0 max_miou_val = 0 loss_count = 0 # 如果有模型的训练权重,则获取global_step,start_epoch if args.last_ckpt: global_step, args.start_epoch = load_ckpt(model, optimizer, args.last_ckpt, device) # if torch.cuda.device_count() > 1 and args.cuda and torch.cuda.is_available(): # print("Let's use", torch.cuda.device_count(), "GPUs!") # model = torch.nn.DataParallel(model).to(device) model = model.to(device) model.train() # cal_param(model, data) loss_func = nn.CrossEntropyLoss() for epoch in range(int(args.start_epoch), args.epochs): torch.cuda.empty_cache() # if epoch <= freeze_epoch: # for layer in [model.conv1, model.maxpool,model.layer1, model.layer2, model.layer3, model.layer4]: # for param in layer.parameters(): # param.requires_grad = False tq = tqdm(total=len(train_loader) * args.batch_size) if loss_count >= 10: args.lr = 0.5 * args.lr loss_count = 0 lr = poly_lr_scheduler(optimizer, args.lr, iter=epoch, max_iter=args.epochs) optimizer.param_groups[0]['lr'] = lr # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.5) tq.set_description('epoch %d, lr %f' % (epoch, args.lr)) loss_record = [] # loss1_record = [] # loss2_record = [] # loss3_record = [] local_count = 0 # print('1') for batch_idx, data in enumerate(train_loader): image = data['image'].to(device) depth = data['depth'].to(device) label = data['label'].long().to(device) # print('label', label.shape) output, output_sup1, output_sup2 = model(image, depth) loss1 = loss_func(output, label) loss2 = loss_func(output_sup1, label) loss3 = loss_func(output_sup2, label) loss = loss1 + loss2 + loss3 tq.update(args.batch_size) tq.set_postfix(loss='%.6f' % loss) optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 local_count += image.data.shape[0] # writer_loss.add_scalar('loss_step', loss, global_step) # writer_loss1.add_scalar('loss1_step', loss1, global_step) # writer_loss2.add_scalar('loss2_step', loss2, global_step) # writer_loss3.add_scalar('loss3_step', loss3, global_step) loss_record.append(loss.item()) # loss1_record.append(loss1.item()) # loss2_record.append(loss2.item()) # loss3_record.append(loss3.item()) if global_step % args.print_freq == 0 or global_step == 1: for name, param in model.named_parameters(): writer_loss.add_histogram(name, param.clone().cpu().data.numpy(), global_step, bins='doane') writer_loss.add_graph(model, [image, depth]) grid_image1 = make_grid(image[:3].clone().cpu().data, 3, normalize=True) writer_loss.add_image('image', grid_image1, global_step) grid_image2 = make_grid(depth[:3].clone().cpu().data, 3, normalize=True) writer_loss.add_image('depth', grid_image2, global_step) grid_image3 = make_grid(utils.color_label( torch.max(output[:3], 1)[1]), 3, normalize=False, range=(0, 255)) writer_loss.add_image('Predicted label', grid_image3, global_step) grid_image4 = make_grid(utils.color_label(label[:3]), 3, normalize=False, range=(0, 255)) writer_loss.add_image('Groundtruth label', grid_image4, global_step) tq.close() loss_train_mean = np.mean(loss_record) with open(log_file, 'a') as f: f.write(str(epoch) + '\t' + str(loss_train_mean)) # loss1_train_mean = np.mean(loss1_record) # loss2_train_mean = np.mean(loss2_record) # loss3_train_mean = np.mean(loss3_record) writer_loss.add_scalar('epoch/loss_epoch_train', float(loss_train_mean), epoch) # writer_loss1.add_scalar('epoch/sub_loss_epoch_train', float(loss1_train_mean), epoch) # writer_loss2.add_scalar('epoch/sub_loss_epoch_train', float(loss2_train_mean), epoch) # writer_loss3.add_scalar('epoch/sub_loss_epoch_train', float(loss3_train_mean), epoch) print('loss for train : %f' % loss_train_mean) print('----validation starting----') # tq_val = tqdm(total=len(val_loader) * args.batch_size) # tq_val.set_description('epoch %d' % epoch) model.eval() val_total_time = 0 with torch.no_grad(): sys.stdout.flush() tbar = tqdm(val_loader) acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() a_meter = AverageMeter() b_meter = AverageMeter() for batch_idx, sample in enumerate(tbar): # origin_image = sample['origin_image'].numpy() # origin_depth = sample['origin_depth'].numpy() image_val = sample['image'].to(device) depth_val = sample['depth'].to(device) label_val = sample['label'].numpy() with torch.no_grad(): start = time.time() pred = model(image_val, depth_val) end = time.time() duration = end - start val_total_time += duration # tq_val.set_postfix(fps ='%.4f' % (args.batch_size / (end - start))) print_str = 'Test step [{}/{}].'.format( batch_idx + 1, len(val_loader)) tbar.set_description(print_str) output_val = torch.max(pred, 1)[1] output_val = output_val.squeeze(0).cpu().numpy() acc, pix = accuracy(output_val, label_val) intersection, union = intersectionAndUnion( output_val, label_val, args.num_class) acc_meter.update(acc, pix) a_m, b_m = macc(output_val, label_val, args.num_class) intersection_meter.update(intersection) union_meter.update(union) a_meter.update(a_m) b_meter.update(b_m) fps = len(val_loader) / val_total_time print('fps = %.4f' % fps) tbar.close() mAcc = (a_meter.average() / (b_meter.average() + 1e-10)) with open(log_file, 'a') as f: f.write(' ' + str(mAcc.mean()) + '\n') iou = intersection_meter.sum / (union_meter.sum + 1e-10) writer_acc.add_scalar('epoch/Acc_epoch_train', mAcc.mean(), epoch) print('----validation finished----') model.train() # # 每隔save_epoch_freq个epoch就保存一次权重 if epoch != args.start_epoch: if iou.mean() >= max_miou_val: print('mIoU:', iou.mean()) if not os.path.isdir(args.ckpt_dir): os.mkdir(args.ckpt_dir) save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch, local_count, num_train) max_miou_val = iou.mean() # max_macc_val = mAcc.mean() else: loss_count += 1 torch.cuda.empty_cache()
def inference(): writer_image = SummaryWriter(os.path.join(args.summary_dir, 'segtest')) model = MultiTaskCNN(38, depth_channel=1, pretrained=False, arch='resnet50', use_aspp=False) load_ckpt(model, None, args.last_ckpt, device) model.eval() model = model.to(device) val_data = data_eval.ReadData(transform=torchvision.transforms.Compose( [data_eval.scaleNorm(), data_eval.ToTensor(), Normalize()]), data_dir=args.data_dir) val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=4, pin_memory=False) acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() a_meter = AverageMeter() b_meter = AverageMeter() test_total_time = 0 with torch.no_grad(): for batch_idx, sample in enumerate(val_loader): # origin_image = sample['origin_image'].to(device) # origin_depth = sample['origin_depth'].to(device) image = sample['image'].to(device) depth = sample['depth'].to(device) label = sample['label'].numpy() show_label = sample['label'].long().to(device) with torch.no_grad(): time1 = time.time() pred = model(image, depth) time2 = time.time() test_total_time += (time2 - time1) output = torch.max(pred, 1)[1] # # output = output.squeeze(0).cpu().numpy() output = output.cpu().numpy() acc, pix = accuracy(output, label) intersection, union = intersectionAndUnion(output, label, args.num_class) acc_meter.update(acc, pix) a_m, b_m = macc(output, label, args.num_class) intersection_meter.update(intersection) union_meter.update(union) a_meter.update(a_m) b_meter.update(b_m) if batch_idx % 50 == 0: grid_image1 = make_grid(image[:1].clone().cpu().data, 1, normalize=True) writer_image.add_image('image', grid_image1, batch_idx) grid_image2 = make_grid(depth[:1].clone().cpu().data, 1, normalize=True) writer_image.add_image('depth', grid_image2, batch_idx) grid_image3 = make_grid(utils.color_label( torch.max(pred[:1], 1)[1]), 1, normalize=False, range=(0, 255)) writer_image.add_image('Predicted label', grid_image3, batch_idx) grid_image4 = make_grid(utils.color_label(show_label[:1]), 1, normalize=False, range=(0, 255)) writer_image.add_image('Groundtruth label', grid_image4, batch_idx) print('[{}] iter {}, accuracy: {}'.format( datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), batch_idx, acc)) # if batch_idx % 1 == 0: # if args.visualize: # visualize_result(origin_image, origin_depth, label, output, batch_idx, args) # visualize_result(origin_image, origin_depth, label - 1, output - 1, batch_idx, args) print('推理时间:', test_total_time / len(val_data), '\nfps:', len(val_data) / test_total_time) iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(i, _iou)) # mAcc:Prediction和Ground Truth对应位置的“分类”准确率(每个像素) mAcc = (a_meter.average() / (b_meter.average() + 1e-10)) print(mAcc.mean()) print('[Eval Summary]:') print('Mean IoU: {:.4}, Accuracy: {:.2f}%'.format( iou.mean(), acc_meter.average() * 100))
def evaluate(): model = ACNet_models_V1.ACNet(num_class=5, pretrained=False) load_ckpt(model, None, None, args.last_ckpt, device) model.eval() model.to(device) val_data = ACNet_data.FreiburgForest( transform=torchvision.transforms.Compose([ ACNet_data.ScaleNorm(), ACNet_data.ToTensor(), ACNet_data.Normalize() ]), data_dirs=[args.test_dir], modal1_name=args.modal1, modal2_name=args.modal2, ) val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() a_meter = AverageMeter() b_meter = AverageMeter() with torch.no_grad(): for batch_idx, sample in enumerate(val_loader): modal1 = sample['modal1'].to(device) modal2 = sample['modal2'].to(device) label = sample['label'].numpy() basename = sample['basename'][0] with torch.no_grad(): pred = model(modal1, modal2) output = torch.argmax(pred, 1) + 1 output = output.squeeze(0).cpu().numpy() acc, pix = accuracy(output, label) intersection, union = intersectionAndUnion(output, label, args.num_class) acc_meter.update(acc, pix) a_m, b_m = macc(output, label, args.num_class) intersection_meter.update(intersection) union_meter.update(union) a_meter.update(a_m) b_meter.update(b_m) print('[{}] iter {}, accuracy: {}' .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), batch_idx, acc)) if args.visualize: visualize_result(modal1, modal2, label, output, batch_idx, args) if args.save_predictions: colored_output = utils.color_label_eval(output).astype(np.uint8) imageio.imwrite(f'{args.output_dir}/{basename}_pred.png', colored_output.transpose([1, 2, 0])) iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(i, _iou)) mAcc = (a_meter.average() / (b_meter.average() + 1e-10)) print(mAcc.mean()) print('[Eval Summary]:') print('Mean IoU: {:.4}, Accuracy: {:.2f}%' .format(iou.mean(), acc_meter.average() * 100))