Пример #1
0
def inference():
    model = DFSeg_model.RedNet(num_classes=40, pretrained=False)
    #model = nn.DataParallel(model)
    load_ckpt(model, None, args.last_ckpt, device)
    model.eval()
    model.to(device)

    val_data = SUNRGBD(transform=torchvision.transforms.Compose([scaleNorm(),
                                                                   ToTensor(),
                                                                   Normalize()]),
                                   phase_train=False,
                                   data_dir=args.data_dir
                                   )
    val_loader = DataLoader(val_data, batch_size=1, shuffle=False,num_workers=1, pin_memory=True)

    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    a_meter = AverageMeter()
    b_meter = AverageMeter()
    with torch.no_grad():
        for batch_idx, sample in enumerate(val_loader):
            #origin_image = sample['origin_image'].numpy()
            #origin_depth = sample['origin_depth'].numpy()
            image = sample['image'].to(device)
            depth = sample['depth'].to(device)
            label = sample['label'].numpy()

            with torch.no_grad():
                pred = model(image, depth)

            output = torch.max(pred, 1)[1] + 1
            output = output.squeeze(0).cpu().numpy()

            acc, pix = accuracy(output, label)
            intersection, union = intersectionAndUnion(output, label, args.num_class)
            acc_meter.update(acc, pix)
            a_m, b_m = macc(output, label, args.num_class)
            intersection_meter.update(intersection)
            union_meter.update(union)
            a_meter.update(a_m)
            b_meter.update(b_m)
            print('[{}] iter {}, accuracy: {}'
                  .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                          batch_idx, acc))

            # img = image.cpu().numpy()
            # print('origin iamge: ', type(origin_image))
            #if args.visualize:
            #    visualize_result(origin_image, origin_depth, label-1, output-1, batch_idx, args)

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(i, _iou))

    mAcc = (a_meter.average() / (b_meter.average()+1e-10))
    print(mAcc.mean())
    print('[Eval Summary]:')
    print('Mean IoU: {:.4}, Accuracy: {:.2f}%'
          .format(iou.mean(), acc_meter.average() * 100))
