def main():
    net = PSPNet(num_classes=voc.num_classes).cuda()
    print('load model ' + args['snapshot'])
    net.load_state_dict(torch.load(os.path.join(ckpt_path, args['exp_name'], args['snapshot'])))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    test_set = voc.VOC('test', transform=val_input_transform)
    test_loader = DataLoader(test_set, batch_size=1, num_workers=8, shuffle=False)

    check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test'))

    for vi, data in enumerate(test_loader):
        img_name, img = data
        img_name = img_name[0]

        img = Variable(img, volatile=True).cuda()
        output = net(img)

        prediction = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
        prediction = voc.colorize_mask(prediction)
        prediction.save(os.path.join(ckpt_path, args['exp_name'], 'test', img_name + '.png'))

        print('%d / %d' % (vi + 1, len(test_loader)))
Пример #2
0
def main():
    net = PSPNet(num_classes=voc.num_classes).cuda()
    print('load model ' + args['snapshot'])
    net.load_state_dict(torch.load(os.path.join(ckpt_path, args['exp_name'], args['snapshot'])))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    test_set = voc.VOC('test', transform=val_input_transform)
    test_loader = DataLoader(test_set, batch_size=1, num_workers=8, shuffle=False)

    check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test'))

    for vi, data in enumerate(test_loader):
        img_name, img = data
        img_name = img_name[0]

        img = Variable(img, volatile=True).cuda()
        output = net(img)

        prediction = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
        prediction = voc.colorize_mask(prediction)
        prediction.save(os.path.join(ckpt_path, args['exp_name'], 'test', img_name + '.png'))

        print('%d / %d' % (vi + 1, len(test_loader)))
Пример #3
0
def main():
    net = fcn8s.FCN8s(num_classes=voc.num_classes,pretrained=False).cuda()
    #net = deeplab_resnet.Res_Deeplab().cuda()
    #net = dl.Res_Deeplab().cuda()
    
    #net.load_state_dict(torch.load(os.path.join(ckpt_path, args['exp_name'], args['snapshot'])))
    net.load_state_dict(torch.load(os.path.join(root,model,pth)))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    test_set = voc.VOC('eval', transform=val_input_transform,target_transform=target_transform)
    test_loader = DataLoader(test_set, batch_size=1, num_workers=8, shuffle=False)

    #check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test'))
    predictions = []
    masks = []
    ious =np.array([])
    for vi, data in enumerate(test_loader):
        img_name, img, msk = data
        img_name = img_name[0]
        
        H,W = img.size()[2:]
        L = min(H,W)
        interp_before = nn.UpsamplingBilinear2d(size=(L,L))
        interp_after = nn.UpsamplingBilinear2d(size=(H,W))
        
        img = Variable(img, volatile=True).cuda()
        msk = Variable(msk, volatile=True).cuda()
        masks.append(msk.data.squeeze_(0).cpu().numpy())
        
        #img = interp_before(img)
        output = net(img)
        #output = interp_after(output[3])

        prediction = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
        #prediction = output.data.max(1)[1].squeeze().cpu().numpy()
        ious = np.append(ious,get_iou(prediction,masks[-1]))
        
        predictions.append(prediction)
        ## prediction.save(os.path.join(ckpt_path, args['exp_name], 'test',
        img_name + '.png')) prediction = voc.colorize_mask(prediction)
        prediction.save(os.path.join(root,'segmented-images',model+'-'+img_name+'.png'))
        #if vi == 10:
        #    break
        print('%d / %d' % (vi + 1, len(test_loader)))
    results = evaluate(predictions,masks,voc.num_classes)
    print('mean iou = {}'.format(results[2]))
    print(ious.mean())
