Example #1
0
def cal_mae(img_root, model_param_path):
    '''
    Calculate the MAE of the test data.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    '''
    device = torch.device("cuda")
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path))
    model.to(device)
    dataset = create_test_dataloader(img_root)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False)
    model.eval()
    mae = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            image = data['image'].cuda()
            gt_densitymap = data['densitymap'].cuda()
            # forward propagation
            et_dmap = model(image)
            mae += abs(et_dmap.data.sum() - gt_densitymap.data.sum()).item()
            del image, gt_densitymap, et_dmap
    print("model_param_path:" + model_param_path + " mae:" +
          str(mae / len(dataloader)))
Example #2
0
def count2(img_root, model_param_path):
    # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    device = torch.device("cpu")
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path))
    model.to(device)
    test_dataloader = create_test_dataloader(img_root)  # dataloader

    for i, data in enumerate(tqdm(test_dataloader, ncols=50)):
        image = data['image'].to(device)
        et_dmp = model(image).detach()
        # count = et_densitymap.data.sum()
        #count = str('%.2f' % (et_densitymap[0].cpu().sum()))
        # et_dmp = et_densitymap[0] / torch.max(et_densitymap[0])
        et_dmp = et_dmp.numpy()
        et_dmp = et_dmp[0][0]
        count = np.sum(et_dmp)
        plt.figure(i)
        plt.axis("off")
        plt.imshow(et_dmp, cmap=CM.jet)
        # 去除坐标轴
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        # 输出图片边框设置
        plt.subplots_adjust(top=1,
                            bottom=0,
                            left=0,
                            right=1,
                            hspace=0,
                            wspace=0)
        plt.margins(0, 0)
        plt.savefig(img_root + "/test_data/result/" + str(i + 1) + "_dmp" +
                    ".jpg")
        print(str(i + 1) + "_" + "renshu:", count)
Example #3
0
def cal_mae(img_root, gt_dmap_root, model_param_path):
    '''
    Calculate the MAE of the test data.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    '''
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path,
                                     map_location=cfg.device))
    model.to(cfg.device)
    test_dataloader = create_test_dataloader(cfg.dataset_root)  # dataloader
    model.eval()
    sum_mae = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_dataloader)):
            image = data['image'].to(cfg.device)
            gt_densitymap = data['densitymap'].to(cfg.device)
            # forward propagation
            et_densitymap = model(image).detach()
            mae = abs(et_densitymap.data.sum() - gt_densitymap.data.sum())
            sum_mae += mae.item()
            # clear mem
            del i, data, image, gt_densitymap, et_densitymap
            torch.cuda.empty_cache()

    print("model_param_path:" + model_param_path + " mae:" +
          str(sum_mae / len(test_dataloader)))
Example #4
0
def one_count(img_path, model_param_path):
    filename = img_path.split('/')[-1]
    filenum = filename.split('.')[0]
    save_path = './data_test/test_data/result/' + filenum
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    img_src_save_path = save_path + '/' + filenum + '_src' + '.jpg'
    img_et_save_path = img_src_save_path.replace('src', 'et')
    img_overlap_save_path = img_src_save_path.replace('src', 'overlap')

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path))
    model.to(device)
    img_src = open_img(img_path)
    img = open_img(img_path)
    img_trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225])])
    img = img_trans(img)
    img = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
    # print('img_src size:', img.shape)
    et_dmap = model(img)
    et_dmap = et_dmap.detach().numpy()
    et_dmap = et_dmap[0][0]
    people_num = np.sum(et_dmap)
    # print(et_dmap.shape)
    print(filenum + '_num:', '\t', people_num)

    # img_src
    plt.figure(0)
    plt.imshow(img_src)
    plt.axis('off')
    # 去除坐标轴
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    # 输出图片边框设置
    plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
    plt.margins(0, 0)
    plt.savefig(img_src_save_path, bbox_inches='tight', dpi=100, pad_inches=-0.04)
    #
    # # img_et
    plt.figure(1)
    plt.imshow(et_dmap, cmap=CM.jet)
    plt.axis('off')
    # 去除坐标轴
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    # 输出图片边框设置
    plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
    plt.margins(0, 0)
    plt.savefig(img_et_save_path, bbox_inches='tight', dpi=100, pad_inches=-0.04)
    #
    img_src = cv2.imread(img_src_save_path)
    img_et = cv2.imread(img_et_save_path)
    img_et = cv2.resize(img_et, (img_src.shape[1], img_src.shape[0]))
    img_overlap = cv2.addWeighted(img_src, 0.2, img_et, 0.8, 0)
    cv2.imwrite(img_overlap_save_path, img_overlap)

    return people_num