Пример #2
0
def train():
    # 记录数据在tensorboard中显示
    writer_loss = SummaryWriter(os.path.join(args.summary_dir, 'loss'))
    # writer_loss1 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss1'))
    # writer_loss2 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss2'))
    # writer_loss3 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss3'))
    writer_acc = SummaryWriter(os.path.join(args.summary_dir, 'macc'))

    # 准备数据集
    train_data = data_eval.ReadData(transform=transforms.Compose([
        data_eval.scaleNorm(),
        data_eval.RandomScale((1.0, 1.4)),
        data_eval.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)),
        data_eval.RandomCrop(image_h, image_w),
        data_eval.RandomFlip(),
        data_eval.ToTensor(),
        data_eval.Normalize()
    ]),
                                    data_dir=args.train_data_dir)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=False,
                              drop_last=True)
    val_data = data_eval.ReadData(transform=transforms.Compose([
        data_eval.scaleNorm(),
        data_eval.RandomScale((1.0, 1.4)),
        data_eval.RandomCrop(image_h, image_w),
        data_eval.ToTensor(),
        data_eval.Normalize()
    ]),
                                  data_dir=args.val_data_dir)
    val_loader = DataLoader(val_data,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=False,
                            drop_last=True)
    num_train = len(train_data)
    # num_val = len(val_data)

    # build model
    if args.last_ckpt:
        model = MultiTaskCNN_Atten(38,
                                   depth_channel=1,
                                   pretrained=False,
                                   arch='resnet50',
                                   use_aspp=True)
    else:
        model = MultiTaskCNN_Atten(38,
                                   depth_channel=1,
                                   pretrained=True,
                                   arch='resnet50',
                                   use_aspp=True)

    # build optimizer
    if args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), args.lr)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None
    global_step = 0
    max_miou_val = 0
    loss_count = 0
    # 如果有模型的训练权重,则获取global_step,start_epoch
    if args.last_ckpt:
        global_step, args.start_epoch = load_ckpt(model, optimizer,
                                                  args.last_ckpt, device)
    # if torch.cuda.device_count() > 1 and args.cuda and torch.cuda.is_available():
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     model = torch.nn.DataParallel(model).to(device)
    model = model.to(device)
    model.train()
    # cal_param(model, data)
    loss_func = nn.CrossEntropyLoss()
    for epoch in range(int(args.start_epoch), args.epochs):
        torch.cuda.empty_cache()
        # if epoch <= freeze_epoch:
        #     for layer in [model.conv1, model.maxpool,model.layer1, model.layer2, model.layer3, model.layer4]:
        #         for param in layer.parameters():
        #             param.requires_grad = False
        tq = tqdm(total=len(train_loader) * args.batch_size)
        if loss_count >= 10:
            args.lr = 0.5 * args.lr
            loss_count = 0
        lr = poly_lr_scheduler(optimizer,
                               args.lr,
                               iter=epoch,
                               max_iter=args.epochs)
        optimizer.param_groups[0]['lr'] = lr
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.5)
        tq.set_description('epoch %d, lr %f' % (epoch, args.lr))
        loss_record = []
        # loss1_record = []
        # loss2_record = []
        # loss3_record = []
        local_count = 0
        # print('1')
        for batch_idx, data in enumerate(train_loader):
            image = data['image'].to(device)
            depth = data['depth'].to(device)
            label = data['label'].long().to(device)
            # print('label', label.shape)
            output, output_sup1, output_sup2 = model(image, depth)
            loss1 = loss_func(output, label)
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)
            loss = loss1 + loss2 + loss3
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1
            local_count += image.data.shape[0]
            # writer_loss.add_scalar('loss_step', loss, global_step)
            # writer_loss1.add_scalar('loss1_step', loss1, global_step)
            # writer_loss2.add_scalar('loss2_step', loss2, global_step)
            # writer_loss3.add_scalar('loss3_step', loss3, global_step)
            loss_record.append(loss.item())
            # loss1_record.append(loss1.item())
            # loss2_record.append(loss2.item())
            # loss3_record.append(loss3.item())
            if global_step % args.print_freq == 0 or global_step == 1:
                for name, param in model.named_parameters():
                    writer_loss.add_histogram(name,
                                              param.clone().cpu().data.numpy(),
                                              global_step,
                                              bins='doane')
                writer_loss.add_graph(model, [image, depth])
                grid_image1 = make_grid(image[:3].clone().cpu().data,
                                        3,
                                        normalize=True)
                writer_loss.add_image('image', grid_image1, global_step)
                grid_image2 = make_grid(depth[:3].clone().cpu().data,
                                        3,
                                        normalize=True)
                writer_loss.add_image('depth', grid_image2, global_step)
                grid_image3 = make_grid(utils.color_label(
                    torch.max(output[:3], 1)[1]),
                                        3,
                                        normalize=False,
                                        range=(0, 255))
                writer_loss.add_image('Predicted label', grid_image3,
                                      global_step)
                grid_image4 = make_grid(utils.color_label(label[:3]),
                                        3,
                                        normalize=False,
                                        range=(0, 255))
                writer_loss.add_image('Groundtruth label', grid_image4,
                                      global_step)

        tq.close()
        loss_train_mean = np.mean(loss_record)
        with open(log_file, 'a') as f:
            f.write(str(epoch) + '\t' + str(loss_train_mean))
        # loss1_train_mean = np.mean(loss1_record)
        # loss2_train_mean = np.mean(loss2_record)
        # loss3_train_mean = np.mean(loss3_record)
        writer_loss.add_scalar('epoch/loss_epoch_train',
                               float(loss_train_mean), epoch)
        # writer_loss1.add_scalar('epoch/sub_loss_epoch_train', float(loss1_train_mean), epoch)
        # writer_loss2.add_scalar('epoch/sub_loss_epoch_train', float(loss2_train_mean), epoch)
        # writer_loss3.add_scalar('epoch/sub_loss_epoch_train', float(loss3_train_mean), epoch)
        print('loss for train : %f' % loss_train_mean)
        print('----validation starting----')
        # tq_val = tqdm(total=len(val_loader) * args.batch_size)
        # tq_val.set_description('epoch %d' % epoch)
        model.eval()

        val_total_time = 0
        with torch.no_grad():
            sys.stdout.flush()
            tbar = tqdm(val_loader)
            acc_meter = AverageMeter()
            intersection_meter = AverageMeter()
            union_meter = AverageMeter()
            a_meter = AverageMeter()
            b_meter = AverageMeter()
            for batch_idx, sample in enumerate(tbar):

                # origin_image = sample['origin_image'].numpy()
                # origin_depth = sample['origin_depth'].numpy()
                image_val = sample['image'].to(device)
                depth_val = sample['depth'].to(device)
                label_val = sample['label'].numpy()

                with torch.no_grad():
                    start = time.time()
                    pred = model(image_val, depth_val)
                    end = time.time()
                    duration = end - start
                    val_total_time += duration
                # tq_val.set_postfix(fps ='%.4f' % (args.batch_size / (end - start)))
                print_str = 'Test step [{}/{}].'.format(
                    batch_idx + 1, len(val_loader))
                tbar.set_description(print_str)

                output_val = torch.max(pred, 1)[1]
                output_val = output_val.squeeze(0).cpu().numpy()

                acc, pix = accuracy(output_val, label_val)
                intersection, union = intersectionAndUnion(
                    output_val, label_val, args.num_class)
                acc_meter.update(acc, pix)
                a_m, b_m = macc(output_val, label_val, args.num_class)
                intersection_meter.update(intersection)
                union_meter.update(union)
                a_meter.update(a_m)
                b_meter.update(b_m)
        fps = len(val_loader) / val_total_time
        print('fps = %.4f' % fps)
        tbar.close()
        mAcc = (a_meter.average() / (b_meter.average() + 1e-10))
        with open(log_file, 'a') as f:
            f.write('                    ' + str(mAcc.mean()) + '\n')
        iou = intersection_meter.sum / (union_meter.sum + 1e-10)
        writer_acc.add_scalar('epoch/Acc_epoch_train', mAcc.mean(), epoch)
        print('----validation finished----')
        model.train()
        # # 每隔save_epoch_freq个epoch就保存一次权重
        if epoch != args.start_epoch:
            if iou.mean() >= max_miou_val:
                print('mIoU:', iou.mean())
                if not os.path.isdir(args.ckpt_dir):
                    os.mkdir(args.ckpt_dir)
                save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch,
                          local_count, num_train)
                max_miou_val = iou.mean()
                # max_macc_val = mAcc.mean()
            else:
                loss_count += 1
        torch.cuda.empty_cache()
