def setup(self):
        """initial the datasets, model, loss and optimizer"""
        args = self.args
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        self.datasets = {
            x: Crowd(os.path.join(args.data_dir, x), args.crop_size,
                     args.downsample_ratio, args.is_gray, x)
            for x in ['train', 'val']
        }
        self.dataloaders = {
            x: DataLoader(self.datasets[x],
                          collate_fn=(train_collate
                                      if x == 'train' else default_collate),
                          batch_size=(args.batch_size if x == 'train' else 1),
                          shuffle=(True if x == 'train' else False),
                          num_workers=args.num_workers * self.device_count,
                          pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        self.model = vgg19()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume,
                                                      self.device))

        self.post_prob = Post_Prob(args.sigma, args.crop_size,
                                   args.downsample_ratio,
                                   args.background_ratio, args.use_background,
                                   self.device)
        self.criterion = Bay_Loss(args.use_background, self.device)
        self.save_list = Save_Handle(max_num=args.max_model_num)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_mae_1 = np.inf
        self.best_mse_1 = np.inf
        self.best_count = 0
        self.best_count_1 = 0
Esempio n. 2
0
def main(args):
    # use gpu
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    cur_device = torch.device('cuda:{}'.format(args.gpu))
    if args.loss == 'bayes':
        root = '/home/datamining/Datasets/CrowdCounting/sha_bayes_512/'
        train_path = root + 'train/'
        test_path = root + 'test/'
    elif args.bn:
        root = '/home/datamining/Datasets/CrowdCounting/sha_512_a/'
        train_path = root + 'train/'
        test_path = root + 'test/'
    else:
        if args.dataset == 'sha':
            root = '/home/datamining/Datasets/CrowdCounting/shanghaitech/part_A_final/'
            train_path = root + 'train_data/images'
            test_path = root + 'test_data/images/'
        elif args.dataset == 'shb':
            root = '/home/datamining/Datasets/CrowdCounting/shb_1024_f15/'
            train_path = root + 'train/'
            test_path = root + 'test/'
        elif args.dataset == 'qnrf':
            root = '/home/datamining/Datasets/CrowdCounting/qnrf_1024_a/'
            train_path = root + 'train/'
            test_path = root + 'test/'

    downsample_ratio = args.downsample
    train_loader, test_loader, train_img_paths, test_img_paths = get_loader(
        train_path, test_path, downsample_ratio, args)

    model_dict = {
        'VGG16_13': M_CSRNet,
        'DefCcNet': DefCcNet,
        'Res50_back3': Res50,
        'InceptionV3': Inception3CC,
        'CAN': CANNet
    }
    model_name = args.model
    dataset_name = args.dataset
    net = model_dict[model_name](downsample=args.downsample,
                                 bn=args.bn > 0,
                                 objective=args.objective,
                                 sp=(args.sp > 0),
                                 se=(args.se > 0),
                                 NL=args.nl)
    net.cuda()
    if args.bn > 0:
        save_name = '{}_{}_{}_bn{}_ps{}_{}'.format(model_name, dataset_name,
                                                   str(int(args.bn)),
                                                   str(args.crop_size),
                                                   args.loss)
    else:
        save_name = '{}_d{}{}{}{}{}_{}_{}_cr{}_{}{}{}{}{}{}'.format(
            model_name, str(args.downsample), '_sp' if args.sp else '',
            '_se' if args.se else '',
            '_' + args.nl if args.nl != 'relu' else '',
            '_vp' if args.val_patch else '', dataset_name, args.crop_mode,
            str(args.crop_scale), args.loss, '_wu' if args.warm_up else '',
            '_cl' if args.curriculum == 'W' else '', '_v' +
            str(int(args.value_factor)) if args.value_factor != 1 else '',
            '_amp' + str(args.amp_k) if args.objective == 'dmp+amp' else '',
            '_bg' if args.use_bg else '')
    save_path = "/home/datamining/Models/CrowdCounting/" + save_name + ".pth"
    logger = get_logger('logs/' + save_name + '.txt')
    for k, v in args.__dict__.items():  # save args
        logger.info("{}: {}".format(k, v))
    if os.path.exists(save_path) and args.resume:
        net.load_state_dict(torch.load(save_path))
        print('{} loaded!'.format(save_path))

    value_factor = args.value_factor
    freq = 100

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.decay)
    elif args.optimizer == 'SGD':
        # not converage
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=args.lr,
                                    momentum=0.95,
                                    weight_decay=args.decay)

    if args.loss == 'bayes':
        bayes_criterion = Bay_Loss(True, cur_device)
        post_prob = Post_Prob(sigma=8.0,
                              c_size=args.crop_size,
                              stride=1,
                              background_ratio=0.15,
                              use_background=True,
                              device=cur_device)
    else:
        mse_criterion = nn.MSELoss().cuda()

    if args.scheduler == 'plt':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.9,
                                                   patience=10,
                                                   verbose=True)
    elif args.scheduler == 'cos':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=50,
                                                   eta_min=0)
    elif args.scheduler == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8)
    elif args.scheduler == 'exp':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
    elif args.scheduler == 'cyclic' and args.optimizer == 'SGD':
        scheduler = lr_scheduler.CyclicLR(
            optimizer,
            base_lr=args.lr * 0.01,
            max_lr=args.lr,
            step_size_up=25,
        )
    elif args.scheduler == 'None':
        scheduler = None
    else:
        print('scheduler name error!')

    if args.val_patch:
        best_mae, best_rmse = val_patch(net, test_loader, value_factor)
    elif args.loss == 'bayes':
        best_mae, best_rmse = val_bayes(net, test_loader, value_factor)
    else:
        best_mae, best_rmse = val(net, test_loader, value_factor)
    if args.scheduler == 'plt':
        scheduler.step(best_mae)
    ssim_loss = pytorch_ssim.SSIM(window_size=11)
    for epoch in range(args.epochs):
        if args.crop_mode == 'curriculum':
            # every 20%, change the dataset
            if (epoch + 1) % (args.epochs // 5) == 0:
                print('change dataset')
                single_dataset = RawDataset(
                    train_img_paths, transform, args.crop_mode,
                    downsample_ratio, args.crop_scale,
                    (epoch + 1.0 + args.epochs // 5) / args.epochs)
                train_loader = torch.utils.data.DataLoader(single_dataset,
                                                           shuffle=True,
                                                           batch_size=1,
                                                           num_workers=8)

        train_loss = 0.0
        if args.loss == 'bayes':
            epoch_mae = AverageMeter()
            epoch_mse = AverageMeter()
        net.train()
        if args.warm_up and epoch < args.warm_up_steps:
            linear_warm_up_lr(optimizer, epoch, args.warm_up_steps, args.lr)
        for it, data in enumerate(train_loader):
            if args.loss == 'bayes':
                inputs, points, targets, st_sizes = data
                img = inputs.to(cur_device)
                st_sizes = st_sizes.to(cur_device)
                gd_count = np.array([len(p) for p in points], dtype=np.float32)
                points = [p.to(cur_device) for p in points]
                targets = [t.to(cur_device) for t in targets]
            else:
                img, target, _, amp_gt = data
                img = img.cuda()
                target = value_factor * target.float().unsqueeze(1).cuda()
                amp_gt = amp_gt.cuda()
            #print(img.shape)
            optimizer.zero_grad()

            #print(target.shape)
            if args.objective == 'dmp+amp':
                output, amp = net(img)
                output = output * amp
            else:
                output = net(img)

            if args.curriculum == 'W':
                delta = (output - target)**2
                k_w = 2e-3 * args.value_factor * args.downsample**2
                b_w = 5e-3 * args.value_factor * args.downsample**2
                T = torch.ones_like(target,
                                    dtype=torch.float32) * epoch * k_w + b_w
                W = T / torch.max(T, output)
                delta = delta * W
                mse_loss = torch.mean(delta)
            else:
                mse_loss = mse_criterion(output, target)

            if args.loss == 'mse+lc':
                loss = mse_loss + 1e2 * cal_lc_loss(output,
                                                    target) * args.downsample
            elif args.loss == 'ssim':
                loss = 1 - ssim_loss(output, target)
            elif args.loss == 'mse+ssim':
                loss = 100 * mse_loss + 1e-2 * (1 - ssim_loss(output, target))
            elif args.loss == 'mse+la':
                loss = mse_loss + cal_spatial_abstraction_loss(output, target)
            elif args.loss == 'la':
                loss = cal_spatial_abstraction_loss(output, target)
            elif args.loss == 'ms-ssim':
                #to do
                pass
            elif args.loss == 'adversial':
                # to do
                pass
            elif args.loss == 'bayes':
                prob_list = post_prob(points, st_sizes)
                loss = bayes_criterion(prob_list, targets, output)
            else:
                loss = mse_loss

            # add the cross entropy loss for attention map
            if args.objective == 'dmp+amp':
                cross_entropy = (amp_gt * torch.log(amp) +
                                 (1 - amp_gt) * torch.log(1 - amp)) * -1
                cross_entropy_loss = torch.mean(cross_entropy)
                loss = loss + cross_entropy_loss * args.amp_k

            loss.backward()
            optimizer.step()
            data_loss = loss.item()
            train_loss += data_loss
            if args.loss == 'bayes':
                N = inputs.size(0)
                pre_count = torch.sum(output.view(N, -1),
                                      dim=1).detach().cpu().numpy()
                res = pre_count - gd_count
                epoch_mse.update(np.mean(res * res), N)
                epoch_mae.update(np.mean(abs(res)), N)

            if args.loss != 'bayes' and it % freq == 0:
                print(
                    '[ep:{}], [it:{}], [loss:{:.8f}], [output:{:.2f}, target:{:.2f}]'
                    .format(epoch + 1, it, data_loss, output[0].sum().item(),
                            target[0].sum().item()))
        if args.val_patch:
            mae, rmse = val_patch(net, test_loader, value_factor)
        elif args.loss == 'bayes':
            mae, rmse = val_bayes(net, test_loader, value_factor)
        else:
            mae, rmse = val(net, test_loader, value_factor)
        if not (args.warm_up and epoch < args.warm_up_steps):
            if args.scheduler == 'plt':
                scheduler.step(best_mae)
            elif args.scheduler != 'None':
                scheduler.step()

        if mae + 0.1 * rmse < best_mae + 0.1 * best_rmse:
            best_mae, best_rmse = mae, rmse
            torch.save(net.state_dict(), save_path)

        if args.loss == 'bayes':
            logger.info(
                '{} Epoch {}/{} Loss:{:.8f},MAE:{:.2f},RMSE:{:.2f} lr:{:.8f}, [CUR]:{mae:.1f}, {rmse:.1f}, [Best]:{b_mae:.1f}, {b_rmse:.1f}'
                .format(model_name,
                        epoch + 1,
                        args.epochs,
                        train_loss / len(train_loader),
                        epoch_mae.get_avg(),
                        np.sqrt(epoch_mse.get_avg()),
                        optimizer.param_groups[0]['lr'],
                        mae=mae,
                        rmse=rmse,
                        b_mae=best_mae,
                        b_rmse=best_rmse))
        else:
            logger.info(
                '{} Epoch {}/{} Loss:{:.8f}, lr:{:.8f}, [CUR]:{mae:.1f}, {rmse:.1f}, [Best]:{b_mae:.1f}, {b_rmse:.1f}'
                .format(model_name,
                        epoch + 1,
                        args.epochs,
                        train_loss / len(train_loader),
                        optimizer.param_groups[0]['lr'],
                        mae=mae,
                        rmse=rmse,
                        b_mae=best_mae,
                        b_rmse=best_rmse))