Пример #4
0
def main():
    backbone = ResNet()
    backbone.load_state_dict(torch.load('./weight/resnet34-333f7ec4.pth'),
                             strict=False)
    net = Decoder34(num_classes=13, backbone=backbone).cuda()
    print('load model' + snapshot)
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name,
                                                snapshot)))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    # test_set = voc.VOC('test', transform=val_input_transform)
    # test_loader = DataLoader(test_set, batch_size=1,
    #                          num_workers=8, shuffle=False)
    test_set = wp.Wp('test',
                     transform=input_transform,
                     target_transform=target_transform)
    test_loader = DataLoader(train_set,
                             batch_size=1,
                             num_workers=4,
                             shuffle=False)
    check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test'))

    for vi, data in enumerate(test_loader):
        img_name, img = data
        img_name = img_name[0]

        img = Variable(img, volatile=True).cuda()
        output = net(img)

        prediction = output.data.max(1)[1].squeeze_(1).squeeze_(
            0).cpu().numpy()
        prediction = voc.colorize_mask(prediction)
        prediction.save(
            os.path.join(ckpt_path, args['exp_name'], 'test',
                         img_name + '.png'))

        print('%d / %d' % (vi + 1, len(test_loader)))
Пример #5
0
def validate(val_loader, net, criterion, optimizer, epoch, train_args,
             visualize):
    # the following code is written assuming that batch size is 1
    net.eval()

    val_loss = AverageMeter()

    gts_all = np.zeros(
        (len(val_loader), args['shorter_size'], 2 * args['shorter_size']),
        dtype=float)
    predictions_all = np.zeros(
        (len(val_loader), args['shorter_size'], 2 * args['shorter_size']),
        dtype=int)
    for vi, data in enumerate(val_loader):
        input, gt, slices_info = data
        assert len(input.size()) == 5 and len(gt.size()) == 4 and len(
            slices_info.size()) == 3
        input.transpose_(0, 1)
        gt.transpose_(0, 1)
        slices_info.squeeze_(0)
        assert input.size()[3:] == gt.size()[2:]

        count = torch.zeros(args['shorter_size'],
                            2 * args['shorter_size']).cuda()
        output = torch.zeros(voc.num_classes, args['shorter_size'],
                             2 * args['shorter_size']).cuda()

        slice_batch_pixel_size = input.size(1) * input.size(3) * input.size(4)

        for input_slice, gt_slice, info in zip(input, gt, slices_info):
            input_slice = Variable(input_slice).cuda()
            gt_slice = Variable(gt_slice).cuda()

            output_slice = net(input_slice)
            assert output_slice.size()[2:] == gt_slice.size()[1:]
            assert output_slice.size()[1] == voc.num_classes
            # print('size: ', output[:, info[0]: info[1], info[2]: info[3]].size(), output_slice[0, :, :info[4], :info[5]].data.size())
            output[:, info[0]:info[1], info[2]:info[3]] += output_slice[
                0, :, :info[4], :info[5]].data
            gts_all[vi, info[0]:info[1], info[2]:info[3]] += gt_slice[
                0, :info[4], :info[5]].data.cpu().numpy()

            count[info[0]:info[1], info[2]:info[3]] += 1

            val_loss.update(
                criterion(output_slice, gt_slice).item(),
                slice_batch_pixel_size)

        output /= count
        # print(count.cpu().numpy())
        gts_all[vi, :, :] /= count.cpu().numpy().astype(float)
        predictions_all[vi, :, :] = output.max(0)[1].squeeze_(
            0).cpu().numpy().astype(int)

        print('validating: %d / %d' % (vi + 1, len(val_loader)))

    acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all,
                                              voc.num_classes)

    train_args['best_record']['val_loss'] = val_loss.avg
    train_args['best_record']['epoch'] = epoch
    train_args['best_record']['acc'] = acc
    train_args['best_record']['acc_cls'] = acc_cls
    train_args['best_record']['mean_iu'] = mean_iu
    train_args['best_record']['fwavacc'] = fwavacc
    snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
        epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc,
        optimizer.param_groups[1]['lr'])
    torch.save(net.state_dict(),
               os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
    torch.save(
        optimizer.state_dict(),
        os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth'))
    print("saving model: ",
          os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))

    if train_args['val_save_to_img_file']:
        to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch))
        check_mkdir(to_save_dir)

    val_visual = []
    for idx, data in enumerate(zip(gts_all, predictions_all)):
        gt_pil = voc.colorize_mask(data[0])
        predictions_pil = voc.colorize_mask(data[1])
        if train_args['val_save_to_img_file']:
            predictions_pil.save(
                os.path.join(to_save_dir, '%d_prediction.png' % idx))
            gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx))
        val_visual.extend([
            visualize(gt_pil.convert('RGB')),
            visualize(predictions_pil.convert('RGB'))
        ])
    val_visual = torch.stack(val_visual, 0)
    val_visual = vutils.make_grid(val_visual, nrow=2, padding=5)
    writer.add_image(snapshot_name, val_visual)

    print(
        '-----------------------------------------------------------------------------------------------------------'
    )
    print(
        '[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]'
        % (epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))

    print(
        'best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]'
        % (train_args['best_record']['val_loss'],
           train_args['best_record']['acc'],
           train_args['best_record']['acc_cls'],
           train_args['best_record']['mean_iu'],
           train_args['best_record']['fwavacc'],
           train_args['best_record']['epoch']))

    print(
        '-----------------------------------------------------------------------------------------------------------'
    )

    writer.add_scalar('val_loss', val_loss.avg, epoch)
    writer.add_scalar('acc', acc, epoch)
    writer.add_scalar('acc_cls', acc_cls, epoch)
    writer.add_scalar('mean_iu', mean_iu, epoch)
    writer.add_scalar('fwavacc', fwavacc, epoch)

    net.train()
    return val_loss.avg