Пример #3
0
def evaluate(nets, loader, args):
    loss_pred1_meter = AverageMeter()
    loss_pred2_meter = AverageMeter()
    #loss_pred_outputs_meter = AverageMeter()
    
    acc_pred1_meter = AverageMeter()
    acc_pred2_meter = AverageMeter()
    #acc_pred_outputs_meter = AverageMeter()
    
    intersection_pred1_meter = AverageMeter()
    intersection_pred2_meter = AverageMeter()
    #intersection_pred_outputs_meter = AverageMeter()
    
    union_pred1_meter = AverageMeter()
    union_pred2_meter = AverageMeter()
    #union_pred_outputs_meter = AverageMeter()
    
    for model in nets:
        model.eval()
        
    for i, batch_data in enumerate(loader):
        # forward pass
        if i % 100 == 0:
            print('{:d} processd'.format(i))
           
        #pred1, pred2, pred_outputs, loss_pred1, loss_pred2, loss_pred_outputs = forward_multiscale(nets, batch_data, args)
        pred1, pred2, loss_pred1, loss_pred2 = forward_multiscale(nets, batch_data, args)
        loss_pred1_meter.update(loss_pred1.data[0])
        loss_pred2_meter.update(loss_pred2.data[0])
        #loss_pred_outputs_meter.update(loss_pred_outputs.data[0])
        
        # calculate accuracy
        acc_pred1, pix_pred1 = accuracy(batch_data, pred1)
        intersection_pred1, union_pred1 = intersectionAndUnion(batch_data, pred1, args.num_classes)
        
        acc_pred2, pix_pred2 = accuracy(batch_data, pred2)
        intersection_pred2, union_pred2 = intersectionAndUnion(batch_data, pred2, args.num_classes)
        
        #acc_pred_outputs, pix_pred_outputs = accuracy(batch_data, pred_outputs)
        #intersection_pred_outputs, union_pred_outputs = intersectionAndUnion(batch_data, pred_outputs, args.num_classes)
        
        acc_pred1_meter.update(acc_pred1, pix_pred1)
        intersection_pred1_meter.update(intersection_pred1)
        union_pred1_meter.update(union_pred1)
        
        acc_pred2_meter.update(acc_pred2, pix_pred2)
        intersection_pred2_meter.update(intersection_pred2)
        union_pred2_meter.update(union_pred2)
        
        #acc_pred_outputs_meter.update(acc_pred_outputs, pix_pred_outputs)
        #intersection_pred_outputs_meter.update(intersection_pred_outputs)
        #union_pred_outputs_meter.update(union_pred_outputs)
        
        print('[{}] iter {}, loss_pred1: {} loss_pred2: {}, Accurarcy_pred1: {} Accurarcy_pred2: {}'
              .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i, loss_pred1.data[0], loss_pred2.data[0], acc_pred1, acc_pred2))
       
        
        # visualization
        if args.visualize:
            visualize_result(batch_data, pred1, pred2, args)
           
    iou_pred1 = intersection_pred1_meter.sum / (union_pred1_meter.sum + 1e-10)
    iou_pred2 = intersection_pred2_meter.sum / (union_pred2_meter.sum + 1e-10)
    #iou_pred_outputs = intersection_pred_outputs_meter.sum / (union_pred_outputs_meter.sum + 1e-10)
    '''
    for i , _iou_pred1 in enumerate(iou_pred1):
        for j, _iou_pred2 in enumerate(iou_pred2):
                for k, _iou_pred_outputs in enumerate(iou_pred_outputs):
                    if k == (j == i):
                    #print('class [{}], IoU_pred1: {}, IoU_pred2: {}, IoU_pred_outputs: {}'.format(i, _iou_pred1, _iou_pred2, _iou_pred_outputs) )
                    
                    print('class [{}], IoU_pred1: {}, IoU_pred2: {}'.format(i, _iou_pred1, _iou_pred2)) 
                    break
    
    for i, _iou_pred1, _iou_pred2, _iou_pred_outputs in list(zip(iou_pred1, iou_pred2, iou_pred_outputs )):
        print('class [{}], IoU_pred1: {}, IoU_pred2: {}'.format(i, _iou_pred1, _iou_pred2))
    '''
    iou = list(zip(iou_pred1, iou_pred2))
    for i, (_iou_pred1, _iou_pred2) in enumerate(iou):
        print('class [{}],\n IoU_pred1: {},\n IoU_pred2: {}\n'.format(i, _iou_pred1, _iou_pred2))
        #print('class [{}],\n IoU_pred1: {},\n IoU_pred2: {},\n IoU_pred_outputs: {}\n'.format(i, _iou_pred1, _iou_pred2, _iou_pred_outputs))
    
    
    print('[Eval Summary]:')
    print('Loss_pred1: {},\n Loss_pred2: {},\n Mean IoU_pred1: {:.2f}%,\n Mean IoU_pred2: {:.2f}%,\n  Accurarcy_pred1: {:.2f}%,\n Accurarcy_pred2: {:.2f}%,\n'
          .format(loss_pred1_meter.average(), loss_pred2_meter.average(), iou_pred1.mean()*100, iou_pred2.mean()*100, acc_pred1_meter.average()*100, acc_pred2_meter.average()*100))
