Exemplo n.º 1
0
def joint_classification(args):
    train_transform = transforms.Compose([
        # transforms.RandomResizedCrop((args.input_size, args.input_size)),
        transforms.Resize((args.input_size, args.input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.4569, 0.4335, 0.3892],
                             [0.2093, 0.2065, 0.2046])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((args.input_size, args.input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.4569, 0.4335, 0.3892],
                             [0.2093, 0.2065, 0.2046])
    ])

    train_dataset = A2DClassification(args, train_transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers,
                              pin_memory=True, drop_last=True, shuffle=True)

    val_dataset = A2DClassification(args, val_transform, mode='val')
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers,
                            pin_memory=True, drop_last=False, shuffle=False)

    model = getJointClassifier(args)
    if args.cuda:
        model = model.cuda()

    criterion = bce_weight_loss(args=args)

    opt = get_finetune_optimizer(args, model)

    if os.path.exists('./save/joint_classification/snap/'):
        shutil.rmtree('./save/joint_classification/snap/')
    os.makedirs('./save/joint_classification/snap/')

    train_loss = []
    val_loss = []
    for epoch in range(args.max_epoches):
        train_loss.append(0)
        val_loss.append(0)
        for _, pack in enumerate(train_loader):
            imgs = pack[1]  # (N,t,c,m,n)
            labels = pack[2]  # (N,t,c,m,n)

            if args.cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()

            out, _ = model(imgs)
            loss = criterion.get_loss(out, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            train_loss[-1] += loss.item()

        with torch.no_grad():
            for _, pack in enumerate(val_loader):
                imgs = pack[1]  # (N,t,c,m,n)
                labels = pack[2]  # (N,t,c,m,n)

                if args.cuda:
                    imgs = imgs.cuda()
                    labels = labels.cuda()

                out, _ = model(imgs)
                loss = criterion.get_loss(out, labels)

                val_loss[-1] += loss.item()

        print('epoch:{} train_loss:{:.3f} val_loss:{:.4f}'.format(
            epoch, train_loss[-1], val_loss[-1]), flush=True)

        # plot
        plt.figure()
        plt.plot(range(len(train_loss)), train_loss, label='train_loss')
        plt.plot(range(len(val_loss)), val_loss, label='val_loss')
        plt.legend()
        if not os.path.exists('./save/joint_classification/imgs'):
            os.makedirs('./save/joint_classification/imgs')
        plt.savefig('./save/joint_classification/imgs/train_line.png')
        plt.close()

        snap_shot = {'epoch': epoch, 'train_loss': train_loss,
                     'val_loss': val_loss, 'state_dict': model.state_dict()}
        torch.save(
            snap_shot, './save/joint_classification/snap/snap_{}.pth.tar'.format(epoch))
Exemplo n.º 2
0
def joint_plot(args, thd=0.669):
    dir_path = os.path.join(args.save_root, 'joint_classification', 'imgs',
                            'res_samples')

    shutil.rmtree(dir_path)
    os.makedirs(dir_path)
    os.makedirs(os.path.join(dir_path, '1'))
    os.makedirs(os.path.join(dir_path, '0'))
    os.makedirs(os.path.join(dir_path, '10'))

    val_transform = transforms.Compose([
        transforms.Resize((args.input_size, args.input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.4569, 0.4335, 0.3892],
                             [0.2093, 0.2065, 0.2046])
    ])

    val_dataset = A2DClassification(args, val_transform, mode='val')
    val_loader = DataLoader(val_dataset,
                            batch_size=1,
                            num_workers=0,
                            pin_memory=True,
                            drop_last=False,
                            shuffle=False)

    model = getJointClassifier(args)
    model.load_state_dict(
        torch.load(os.path.join(args.save_root,
                                'joint_classification/snap_25.pth.tar'),
                   map_location='cpu')['state_dict'])
    grad_model = GradCam(model=model, target_layer_names=["7"], args=args)

    for iter, pack in enumerate(val_loader):
        if iter == 300:
            break
        part_path = pack[0][-1]
        imgs = pack[1]  # (N,t,c,m,n)
        labels = pack[2]  # (N,t,c,m,n)

        # 1: label 1 res 1
        # 0: label 1 res 0
        # 10: label 0 res 1
        gt = np.where(labels.squeeze().cpu().numpy() == 1)[0]
        res = None
        for index in gt:
            mask, output = grad_model(imgs, index)
            output = torch.sigmoid(output).detach().cpu().numpy()
            output[output >= thd] = 1
            output[output < thd] = 0
            show_cam_on_image(
                args, os.path.join(args.a2d_root, part_path), mask,
                os.path.join(
                    dir_path, '1',
                    part_path.split('/')[-1][:-4] + '_' + str(index) +
                    '.jpg') if output[index] == 1 else os.path.join(
                        dir_path, '0',
                        part_path.split('/')[-1][:-4] + '_' + str(index) +
                        '.jpg'))

            res = output

        res = np.where(output == 1)[0]
        for index in res:
            if labels.squeeze()[index] == 0:
                mask, output = grad_model(imgs, index)
                show_cam_on_image(
                    args, os.path.join(args.a2d_root, part_path), mask,
                    os.path.join(
                        dir_path, '10',
                        part_path.split('/')[-1][:-4] + '_' + str(index) +
                        '.jpg'))
Exemplo n.º 3
0
def eval_joint_classification(args, model_path='joint_classification/snap/'):
    val_transform = transforms.Compose([
        transforms.Resize((args.input_size, args.input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.4569, 0.4335, 0.3892],
                             [0.2093, 0.2065, 0.2046])
    ])

    val_dataset = A2DClassification(args, val_transform, mode='val')
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            num_workers=0,
                            pin_memory=True,
                            drop_last=False,
                            shuffle=False)

    snapList = os.listdir(os.path.join(args.save_root, model_path))
    snapList.sort(key=lambda x: int(x.split('.')[0].split('_')[1]))
    resList = []
    bestmodelNum = None
    best_res = None
    mode = 'mAP'

    for snap in snapList:
        model = getJointClassifier(args)
        if args.cuda:
            model = model.cuda()
        model.load_state_dict(
            torch.load(os.path.join(args.save_root, model_path, snap),
                       map_location='cpu')['state_dict'])

        total_res = []
        total_label = []
        with torch.no_grad():
            for iter, pack in enumerate(val_loader):
                imgs = pack[1]  # (N,t,c,m,n)
                labels = pack[2]  # (N,t,c,m,n)

                if args.cuda:
                    imgs = imgs.cuda()
                    labels = labels.cuda()

                out, fc = model(imgs)
                out = F.sigmoid(out)
                total_res.append(out.detach().cpu().numpy())
                total_label.append(labels.cpu().numpy())

        total_res = np.concatenate(total_res, axis=0)
        total_label = np.concatenate(total_label, axis=0)

        res = get_eval(total_res, total_label, mode=mode)
        print('snap:{} {}:{}'.format(snap, mode, res))
        resList.append(res)
        if best_res is None or res > best_res:
            best_res = res
            bestmodelNum = int(snap.split('.')[0].split('_')[1])

    print('best {}:{}'.format(mode, best_res))
    print('best model:{}'.format('snap_' + str(bestmodelNum)))
    plt.figure()
    plt.plot(range(len(resList)), resList, label='joint_model {}'.format(mode))
    plt.legend()
    if not os.path.exists('./save/joint_classification/imgs'):
        os.makedirs('./save/joint_classification/imgs')
    plt.savefig(
        './save/joint_classification/imgs/snaps_eval_{}.png'.format(mode))
    plt.close()