Example #5
0
def estimate_density_map_no_gt(img_root, gt_dmap_root, model_param_path,
                               index):
    '''
    Show one estimated density-map.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    index: the order of the test image in test dataset.
    '''
    image_export_folder = 'export_images_extra'
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path,
                                     map_location=cfg.device))
    model.to(cfg.device)
    test_dataloader = create_test_extra_dataloader(
        cfg.dataset_root)  # dataloader
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_dataloader)):
            image = data['image'].to(cfg.device)
            # gt_densitymap = data['densitymap'].to(cfg.device)
            # forward propagation
            et_densitymap = model(image).detach()
            pred_count = et_densitymap.data.sum().cpu()
            # actual_count = gt_densitymap.data.sum().cpu()
            actual_count = 999
            et_densitymap = et_densitymap.squeeze(0).squeeze(0).cpu().numpy()
            # gt_densitymap = gt_densitymap.squeeze(0).squeeze(0).cpu().numpy()
            image = image[0].cpu()  # denormalize(image[0].cpu())
            print(et_densitymap.shape)
            # et is the estimated density
            plt.imshow(et_densitymap, cmap=CM.jet)
            plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder,
                                                str(i).zfill(3),
                                                str(int(pred_count)),
                                                str(int(actual_count)),
                                                'etdm.png'))
            # # gt is the ground truth density
            # plt.imshow(gt_densitymap, cmap=CM.jet)
            # plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder,
            #                                     str(i).zfill(3),
            #                                     str(int(pred_count)),
            #                                     str(int(actual_count)), 'gtdm.png'))
            # image
            plt.imshow(image.permute(1, 2, 0))
            plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder,
                                                str(i).zfill(3),
                                                str(int(pred_count)),
                                                str(int(actual_count)),
                                                'image.png'))

            # clear mem
            del i, data, image, et_densitymap, pred_count, actual_count
            torch.cuda.empty_cache()
Example #6
0
def count1(img_root, model_param_path):
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path))
    model.to(device)
    test_dataloader = create_test_dataloader(img_root)  # dataloader

    # 添加进度条
    for i, data in enumerate(tqdm(test_dataloader, ncols=50)):
        image = data['image'].to(device)
        et_densitymap = model(image).detach()
        # count = et_densitymap.data.sum()
        count = str('%.2f' % (et_densitymap[0].cpu().sum()))
Example #7
0
def main():
    transform = ST.Compose(
        [
            ST.ToNumpyForVal(),
            ST.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    global args
    args = parser.parse_args()
    model = CSRNet()
    model = model.to("cuda")
    # checkpoint = flow.load('checkpoint/Shanghai_BestModelA/shanghaiA_bestmodel')
    checkpoint = flow.load(args.modelPath)
    model.load_state_dict(checkpoint)
    img = transform(Image.open(args.picPath).convert("RGB"))
    img = flow.Tensor(img)
    img = img.to("cuda")
    output = model(img.unsqueeze(0))
    print("Predicted Count : ", int(output.detach().to("cpu").sum().numpy()))
    temp = output.view(output.shape[2], output.shape[3])
    temp = temp.numpy()
    plt.title("Predicted Count")
    plt.imshow(temp, cmap=c.jet)
    plt.show()
    temp = h5py.File(args.picDensity, "r")
    temp_1 = np.asarray(temp["density"])
    plt.title("Original Count")
    plt.imshow(temp_1, cmap=c.jet)
    print("Original Count : ", int(np.sum(temp_1)) + 1)
    plt.show()
    print("Original Image")
    plt.title("Original Image")
    plt.imshow(plt.imread(args.picPath))
    plt.show()
Example #8
0
    def __init__(self):
        self.best_pred = 1e6

        # Define Saver
        self.saver = Saver(opt)
        self.saver.save_experiment_config()

        # visualize
        if opt.visualize:
            # vis_legend = ["Loss", "MAE"]
            # batch_plot = create_vis_plot(vis, 'Batch', 'Loss', 'batch loss', vis_legend[0:1])
            # val_plot = create_vis_plot(vis, 'Epoch', 'result', 'val result', vis_legend[1:2])
            # Define Tensorboard Summary
            self.summary = TensorboardSummary(self.saver.experiment_dir)
            self.writer = self.summary.create_summary()

        # Dataset dataloader
        self.train_dataset = SHTDataset(opt.train_dir, train=True)
        self.train_loader = DataLoader(self.train_dataset,
                                       num_workers=opt.workers,
                                       shuffle=True,
                                       batch_size=opt.batch_size)  # must be 1
        self.test_dataset = SHTDataset(opt.test_dir, train=False)
        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset, shuffle=False, batch_size=opt.batch_size
        )  # must be 1, because per image size is different

        torch.cuda.manual_seed(opt.seed)

        model = CSRNet()
        self.model = model.to(opt.device)
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                opt.start_epoch = checkpoint['epoch']
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))
        if opt.use_mulgpu:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)
        self.criterion = nn.MSELoss(reduction='mean').to(opt.device)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         opt.lr,
                                         momentum=opt.momentum,
                                         weight_decay=opt.decay)
        # Define lr scheduler
        self.scheduler = lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.scales)
        self.scheduler.last_epoch = opt.start_epoch - 1