Пример #4
0
def inference():
    writer_image = SummaryWriter(os.path.join(args.summary_dir, 'segtest'))
    model = MultiTaskCNN(38,
                         depth_channel=1,
                         pretrained=False,
                         arch='resnet50',
                         use_aspp=False)
    load_ckpt(model, None, args.last_ckpt, device)
    model.eval()
    model = model.to(device)
    val_data = data_eval.ReadData(transform=torchvision.transforms.Compose(
        [data_eval.scaleNorm(),
         data_eval.ToTensor(),
         Normalize()]),
                                  data_dir=args.data_dir)
    val_loader = DataLoader(val_data,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=False)

    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    a_meter = AverageMeter()
    b_meter = AverageMeter()
    test_total_time = 0
    with torch.no_grad():
        for batch_idx, sample in enumerate(val_loader):
            # origin_image = sample['origin_image'].to(device)
            # origin_depth = sample['origin_depth'].to(device)
            image = sample['image'].to(device)
            depth = sample['depth'].to(device)
            label = sample['label'].numpy()
            show_label = sample['label'].long().to(device)

            with torch.no_grad():
                time1 = time.time()
                pred = model(image, depth)
                time2 = time.time()
                test_total_time += (time2 - time1)
            output = torch.max(pred, 1)[1]
            # # output = output.squeeze(0).cpu().numpy()
            output = output.cpu().numpy()
            acc, pix = accuracy(output, label)
            intersection, union = intersectionAndUnion(output, label,
                                                       args.num_class)
            acc_meter.update(acc, pix)
            a_m, b_m = macc(output, label, args.num_class)
            intersection_meter.update(intersection)
            union_meter.update(union)
            a_meter.update(a_m)
            b_meter.update(b_m)
            if batch_idx % 50 == 0:
                grid_image1 = make_grid(image[:1].clone().cpu().data,
                                        1,
                                        normalize=True)
                writer_image.add_image('image', grid_image1, batch_idx)
                grid_image2 = make_grid(depth[:1].clone().cpu().data,
                                        1,
                                        normalize=True)
                writer_image.add_image('depth', grid_image2, batch_idx)
                grid_image3 = make_grid(utils.color_label(
                    torch.max(pred[:1], 1)[1]),
                                        1,
                                        normalize=False,
                                        range=(0, 255))
                writer_image.add_image('Predicted label', grid_image3,
                                       batch_idx)
                grid_image4 = make_grid(utils.color_label(show_label[:1]),
                                        1,
                                        normalize=False,
                                        range=(0, 255))
                writer_image.add_image('Groundtruth label', grid_image4,
                                       batch_idx)
                print('[{}] iter {}, accuracy: {}'.format(
                    datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    batch_idx, acc))
            # if batch_idx % 1 == 0:
            # 	if args.visualize:
            # 		visualize_result(origin_image, origin_depth, label, output, batch_idx, args)
            # visualize_result(origin_image, origin_depth, label - 1, output - 1, batch_idx, args)
    print('推理时间:', test_total_time / len(val_data), '\nfps:',
          len(val_data) / test_total_time)
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(i, _iou))
    # mAcc:Prediction和Ground Truth对应位置的“分类”准确率(每个像素)
    mAcc = (a_meter.average() / (b_meter.average() + 1e-10))
    print(mAcc.mean())
    print('[Eval Summary]:')
    print('Mean IoU: {:.4}, Accuracy: {:.2f}%'.format(
        iou.mean(),
        acc_meter.average() * 100))
