示例#1
0
def main():

    net = model.A2J_model(num_classes=keypointsNumber)
    net.load_state_dict(torch.load(model_dir))
    net = net.cuda()
    net.eval()

    post_precess = anchor.post_process(
        shape=[cropHeight // 16, cropWidth // 16],
        stride=16,
        P_h=None,
        P_w=None)

    output = torch.FloatTensor()

    for i, (img, label) in tqdm(enumerate(test_dataloaders)):
        with torch.no_grad():

            img, label = img.cuda(), label.cuda()
            heads = net(img)
            pred_keypoints = post_precess(heads, voting=False)
            output = torch.cat([output, pred_keypoints.data.cpu()], 0)

    result = output.cpu().data.numpy()
    errTotal = errorCompute(result, keypointsUVD_test, center_test)
    writeTxt(result, center_test)

    print('Error:', errTotal)
示例#2
0
def main():

    net = model.A2J_model(num_classes=keypointsNumber, is_3D=False)
    net.load_state_dict(torch.load(model_dir))
    net = net.cuda()
    net.eval()

    post_precess = anchor.post_process(
        shape=[cropHeight // 16, cropWidth // 16],
        stride=16,
        P_h=None,
        P_w=None,
        is_3D=False)
    output = torch.FloatTensor()

    for i, (img, label) in tqdm(enumerate(test_dataloaders)):

        with torch.no_grad():

            img, label = img.cuda(), label.cuda()
            heads = net(img)
            pred_keypoints = post_precess(heads, voting=False)
            output = torch.cat([output, pred_keypoints.data.cpu()], 0)

    result = output.cpu().data.numpy()
    print('Accuracy 0.05:',
          evaluationPDJ(result.copy(), keypointsTest, bndboxTest, 0.05))
    print('Accuracy 0.10:',
          evaluationPDJ(result.copy(), keypointsTest, bndboxTest, 0.1))
    print('Accuracy 0.15:',
          evaluationPDJ(result.copy(), keypointsTest, bndboxTest, 0.15))
    print('Accuracy 0.20:',
          evaluationPDJ(result.copy(), keypointsTest, bndboxTest, 0.2))
def main():

    net = model.A2J_model(num_classes=keypointsNumber)
    net.load_state_dict(torch.load(model_dir))
    net = net.cuda()
    net.eval()

    post_precess = anchor.post_process(
        shape=[cropHeight // 16, cropWidth // 16],
        stride=16,
        P_h=None,
        P_w=None)

    output = torch.FloatTensor()

    for i, (img, label) in tqdm(enumerate(test_dataloaders)):
        with torch.no_grad():

            img, label = img.cuda(), label.cuda()
            heads = net(img)
            pred_keypoints = post_precess(heads, voting=False)
            output = torch.cat([output, pred_keypoints.data.cpu()], 0)

    result = output.cpu().data.numpy()
    Accuracy_test = evaluation10CMRule(result, keypointsWorldtest, bndbox_test,
                                       center_test)
    print('Accuracy:', Accuracy_test)
    evaluation10CMRule_perJoint(result, keypointsWorldtest, bndbox_test,
                                center_test)
示例#4
0
def train():

    net = model.A2J_model(num_classes=keypointsNumber)
    net = net.cuda()

    post_precess = anchor.post_process(
        shape=[cropHeight // 16, cropWidth // 16],
        stride=16,
        P_h=None,
        P_w=None)
    criterion = anchor.A2J_loss(shape=[cropHeight // 16, cropWidth // 16],
                                thres=[16.0, 32.0],
                                stride=16,
                                spatialFactor=spatialFactor,
                                img_shape=[cropHeight, cropWidth],
                                P_h=None,
                                P_w=None)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=learning_rate,
                                 weight_decay=Weight_Decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)

    logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \
                    filename=os.path.join(save_dir, 'train.log'), level=logging.INFO)
    logging.info('======================================================')

    for epoch in range(nepoch):
        net = net.train()
        train_loss_add = 0.0
        Cls_loss_add = 0.0
        Reg_loss_add = 0.0
        timer = time.time()

        # Training loop
        for i, (img, label) in enumerate(train_dataloaders):

            torch.cuda.synchronize()

            img, label = img.cuda(), label.cuda()

            heads = net(img)
            #print(regression)
            optimizer.zero_grad()

            Cls_loss, Reg_loss = criterion(heads, label)

            loss = 1 * Cls_loss + Reg_loss * RegLossFactor
            loss.backward()
            optimizer.step()

            torch.cuda.synchronize()

            train_loss_add = train_loss_add + (loss.item()) * len(img)
            Cls_loss_add = Cls_loss_add + (Cls_loss.item()) * len(img)
            Reg_loss_add = Reg_loss_add + (Reg_loss.item()) * len(img)

            # printing loss info
            if i % 10 == 0:
                print('epoch: ', epoch, ' step: ', i, 'Cls_loss ',
                      Cls_loss.item(), 'Reg_loss ', Reg_loss.item(),
                      ' total loss ', loss.item())

        scheduler.step(epoch)

        # time taken
        torch.cuda.synchronize()
        timer = time.time() - timer
        timer = timer / TrainImgFrames
        print('==> time to learn 1 sample = %f (ms)' % (timer * 1000))

        train_loss_add = train_loss_add / TrainImgFrames
        Cls_loss_add = Cls_loss_add / TrainImgFrames
        Reg_loss_add = Reg_loss_add / TrainImgFrames
        print('mean train_loss_add of 1 sample: %f, #train_indexes = %d' %
              (train_loss_add, TrainImgFrames))
        print('mean Cls_loss_add of 1 sample: %f, #train_indexes = %d' %
              (Cls_loss_add, TrainImgFrames))
        print('mean Reg_loss_add of 1 sample: %f, #train_indexes = %d' %
              (Reg_loss_add, TrainImgFrames))

        Error_test = 0
        Error_train = 0
        Error_test_wrist = 0

        if (epoch % 1 == 0):
            net = net.eval()
            output = torch.FloatTensor()
            outputTrain = torch.FloatTensor()

            for i, (img, label) in tqdm(enumerate(test_dataloaders)):
                with torch.no_grad():
                    img, label = img.cuda(), label.cuda()
                    heads = net(img)
                    pred_keypoints = post_precess(heads, voting=False)
                    output = torch.cat([output, pred_keypoints.data.cpu()], 0)

            result = output.cpu().data.numpy()
            Error_test = errorCompute(result, keypointsUVD_test, center_test)
            print('epoch: ', epoch, 'Test error:', Error_test)
            saveNamePrefix = '%s/net_%d_wetD_' % (save_dir, epoch) + str(
                Weight_Decay) + '_depFact_' + str(
                    spatialFactor) + '_RegFact_' + str(
                        RegLossFactor) + '_rndShft_' + str(RandCropShift)
            torch.save(net.state_dict(), saveNamePrefix + '.pth')

        # log
        logging.info(
            'Epoch#%d: total loss=%.4f, Cls_loss=%.4f, Reg_loss=%.4f, Err_test=%.4f, lr = %.6f'
            % (epoch, train_loss_add, Cls_loss_add, Reg_loss_add, Error_test,
               scheduler.get_lr()[0]))