Example #9
0
def count3(img_root, model_param_path):
    writer = SummaryWriter()
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path))
    model.to(device)
    test_dataloader = create_test_dataloader(img_root)  # dataloader

    for i, data in enumerate(tqdm(test_dataloader, ncols=50)):
        image = data['image'].to(device)
        et_densitymap = model(image).detach()
        # count = et_densitymap.data.sum()
        count = str('%.2f' % (et_densitymap[0].cpu().sum()))
        writer.add_image(str(i) + '/img:', denormalize(image[0].cpu()))
        writer.add_image(
            str(i) + "/dmp_count:" + count,
            et_densitymap[0] / torch.max(et_densitymap[0]))
        print(str(i + 1) + "_img count success")
Example #10
0
def cal_mae(img_root, gt_dmap_root, model_param_path):
    '''
    Calculate the MAE of the test data.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    '''
    cfg = Config()
    device = cfg.device
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path))  # GPU
    #torch.load(model_param_path, map_location=lambda storage, loc: storage)        # CPU
    model.to(device)
    """
    @Mushy 
    Changed data loader to give path From config device 
    
    """

    dataloader = create_test_dataloader(cfg.dataset_root)
    #dataloader=torch.utils.data.DataLoader(cfg.dataset_root,batch_size=1,shuffle=False)
    model.eval()
    mae = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            """
            @Mushy 
            Changed how to access the data . 
            """

            img = data['image'].to(device)
            #gt_dmap=gt_dmap.to(device)
            gt_dmap = data['densitymap'].to(device)
            # forward propagation
            et_dmap = model(img)
            mae += abs(et_dmap.data.sum() - gt_dmap.data.sum()).item()
            del img, gt_dmap, et_dmap

    print("model_param_path:" + model_param_path + " mae:" +
          str(mae / len(dataloader)))