Пример #5
0
def evaluate(models, val_loader, interp, criterion, args):
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    time_meter = AverageMeter()

    models.eval()

    for i, batch_data in enumerate(val_loader):
        # forward pass
        images, labels, _ = batch_data

        torch.cuda.synchronize()
        tic = time.perf_counter()

        pred_seg = torch.zeros(images.size(0), args.num_classes,
                               labels.size(1), labels.size(2))
        pred_seg = pred_seg.cuda(args.gpu_id, non_blocking=True)

        for scale in args.scales:
            imgs_scale = zoom(images.numpy(), (1., 1., scale, scale),
                              order=1,
                              prefilter=False,
                              mode='nearest')

            input_images = torch.from_numpy(imgs_scale)
            if args.gpu_id is not None:
                input_images = input_images.cuda(args.gpu_id,
                                                 non_blocking=True)

            pred_scale, _ = models(input_images)  # change
            pred_scale = interp(pred_scale)

            # average the probability
            pred_seg = pred_seg + pred_scale / len(args.scales)

        # pred =torch.log(pred)

        seg_labels = labels.cuda(args.gpu_id, non_blocking=True)

        loss = criterion(pred_seg, seg_labels)
        loss_meter.update(loss.data.item())
        print('[Eval] iter {}, loss: {}'.format(i, loss.data.item()))
        # loss_meter.update(loss.item())
        # print('[Eval] iter {}, loss: {}'.format(i, loss.item()))

        labels = as_numpy(labels)
        _, pred = torch.max(pred_seg, dim=1)
        pred = as_numpy(pred.squeeze(0).cpu())

        # calculate accuracy
        acc, pix = accuracy(pred, labels)
        intersection, union = intersectionAndUnion(pred, labels,
                                                   args.num_classes)
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)

        torch.cuda.synchronize()
        time_meter.update(time.perf_counter() - tic)

        if args.visualize:
            visualize_result(batch_data, pred_seg, args)

    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [ {} ], IoU: {:.4f}'.format(i, _iou))

    print('[Eval Summary]:')
    print(
        'loss: {:.6f}, Mean IoU: {:.2f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'
        .format(loss_meter.average(),
                iou.mean() * 100,
                acc_meter.average() * 100, time_meter.average()))