def validate(val_loader, net, criterion, optimizer, epoch, train_args, visualize):
    # the following code is written assuming that batch size is 1
    net.eval()

    val_loss = AverageMeter()

    gts_all = np.zeros((len(val_loader), args['shorter_size'], 2 * args['shorter_size']), dtype=int)
    predictions_all = np.zeros((len(val_loader), args['shorter_size'], 2 * args['shorter_size']), dtype=int)
    for vi, data in enumerate(val_loader):
        input, gt, slices_info = data
        assert len(input.size()) == 5 and len(gt.size()) == 4 and len(slices_info.size()) == 3
        input.transpose_(0, 1)
        gt.transpose_(0, 1)
        slices_info.squeeze_(0)
        assert input.size()[3:] == gt.size()[2:]

        count = torch.zeros(args['shorter_size'], 2 * args['shorter_size']).cuda()
        output = torch.zeros(voc.num_classes, args['shorter_size'], 2 * args['shorter_size']).cuda()

        slice_batch_pixel_size = input.size(1) * input.size(3) * input.size(4)

        for input_slice, gt_slice, info in zip(input, gt, slices_info):
            input_slice = Variable(input_slice).cuda()
            gt_slice = Variable(gt_slice).cuda()

            output_slice = net(input_slice)
            assert output_slice.size()[2:] == gt_slice.size()[1:]
            assert output_slice.size()[1] == voc.num_classes
            output[:, info[0]: info[1], info[2]: info[3]] += output_slice[0, :, :info[4], :info[5]].data
            gts_all[vi, info[0]: info[1], info[2]: info[3]] += gt_slice[0, :info[4], :info[5]].data.cpu().numpy()

            count[info[0]: info[1], info[2]: info[3]] += 1

            val_loss.update(criterion(output_slice, gt_slice).data[0], slice_batch_pixel_size)

        output /= count
        gts_all[vi, :, :] /= count.cpu().numpy().astype(int)
        predictions_all[vi, :, :] = output.max(0)[1].squeeze_(0).cpu().numpy()

        print('validating: %d / %d' % (vi + 1, len(val_loader)))

    acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all, voc.num_classes)

    train_args['best_record']['val_loss'] = val_loss.avg
    train_args['best_record']['epoch'] = epoch
    train_args['best_record']['acc'] = acc
    train_args['best_record']['acc_cls'] = acc_cls
    train_args['best_record']['mean_iu'] = mean_iu
    train_args['best_record']['fwavacc'] = fwavacc
    snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
        epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr'])
    torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
    torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth'))

    if train_args['val_save_to_img_file']:
        to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch))
        check_mkdir(to_save_dir)

    val_visual = []
    for idx, data in enumerate(zip(gts_all, predictions_all)):
        gt_pil = voc.colorize_mask(data[0])
        predictions_pil = voc.colorize_mask(data[1])
        if train_args['val_save_to_img_file']:
            predictions_pil.save(os.path.join(to_save_dir, '%d_prediction.png' % idx))
            gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx))
        val_visual.extend([visualize(gt_pil.convert('RGB')),
                           visualize(predictions_pil.convert('RGB'))])
    val_visual = torch.stack(val_visual, 0)
    val_visual = vutils.make_grid(val_visual, nrow=2, padding=5)
    writer.add_image(snapshot_name, val_visual)

    print('-----------------------------------------------------------------------------------------------------------')
    print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % (
        epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))

    print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]' % (
        train_args['best_record']['val_loss'], train_args['best_record']['acc'], train_args['best_record']['acc_cls'],
        train_args['best_record']['mean_iu'], train_args['best_record']['fwavacc'], train_args['best_record']['epoch']))

    print('-----------------------------------------------------------------------------------------------------------')

    writer.add_scalar('val_loss', val_loss.avg, epoch)
    writer.add_scalar('acc', acc, epoch)
    writer.add_scalar('acc_cls', acc_cls, epoch)
    writer.add_scalar('mean_iu', mean_iu, epoch)
    writer.add_scalar('fwavacc', fwavacc, epoch)

    net.train()
    return val_loss.avg
