def sssss(test_image_dir, test_gt_dir, model):
    print('begin test')
    print('begin test', file=terminal_file)
    testset = ImageDataset(
        img_dir=test_image_dir,
        gt_dir=test_gt_dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
    )
    test_loader = torch.utils.data.DataLoader(testset,
                                              shuffle=False,
                                              batch_size=args.batch_size,
                                              num_workers=args.workers)
    model.eval()

    mae = 0
    mse = 0
    gt_sum = 0
    predict_sum = 0
    with torch.no_grad():
        for i, (img, gt_density_map) in enumerate(test_loader):
            img = img.to(args.device)
            gt_density_map = gt_density_map.to(args.device)
            predict_density_map, refine_density_map = model(
                img, gt_density_map)

            gt_count = np.sum(gt_density_map.data.cpu().numpy())
            predict_count = np.sum(predict_density_map.data.cpu().numpy())
            mae += abs(gt_count - predict_count)
            mse += ((gt_count - predict_count) * (gt_count - predict_count))
            gt_sum += gt_count
            predict_sum += predict_count
    mae = mae / len(test_loader.dataset)
    mse = np.sqrt(mse / len(test_loader.dataset))
    gt_sum = gt_sum / len(test_loader.dataset)
    predict_sum = predict_sum / len(test_loader.dataset)

    return mae, mse, gt_sum, predict_sum
def train(train_image_dir, train_gt_dir, model, criterion, optimizer, epoch):
    losses = AverageMeter()
    losses1 = AverageMeter()
    losses2 = AverageMeter()

    trainset = ImageDataset(
        img_dir=train_image_dir,
        gt_dir=train_gt_dir,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
    )
    train_loader = torch.utils.data.DataLoader(trainset,
                                               shuffle=True,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers)
    print('epoch %d, processed %d samples, dataset %s, lr %.10f' %
          (epoch, epoch * len(train_loader.dataset), args.dataset, args.lr))
    print('epoch %d, processed %d samples, dataset %s, lr %.10f' %
          (epoch, epoch * len(train_loader.dataset), args.dataset, args.lr),
          file=terminal_file)

    model.train()
    end = time.time()

    train_mae = 0.0
    train_mse = 0.0
    train_gt_sum = 0.0
    train_predict_sum = 0.0
    for i, (img, gt_density_map) in enumerate(train_loader):
        img = img.to(args.device)
        gt_density_map = gt_density_map.to(args.device)
        predict_density_map, refine_density_map = model(
            img, gt_density_map)  # predict的shape为[64, 2]

        loss1 = criterion(refine_density_map, gt_density_map)
        loss2 = criterion(refine_density_map, predict_density_map)
        loss = args.alpha * loss1 + args.beta * loss2

        losses.update(loss.item(), img.size(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Loss1 {loss1.val:.4f} ({loss1.avg:.4f})\t'
                  'Loss2 {loss2.val:.4f} ({loss2.avg:.4f})\t'.format(
                      epoch,
                      i,
                      len(train_loader),
                      loss=losses,
                      loss1=losses1,
                      loss2=losses2),
                  file=terminal_file)
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Loss1 {loss1.val:.4f} ({loss1.avg:.4f})\t'
                  'Loss2 {loss2.val:.4f} ({loss2.avg:.4f})\t'.format(
                      epoch,
                      i,
                      len(train_loader),
                      loss=losses,
                      loss1=losses1,
                      loss2=losses2))

        train_gt_count = np.sum(gt_density_map.data.cpu().numpy())
        train_predict_count = np.sum(predict_density_map.data.cpu().numpy())
        train_mae += abs(train_gt_count - train_predict_count)
        train_mse += (train_gt_count - train_predict_count) * (
            train_gt_count - train_predict_count)
        train_gt_sum += train_gt_count
        train_predict_sum += train_predict_count
    train_mae = train_mae / len(train_loader.dataset)
    train_mse = np.sqrt(train_mse / len(train_loader.dataset))
    train_gt_sum = train_gt_sum / len(train_loader.dataset)
    train_predict_sum = train_predict_sum / len(train_loader.dataset)

    return train_mae, train_mse, train_gt_sum, train_predict_sum