Пример #6
0
def evaluate():
    model = ACNet_models_V1.ACNet(num_class=5, pretrained=False)
    load_ckpt(model, None, None, args.last_ckpt, device)
    model.eval()
    model.to(device)

    val_data = ACNet_data.FreiburgForest(
        transform=torchvision.transforms.Compose([
            ACNet_data.ScaleNorm(),
            ACNet_data.ToTensor(),
            ACNet_data.Normalize()
        ]),
        data_dirs=[args.test_dir],
        modal1_name=args.modal1,
        modal2_name=args.modal2,
    )
    val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    a_meter = AverageMeter()
    b_meter = AverageMeter()
    with torch.no_grad():
        for batch_idx, sample in enumerate(val_loader):
            modal1 = sample['modal1'].to(device)
            modal2 = sample['modal2'].to(device)
            label = sample['label'].numpy()
            basename = sample['basename'][0]

            with torch.no_grad():
                pred = model(modal1, modal2)

            output = torch.argmax(pred, 1) + 1
            output = output.squeeze(0).cpu().numpy()

            acc, pix = accuracy(output, label)
            intersection, union = intersectionAndUnion(output, label, args.num_class)
            acc_meter.update(acc, pix)
            a_m, b_m = macc(output, label, args.num_class)
            intersection_meter.update(intersection)
            union_meter.update(union)
            a_meter.update(a_m)
            b_meter.update(b_m)
            print('[{}] iter {}, accuracy: {}'
                  .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), batch_idx, acc))

            if args.visualize:
                visualize_result(modal1, modal2, label, output, batch_idx, args)

            if args.save_predictions:
                colored_output = utils.color_label_eval(output).astype(np.uint8)
                imageio.imwrite(f'{args.output_dir}/{basename}_pred.png', colored_output.transpose([1, 2, 0]))

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(i, _iou))

    mAcc = (a_meter.average() / (b_meter.average() + 1e-10))
    print(mAcc.mean())
    print('[Eval Summary]:')
    print('Mean IoU: {:.4}, Accuracy: {:.2f}%'
          .format(iou.mean(), acc_meter.average() * 100))