def validate(val_loader, net, criterion, optimizer, epoch, train_args, restore,
             visualize):
    net.eval()

    val_loss = AverageMeter()
    inputs_all, gts_all, predictions_all = [], [], []

    for vi, data in enumerate(val_loader):
        inputs, gts = data
        N = inputs.size(0)
        inputs = Variable(inputs, volatile=True).cuda()
        gts = Variable(gts, volatile=True).cuda()

        outputs = net(inputs)
        predictions = outputs.data.max(1)[1].squeeze_(1).squeeze_(
            0).cpu().numpy()

        val_loss.update(criterion(outputs, gts).data[0] / N, N)

        if random.random() > train_args['val_img_sample_rate']:
            inputs_all.append(None)
        else:
            inputs_all.append(inputs.data.squeeze_(0).cpu())
        gts_all.append(gts.data.squeeze_(0).cpu().numpy())
        predictions_all.append(predictions)

    acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all,
                                              voc.num_classes)

    if mean_iu > train_args['best_record']['mean_iu']:
        train_args['best_record']['val_loss'] = val_loss.avg
        train_args['best_record']['epoch'] = epoch
        train_args['best_record']['acc'] = acc
        train_args['best_record']['acc_cls'] = acc_cls
        train_args['best_record']['mean_iu'] = mean_iu
        train_args['best_record']['fwavacc'] = fwavacc
        snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
            epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc,
            optimizer.param_groups[1]['lr'])
        # torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
        # torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth'))

        if train_args['val_save_to_img_file']:
            to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch))
            check_mkdir(to_save_dir)

        val_visual = []
        for idx, data in enumerate(zip(inputs_all, gts_all, predictions_all)):
            if data[0] is None:
                continue
            input_pil = restore(data[0])
            gt_pil = voc.colorize_mask(data[1])
            predictions_pil = voc.colorize_mask(data[2])
            if train_args['val_save_to_img_file']:
                input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx))
                predictions_pil.save(
                    os.path.join(to_save_dir, '%d_prediction.png' % idx))
                gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx))
            val_visual.extend([
                visualize(input_pil.convert('RGB')),
                visualize(gt_pil.convert('RGB')),
                visualize(predictions_pil.convert('RGB'))
            ])
        val_visual = torch.stack(val_visual, 0)
        val_visual = vutils.make_grid(val_visual, nrow=3, padding=5)
        writer.add_image(snapshot_name, val_visual)

    print(
        '--------------------------------------------------------------------')
    print(
        '[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]'
        % (epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))

    print(
        'best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]'
        % (train_args['best_record']['val_loss'],
           train_args['best_record']['acc'],
           train_args['best_record']['acc_cls'],
           train_args['best_record']['mean_iu'],
           train_args['best_record']['fwavacc'],
           train_args['best_record']['epoch']))

    print(
        '--------------------------------------------------------------------')

    writer.add_scalar('val_loss', val_loss.avg, epoch)
    writer.add_scalar('acc', acc, epoch)
    writer.add_scalar('acc_cls', acc_cls, epoch)
    writer.add_scalar('mean_iu', mean_iu, epoch)
    writer.add_scalar('fwavacc', fwavacc, epoch)
    writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch)

    net.train()
    return val_loss.avg