Example #11
0
def main():
    transform = ST.Compose([
        ST.ToNumpyForVal(),
        ST.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    global args
    args = parser.parse_args()
    root = "./dataset/"
    # now generate the ShanghaiA's ground truth
    part_A_train = os.path.join(root, "part_A_final/train_data", "images")
    part_A_test = os.path.join(root, "part_A_final/test_data", "images")
    part_B_train = os.path.join(root, "part_B_final/train_data", "images")
    part_B_test = os.path.join(root, "part_B_final/test_data", "images")
    path_sets = []
    if args.picSrc == "part_A_test":
        path_sets = [part_A_test]
    elif args.picSrc == "part_B_test":
        path_sets = [part_B_test]
    img_paths = []
    for path in path_sets:
        for img_path in glob.glob(os.path.join(path, "*.jpg")):
            img_paths.append(img_path)
    model = CSRNet()
    model = model.to("cuda")
    checkpoint = flow.load(args.modelPath)
    model.load_state_dict(checkpoint)
    MAE = []
    for i in range(len(img_paths)):
        img = transform(Image.open(img_paths[i]).convert("RGB"))
        img = np.asarray(img).astype(np.float32)
        img = flow.Tensor(img, dtype=flow.float32, device="cuda")
        img = img.to("cuda")
        gt_file = h5py.File(
            img_paths[i].replace(".jpg",
                                 ".h5").replace("images", "ground_truth"), "r")
        groundtruth = np.asarray(gt_file["density"])
        with flow.no_grad():
            output = model(img.unsqueeze(0))
        mae = abs(output.sum().numpy() - np.sum(groundtruth))
        MAE.append(mae)
    avg_MAE = sum(MAE) / len(MAE)
    print("test result: MAE:{:2f}".format(avg_MAE))
def main():
    global args, best_prec1
    best_prec1 = 1e6

    args = parser.parse_args()
    args.original_lr = 1e-7
    args.lr = 1e-7
    args.batch_size = 1
    args.momentum = 0.95
    args.decay = 5 * 1e-4
    args.start_epoch = 0
    args.epochs = 400
    args.steps = [-1, 1, 100, 150]
    args.scales = [1, 1, 1, 1]
    args.workers = 4
    args.seed = time.time()
    args.print_freq = 30

    # with open(args.train_json, 'r') as outfile:
    #     train_list = json.load(outfile)

    train_list = [
        os.path.join(args.train_path, i) for i in os.listdir(args.train_path)
    ]
    # with open(args.test_json, 'r') as outfile:
    #     val_list = json.load(outfile)

    print(train_list)
    val_list = [
        os.path.join(args.train_path, j) for j in os.listdir(args.test_path)
    ]

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.cuda.manual_seed(args.seed)

    model = CSRNet()

    model = model.to(device)

    criterion = nn.MSELoss(size_average=False).to(device)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.decay)

    if args.pre:
        if os.path.isfile(args.pre):
            print("=> loading checkpoint '{}'".format(args.pre))
            checkpoint = torch.load(args.pre)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.pre, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pre))

    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        train(train_list, model, criterion, optimizer, epoch)
        prec1 = validate(val_list, model, criterion)

        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        print(' * best MAE {mae:.3f} '.format(mae=best_prec1))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.pre,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.task)
Example #13
0
def main():
    global args, best_prec1
    best_prec1 = 1e6
    args = parser.parse_args()
    args.original_lr = 1e-7
    args.lr = 1e-7
    args.batch_size = 1
    args.momentum = 0.95
    args.decay = 5 * 1e-4
    args.start_epoch = 0
    args.epochs = 400
    args.steps = [-1, 1, 100, 150]
    args.scales = [1, 1, 1, 1]
    args.workers = 0
    args.seed = time.time()
    args.print_freq = 30
    with open(args.train_json, "r") as outfile:
        train_list = json.load(outfile)
    with open(args.test_json, "r") as outfile:
        val_list = json.load(outfile)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    model = CSRNet()
    model = model.to("cuda")
    criterion = nn.MSELoss(reduction="sum").to("cuda")
    optimizer = flow.optim.SGD(model.parameters(),
                               args.lr,
                               momentum=args.momentum,
                               weight_decay=args.decay)
    if args.pre:
        if os.path.isfile(args.pre):
            print("=> loading checkpoint '{}'".format(args.pre))
            checkpoint = flow.load(args.pre)
            args.start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])

            optimizer.load_state_dict(checkpoint["optimizer"])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.pre, checkpoint["epoch"]))
        else:
            print("=> no checkpoint found at '{}'".format(args.pre))

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        train(train_list, model, criterion, optimizer, epoch)
        prec1 = validate(val_list, model, criterion)
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        print(" * best MAE {mae:.3f} ".format(mae=best_prec1))
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": args.pre,
                "state_dict": model.state_dict(),
                "best_prec1": best_prec1,
            },
            is_best,
            str(epoch + 1),
            args.modelPath,
        )