Exemplo n.º 4
0
def gen_joint_feature_aug(args):
    train_transform = transforms.Compose([
        transforms.Resize((args.input_size, args.input_size)),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.4569, 0.4335, 0.3892],
                             [0.2093, 0.2065, 0.2046])
    ])

    train_hp_transform = transforms.Compose([
        transforms.Resize((args.input_size, args.input_size)),
        transforms.RandomHorizontalFlip(1.0),
        transforms.ToTensor(),
        transforms.Normalize([0.4569, 0.4335, 0.3892],
                             [0.2093, 0.2065, 0.2046])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((args.input_size, args.input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.4569, 0.4335, 0.3892],
                             [0.2093, 0.2065, 0.2046])
    ])

    train_dataset = A2DClassification(args, train_transform, mode='train')
    train_loader = DataLoader(train_dataset,
                              batch_size=64,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              drop_last=False,
                              shuffle=False)

    train_hp_dataset = A2DClassification(args,
                                         train_hp_transform,
                                         mode='train')
    train_hp_loader = DataLoader(train_hp_dataset,
                                 batch_size=64,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=False,
                                 shuffle=False)

    val_dataset = A2DClassification(args, val_transform, mode='val')
    val_loader = DataLoader(val_dataset,
                            batch_size=64,
                            num_workers=args.num_workers,
                            pin_memory=True,
                            drop_last=False,
                            shuffle=False)

    model = getJointClassifier(args)
    model.load_state_dict(
        torch.load(os.path.join(args.save_root,
                                'joint_classification/snap/snap_43.pth.tar'),
                   map_location='cpu')['state_dict'])
    if args.cuda:
        model = model.cuda()

    res = {}
    with torch.no_grad():
        for iter, pack in enumerate(train_loader):
            id = pack[0]
            imgs = pack[1]  # (N,t,c,m,n)
            labels = pack[2]  # (N,t,c,m,n)

            if args.cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()

            fc = model.gen_feature(imgs).detach().cpu().numpy()
            for s_id in id:
                res[s_id] = [fc[id.index(s_id)]]

        for iter, pack in enumerate(val_loader):
            id = pack[0]
            imgs = pack[1]  # (N,t,c,m,n)
            labels = pack[2]  # (N,t,c,m,n)

            if args.cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()

            fc = model.gen_feature(imgs).detach().cpu().numpy()

            for s_id in id:
                res[s_id] = [fc[id.index(s_id)]]

        for iter, pack in enumerate(train_hp_loader):
            id = pack[0]
            imgs = pack[1]  # (N,t,c,m,n)
            labels = pack[2]  # (N,t,c,m,n)

            if args.cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()

            fc = model.gen_feature(imgs).detach().cpu().numpy()
            for s_id in id:
                res[s_id].append(fc[id.index(s_id)])

    np.save('./repo/joint_img_feature.npy', res)