def validate(val_loader, net, criterion, optimizer, epoch, train_args, restore, visualize):
    net.eval()

    val_loss = AverageMeter()
    inputs_all, gts_all, predictions_all = [], [], []

    for vi, data in enumerate(val_loader):
        inputs, gts = data
        N = inputs.size(0)
        inputs = Variable(inputs, volatile=True).cuda()
        gts = Variable(gts, volatile=True).cuda()

        outputs = net(inputs)
        predictions = outputs.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

        val_loss.update(criterion(outputs, gts).data[0] / N, N)

        if random.random() > train_args['val_img_sample_rate']:
            inputs_all.append(None)
        else:
            inputs_all.append(inputs.data.squeeze_(0).cpu())
        gts_all.append(gts.data.squeeze_(0).cpu().numpy())
        predictions_all.append(predictions)

    acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all, voc.num_classes)

    if mean_iu > train_args['best_record']['mean_iu']:
        train_args['best_record']['val_loss'] = val_loss.avg
        train_args['best_record']['epoch'] = epoch
        train_args['best_record']['acc'] = acc
        train_args['best_record']['acc_cls'] = acc_cls
        train_args['best_record']['mean_iu'] = mean_iu
        train_args['best_record']['fwavacc'] = fwavacc
        snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % (
            epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr']
        )
        torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth'))

        if train_args['val_save_to_img_file']:
            to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch))
            check_mkdir(to_save_dir)

        val_visual = []
        for idx, data in enumerate(zip(inputs_all, gts_all, predictions_all)):
            if data[0] is None:
                continue
            input_pil = restore(data[0])
            gt_pil = voc.colorize_mask(data[1])
            predictions_pil = voc.colorize_mask(data[2])
            if train_args['val_save_to_img_file']:
                input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx))
                predictions_pil.save(os.path.join(to_save_dir, '%d_prediction.png' % idx))
                gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx))
            val_visual.extend([visualize(input_pil.convert('RGB')), visualize(gt_pil.convert('RGB')),
                               visualize(predictions_pil.convert('RGB'))])
        val_visual = torch.stack(val_visual, 0)
        val_visual = vutils.make_grid(val_visual, nrow=3, padding=5)
        writer.add_image(snapshot_name, val_visual)

    print('--------------------------------------------------------------------')
    print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % (
        epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))

    print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]' % (
        train_args['best_record']['val_loss'], train_args['best_record']['acc'], train_args['best_record']['acc_cls'],
        train_args['best_record']['mean_iu'], train_args['best_record']['fwavacc'], train_args['best_record']['epoch']))

    print('--------------------------------------------------------------------')

    writer.add_scalar('val_loss', val_loss.avg, epoch)
    writer.add_scalar('acc', acc, epoch)
    writer.add_scalar('acc_cls', acc_cls, epoch)
    writer.add_scalar('mean_iu', mean_iu, epoch)
    writer.add_scalar('fwavacc', fwavacc, epoch)
    writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch)

    net.train()
    return val_loss.avg