Example #14
0
def count(path):
    """
    evaluates the number of larva present in input.
    input is either an image of a video. if input is an image, the evaluation is done once over the image, if input is a
    video, the evaluation is done over every caption in the video seperately and then averaged over all captions to
    produce the result
    :param path: a path to an image or a video
    :return: count
    """

    # Define the device(processor) type
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print('Current procesor is GPU')
    else:
        device = torch.device('cpu')
        print('Current procesor is CPU')

    # Define the model to use for calculations
    model = CSRNet()
    model.load_state_dict(torch.load('model_wgts.pth'))
    model.to(device)
    model.eval()

    # Load the image or video
    im_list = []
    try:
        im_list.append(Image.open(path))
    except OSError:
        if 'http' in path:
            wget.download(path, out='videos')
        cap = cv2.VideoCapture(os.listdir('videos')[0])
        frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fc = 0
        ret = True
        im_list = []
        while (fc < frameCount and ret):
            ret, im = cap.read()
            if fc % 10 == 0:
                new_im = np.zeros_like(im)
                new_im[:, :, 0] = im[:, :, 2]
                new_im[:, :, 1] = im[:, :, 1]
                new_im[:, :, 2] = im[:, :, 0]
                im_list.append(Image.fromarray(new_im.astype('uint8'), 'RGB'))
            fc += 1

    # Disable gradients
    with torch.no_grad():
        # Prepare data for model
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        transform_eval = T.Compose([
            T.Resize(255, interpolation=Image.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean, std)
        ])
        model_input = torch.stack([transform_eval(im) for im in im_list])
        model_input.to(device)

        results, densities = model(model_input)
        if len(results) > 1:
            results = results.mean()
        return results
Example #15
0
def main():

    global args, best_prec1

    best_prec1 = 1e6

    args = parser.parse_args()
    args.original_lr = 1e-5
    args.lr = 1e-5
    args.batch_size = 1
    args.momentum = 0.95
    args.decay = 5 * 1e-4
    args.start_epoch = 0
    args.epochs = 100
    args.steps = [-1, 20, 40, 60]
    args.scales = [1, 0.1, 0.1, 0.1]
    args.workers = 4
    args.seed = time.time()
    args.print_freq = 30
    # with open(args.train_json, 'r') as outfile:
    #     train_list = json.load(outfile)
    # with open(args.test_json, 'r') as outfile:
    #     val_list = json.load(outfile)

    csv_train_path = args.train_csv
    csv_test_path = args.test_csv

    # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    # torch.cuda.manual_seed(args.seed)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    model = CSRNet()

    #summary(model, (3, 256, 256))

    model = model.to(device)

    criterion = nn.MSELoss(size_average=False).to(device)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.decay)

    if args.pre:
        if os.path.isfile(args.pre):
            print("=> loading checkpoint '{}'".format(args.pre))
            checkpoint = torch.load(args.pre)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.pre, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pre))
    precs = []
    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        train(csv_train_path, model, criterion, optimizer, epoch)
        prec1 = validate(csv_test_path, model, criterion)
        precs.append(prec1)
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        print(' * best MAE {mae:.3f} '.format(mae=best_prec1))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.pre,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
                'MAE_history': precs
            }, is_best, args.task)
Example #16
0
def main():
    
    global args,best_prec1
    
    best_prec1 = 1e6
    
    args = parser.parse_args()
    print(args)
    args.original_lr = 1e-7
    args.lr = 1e-7
#     args.batch_size    = 9
    args.momentum      = 0.95
    args.decay         = 5*1e-4
    args.start_epoch   = 0
    args.epochs = 400
    args.steps         = [-1,1,100,150]
    args.scales        = [1,1,1,1]
    args.workers = 4
    args.seed = time.time()
    args.print_freq = 30
    
    train_list, test_list = getTrainAndTestListFromPath(args.train_path, args.test_path)
    splitRatio = 0.8
    
    print('batch size is ', args.batch_size)
    print('cuda available? {}'.format(torch.cuda.is_available()))
    
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    
#     os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#     torch.cuda.manual_seed(args.seed)
    
    model = CSRNet()
    
    model = model.to(device)
    
    criterion = nn.MSELoss(size_average=False).to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.decay)

    if args.pre:
        if os.path.isfile(args.pre):
            print("=> loading checkpoint '{}'".format(args.pre))
            checkpoint = torch.load(args.pre)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.pre, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pre))
    
    for epoch in range(args.start_epoch, args.epochs):
        
        adjust_learning_rate(optimizer, epoch)
        
        subsetTrain, subsetValid = getTrainAndValidateList(train_list, splitRatio)
        
        train(subsetTrain, model, criterion, optimizer, epoch, device)
        prec1 = validate(subsetValid, model, criterion, device)
        
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        print(' * best MAE {mae:.3f} '
              .format(mae=best_prec1))
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.pre,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best,args.task)