def main():

    test_set = GazeDataset(
        root_dir=
        '/home/emannuell/Documentos/mestrado/GazeFollowing/GazeFollowData/data_new/',
        mat_file='',
        training='test')

    test_data_loader = DataLoader(test_set,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=8)

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    pretrained_dict = torch.load(
        '/home/emannuell/Documentos/mestrado/GazeFollowing/model/test04/epoch_15_loss_0.0558342523873.pkl'
    )
    model_dict = net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)

    test(net, test_data_loader)
Esempio n. 2
0
def main():

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    pretrained_dict = torch.load('../model/trained_model.pkl')

    model_dict = net.state_dict()

    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # print(pretrained_dict.summary())
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)

    test_image_path = sys.argv[1]
    xi, yi = detect_head(test_image_path)
    x = float(xi)
    y = float(yi)
    #x = float(sys.argv[2])
    #y = float(sys.argv[3])

    heatmap, p_x, p_y = test(net, test_image_path, (x, y))
    draw_result(test_image_path, (x, y), heatmap, (p_x, p_y))

    print(p_x, p_y)
def main():
    dis_test_sets = dataset_wrapper(
        root_dir='../../test_data/',
        mat_file='../../test_data/test2_annotations.mat',
        training='test')

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    net = load_pretrained_model(net, '../model/pretrained_models/')

    area_count = 8
    area_in_network = 2

    dis_test_data_loaders = []
    all_losses, all_errors = [], []
    info_lists, heatmaps = [], []
    for i in range(16):
        test_data_loader = DataLoader(dis_test_sets[i],
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=8)

        area_idx = int(i / area_in_network)
        net.module.change_fpn(area_idx)
        if not next(net.module.fpn_net.parameters()).is_cuda:
            net.module.fpn_net.cuda()

        cur_loss, cur_error, info_list, heatmap = test(net, test_data_loader)
        all_losses.append(cur_loss)
        all_errors.append(cur_error)

        np.savez('../npzs/multi_scale_concat_prediction_{}.npz'.format(str(i)),
                 info_list=info_list)
        np.savez('../npzs/multi_scale_concat_heatmaps_{}.npz'.format(str(i)),
                 heatmaps=heatmap)
        for info in info_list:
            info_lists.append(info)
        for cur_heatmap in heatmap:
            heatmaps.append(cur_heatmap)

    print(np.mean(all_losses, axis=0))
    print(np.mean(all_errors, axis=0))

    info_lists, heatmaps = np.array(info_lists), np.array(heatmaps)
    np.savez('../npzs/multi_scale_concat_prediction.npz', info_list=info_lists)
    np.savez('../npzs/multi_scale_concat_heatmaps.npz', heatmaps=heatmaps)
Esempio n. 4
0
def home():
    #python3 inference.py ../images/00004844.jpg 0.35636 0.23724
    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    pretrained_dict = torch.load('../model/trained_model.pkl')
    model_dict = net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)
    if request.method == 'POST':
        f = request.files['file']
        f.save(secure_filename(f.filename))
        fname = f.filename
        print(fname)

    path2 = '.'
    test_image_path = os.path.join(path2, f.filename)
    xi, yi = detect_head(test_image_path)
    x = float(xi)
    y = float(yi)
    print(test_image_path)
    # x = 0.372
    # y = 0.22267
    # 0.372,0.22267
    heatmap, p_x, p_y = test(net, test_image_path, (x, y))

    resultimg = draw_result(test_image_path, (x, y), heatmap, (p_x, p_y))

    #print(resultimg)

    img = Image.fromarray(resultimg, 'RGB')
    #print(img)
    # path = '/content/drive/My Drive/AI_PROJECT/GazeFollowing/code/static'
    # cv2.imwrite(os.path.join(path2 , 'tmp.jpg'), img)

    print(p_x, p_y)
    outim = ['tmp.png']
    for o in outim:
        shutil.copy(o, './static')
    return render_template('index.html')
Esempio n. 5
0
def main():

    test_set = GazeDataset(root_dir='../GazeFollowData/',
                           mat_file='../GazeFollowData/test2_annotations.mat',
                           training='test')
    test_data_loader = DataLoader(test_set, batch_size=1,
                                  shuffle=False, num_workers=8)

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    pretrained_dict = torch.load('../model/pretrained_model.pkl')
    model_dict = net.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)
    
    test(net, test_data_loader)
Esempio n. 6
0
def main():

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    pretrained_dict = torch.load('../model/pretrained_model.pkl')
    model_dict = net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)

    test_image_path = sys.argv[1]
    x = float(sys.argv[2])
    y = float(sys.argv[3])
    heatmap, p_x, p_y = test(net, test_image_path, (x, y))
    draw_result(test_image_path, (x, y), heatmap, (p_x, p_y))

    print(p_x, p_y)
import operator
import itertools
from scipy.io import loadmat
import logging
import matplotlib.pyplot as plt
import frame_helper as ef
import frame_to_video as fv
import multiprocessing as mp
import imageio

from scipy import signal

from utils import data_transforms
from utils import get_paste_kernel, kernel_map

net = GazeNet()
net = DataParallel(net)
net.cpu()
torch.device('cpu')
pretrained_dict = torch.load('./savedmodels/pretrained_model.pkl',
                             map_location=lambda storage, loc: storage)
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)


def generate_data_field(eye_point):
    """eye_point is (x, y) and between 0 and 1"""
    height, width = 224, 224
    x_grid = np.array(range(width)).reshape([1, width]).repeat(height, axis=0)
Esempio n. 8
0
def main():
    train_set = GazeDataset(root_dir='../GazeFollowData/',
                            mat_file='../GazeFollowData/train_annotations.mat',
                            training='train')
    train_data_loader = DataLoader(train_set,
                                   batch_size=32 * 4,
                                   shuffle=True,
                                   num_workers=16)

    test_set = GazeDataset(root_dir='../GazeFollowData/',
                           mat_file='../GazeFollowData/test2_annotations.mat',
                           training='test')
    test_data_loader = DataLoader(test_set,
                                  batch_size=32 * 4,
                                  shuffle=False,
                                  num_workers=8)

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    resume_training = False
    if resume_training:
        pretrained_dict = torch.load('../model/pretrained_model.pkl')
        model_dict = net.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        net.load_state_dict(model_dict)
        test(net, test_data_loader)
        exit()

    method = 'Adam'
    learning_rate = 0.0001

    optimizer_s1 = optim.Adam(
        [{
            'params': net.module.face_net.parameters(),
            'initial_lr': learning_rate
        }, {
            'params': net.module.face_process.parameters(),
            'initial_lr': learning_rate
        }, {
            'params': net.module.eye_position_transform.parameters(),
            'initial_lr': learning_rate
        }, {
            'params': net.module.fusion.parameters(),
            'initial_lr': learning_rate
        }],
        lr=learning_rate,
        weight_decay=0.0001)
    optimizer_s2 = optim.Adam([{
        'params': net.module.fpn_net.parameters(),
        'initial_lr': learning_rate
    }],
                              lr=learning_rate,
                              weight_decay=0.0001)

    optimizer_s3 = optim.Adam([{
        'params': net.parameters(),
        'initial_lr': learning_rate
    }],
                              lr=learning_rate * 0.1,
                              weight_decay=0.0001)

    lr_scheduler_s1 = optim.lr_scheduler.StepLR(optimizer_s1,
                                                step_size=5,
                                                gamma=0.1,
                                                last_epoch=-1)
    lr_scheduler_s2 = optim.lr_scheduler.StepLR(optimizer_s2,
                                                step_size=5,
                                                gamma=0.1,
                                                last_epoch=-1)
    lr_scheduler_s3 = optim.lr_scheduler.StepLR(optimizer_s3,
                                                step_size=5,
                                                gamma=0.1,
                                                last_epoch=-1)

    max_epoch = 25

    epoch = 0
    while epoch < max_epoch:
        if epoch == 0:
            lr_scheduler = lr_scheduler_s1
            optimizer = optimizer_s1
        elif epoch == 7:
            lr_scheduler = lr_scheduler_s2
            optimizer = optimizer_s2
        elif epoch == 15:
            lr_scheduler = lr_scheduler_s3
            optimizer = optimizer_s3

        lr_scheduler.step()

        running_loss = []
        for i, data in tqdm(enumerate(train_data_loader)):
            image, face_image, gaze_field, eye_position, gt_position, gt_heatmap = \
                data['image'], data['face_image'], data['gaze_field'], data['eye_position'], data['gt_position'], data['gt_heatmap']
            image, face_image, gaze_field, eye_position, gt_position, gt_heatmap = \
                map(lambda x: Variable(x.cuda()), [image, face_image, gaze_field, eye_position, gt_position, gt_heatmap])
            #for var in [image, face_image, gaze_field, eye_position, gt_position]:
            #    print var.shape

            optimizer.zero_grad()

            direction, predict_heatmap = net(
                [image, face_image, gaze_field, eye_position])

            heatmap_loss, m_angle_loss = \
                F_loss(direction, predict_heatmap, eye_position, gt_position, gt_heatmap)

            if epoch == 0:
                loss = m_angle_loss
            elif epoch >= 7 and epoch <= 14:
                loss = heatmap_loss
            else:
                loss = m_angle_loss + heatmap_loss

            loss.backward()
            optimizer.step()

            running_loss.append(
                [heatmap_loss.data[0], m_angle_loss.data[0], loss.data[0]])
            if i % 10 == 9:
                logging.info('%s %s %s' % (str(np.mean(running_loss, axis=0)),
                                           method, str(lr_scheduler.get_lr())))
                running_loss = []

        epoch += 1

        save_path = '../model/two_stage_fpn_concat_multi_scale_' + method
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        torch.save(net.state_dict(),
                   save_path + '/model_epoch{}.pkl'.format(epoch))

        test(net, test_data_loader)
def main():
    dataset_path = '/home/emannuell/Documentos/mestrado/dataset/data_new/'
    output_path = 'output/convergeTest'
    train_set = GazeDataset(root_dir=dataset_path, training='train')
    train_data_loader = DataLoader(train_set,
                                   batch_size=10,
                                   shuffle=True,
                                   num_workers=4)

    test_set = GazeDataset(root_dir=dataset_path, training='test')
    test_data_loader = DataLoader(test_set,
                                  batch_size=2,
                                  shuffle=False,
                                  num_workers=4)

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    resume_training = False
    if resume_training:
        pretrained_dict = torch.load('model/pretrained_model.pkl')
        model_dict = net.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        net.load_state_dict(model_dict)
        test(net, test_data_loader)
        exit()

    method = 'Adam'
    # 0.0001
    learning_rate = 0.001

    optimizer_s1 = optim.Adam(
        [{
            'params': net.module.head_pose_transform.parameters(),
            'initial_lr': learning_rate
        }, {
            'params': net.module.eye_position_transform.parameters(),
            'initial_lr': learning_rate
        }, {
            'params': net.module.fusion.parameters(),
            'initial_lr': learning_rate
        }],
        lr=learning_rate,
        weight_decay=0.0001)
    optimizer_s2 = optim.Adam([{
        'params': net.module.fpn_net.parameters(),
        'initial_lr': learning_rate
    }],
                              lr=learning_rate,
                              weight_decay=0.0001)

    optimizer_s3 = optim.Adam([{
        'params': net.parameters(),
        'initial_lr': learning_rate
    }],
                              lr=learning_rate * 0.1,
                              weight_decay=0.0001)

    lr_scheduler_s1 = optim.lr_scheduler.StepLR(optimizer_s1,
                                                step_size=5,
                                                gamma=0.1,
                                                last_epoch=-1)
    lr_scheduler_s2 = optim.lr_scheduler.StepLR(optimizer_s2,
                                                step_size=5,
                                                gamma=0.1,
                                                last_epoch=-1)
    lr_scheduler_s3 = optim.lr_scheduler.StepLR(optimizer_s3,
                                                step_size=5,
                                                gamma=0.1,
                                                last_epoch=-1)

    max_epoch = 25

    epoch = 0
    while epoch < max_epoch:
        if epoch == 0:
            lr_scheduler = lr_scheduler_s1
            optimizer = optimizer_s1
        # 5
        elif epoch == 15:
            lr_scheduler = lr_scheduler_s2
            optimizer = optimizer_s2
        # 9
        elif epoch == 20:
            lr_scheduler = lr_scheduler_s3
            optimizer = optimizer_s3
        # optimizer.step()
        lr_scheduler.step()

        running_loss = []
        ep_heatmap_loss = []
        ep_m_angle_loss = []
        for i, data in tqdm(enumerate(train_data_loader)):
            image, gaze_field, eye_position, gt_position, gt_heatmap, head_pose = \
                data['image'],  data['gaze_field'], data['eye_position'], data['gt_position'], data['gt_heatmap'], data['head_pose']
            image, gaze_field, eye_position, gt_position, gt_heatmap, head_pose = \
                map(lambda x: Variable(x.cuda()), [image, gaze_field, eye_position, gt_position, gt_heatmap, head_pose])

            optimizer.zero_grad()
            direction, predict_heatmap = net(
                [image, gaze_field, eye_position, head_pose])

            heatmap_loss, m_angle_loss = F_loss(direction, predict_heatmap,
                                                eye_position, gt_position,
                                                gt_heatmap)
            ep_heatmap_loss.append(np.array(heatmap_loss.cpu().data))
            ep_m_angle_loss.append(np.array(m_angle_loss.cpu().data))

            if epoch == 0:
                loss = m_angle_loss
            elif epoch >= 15 and epoch <= 20:
                loss = heatmap_loss
            else:
                loss = m_angle_loss + heatmap_loss

            loss.backward()
            optimizer.step()

            running_loss.append(
                [heatmap_loss.data, m_angle_loss.data, loss.data])
            # if i % 10 == 9:
            #     logging.info('%s %s %s'%(str(np.mean(running_loss, axis=0)), method, str(lr_scheduler.get_lr())))
            #     running_loss = []

        epoch += 1
        print('==== Training loss ====')
        logging.info('Epoch: %s' % epoch)
        logging.info('heatmap loss: %s' %
                     str(np.mean(np.array(ep_heatmap_loss))))
        logging.info('mean angle loss: %s' %
                     str(np.mean(np.array(ep_m_angle_loss))))
        logging.info('file: %s' % output_path +
                     '/epoch_{}_loss_{}.pkl'.format(epoch, loss.data))

        if not os.path.exists(output_path):
            os.makedirs(output_path)
        print('Saving model to output path: ',
              output_path + '/epoch_{}_loss_{}.pkl'.format(epoch, loss.data))
        torch.save(
            net.state_dict(),
            output_path + '/epoch_{}_loss_{}.pkl'.format(epoch, loss.data))
Esempio n. 10
0
def deepMain():
    # from deepface.detectors.detector_ssd import FaceDetectorSSDMobilenetV2

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()
    # face_detector = FaceDetectorSSDMobilenetV2()
    face_detector = cv2.CascadeClassifier(
        'model/lbpcascade_frontalface_improved.xml')
    # Load pretrained gaze following model
    pretrained_dict = torch.load('model/epoch_15_loss_0.0558342523873.pkl')
    model_dict = net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)

    cap = cv2.VideoCapture('video.mp4')
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    videoFrames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    print(width, height, videoFrames, fps)
    # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
    # out = cv2.VideoWriter('inspecaoCarroResult.avi',cv2.VideoWriter_fourcc('M','J','P','G'), fps, (width, height))
    print('Iniciando processamento do video...')
    while True:
        frameId = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
        if frameId < videoFrames:
            print(frameId, '/', videoFrames)
            ret, img = cap.read()
            t = time.time()
            originalImg = img
            img = cv2.resize(img, (640, 480))
            height, width, _ = img.shape
            faces = face_detector.detectMultiScale(img, 1.05, 3)
            for face in faces:
                print(face)
                # Precisa redimensionar imagem para aumentar o crop!
                faceImage = img[face[1]:face[1] + face[2],
                                face[0]:face[3] + face[0]]
                gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                formatDetected = [(face[1], face[1] + face[2]),
                                  (face[0], face[3] + face[0])]
                print(formatDetected)
                x_ = face[0]
                y_ = face[1]
                w_ = face[2]
                h_ = face[3]
                score = face.score
                x, y = int(x_ + w_ / 2), int(y_ + h_ / 2)
                print(x, y)
                cv2.rectangle(img, (x_, y_), (x_ + w_, y_ + h_), (0, 0, 255),
                              2)
                # Head pose estimation
                faceBbox = x_, y_, x_ + w_, y_ + h_
                y, p, r = headpose.detectHeadPose(img, faceBbox)
                # To normalize detections we fit with min and max angular examples from dataset
                MinMaxY = [-86.73414612, 87.62715149]
                MinMaxP = [-54.0966301, 36.50032043]
                MinMaxR = [-42.67565918, 42.16217041]
                MinMaxY.append(y)
                MinMaxP.append(p)
                MinMaxR.append(r)
                y = min_max_scaler.fit_transform(
                    np.array(MinMaxY).reshape((-1, 1)))
                p = min_max_scaler.fit_transform(
                    np.array(MinMaxP).reshape((-1, 1)))
                r = min_max_scaler.fit_transform(
                    np.array(MinMaxR).reshape((-1, 1)))
                headPoseAngles = float(y[-1]), float(p[-1]), float(r[-1])
                print('Head pose normalized: ', headPoseAngles)
                center_x = eye_center[0] / width
                center_y = eye_center[1] / height
                img = cv2.circle(img, (x, y), 2, (255, 0, 255),
                                 thickness=2)  # Magento
                heatmap, p_x, p_y = test(net, originalImg, (x, y),
                                         headPoseAngles)
                img = cv2.circle(img, (int(p_x * width), int(p_y * height)),
                                 2, (255, 0, 0),
                                 thickness=2)  # Azul
                img = draw_result(img, (center_x, center_y), heatmap,
                                  (p_x, p_y))
        else:
            break
        img = np.concatenate((img, img2), axis=1)
        img2 = img
        # Write the frame into the file 'output.avi'
        # img = cv2.resize(img, (1280, 480))
        # out.write(img)
        cv2.imshow('result', img)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    # out.release()
    cv2.destroyAllWindows()
Esempio n. 11
0
def main():

    # DEFINE PARAMETERS

    CHECKPOINT_PATH = '/media/samsung2080pc/New Volume/SAMSUNG/gazefollowing/trial1_Adam'
    EPOCH = 25
    LOAD_PATH = os.path.join(CHECKPOINT_PATH,
                             'model_epoch' + str(EPOCH) + '.pkl')
    DATASET_PATH = '/home/samsung2080pc/Documents/ObjectOfInterestV22Dataset'
    TEST_PATH = '/home/samsung2080pc/Documents/ObjectOfInterestV22Dataset/test.pickle'
    NUM_TEST = 0

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    # pretrained_dict = torch.load('../model/pretrained_model.pkl')

    pretrained_dict = torch.load(LOAD_PATH)
    model_dict = net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)

    f = open(TEST_PATH, 'rb')
    test_data = pickle.load(f)
    test_img_path = test_data[0]['filename']
    test_img_path = os.path.join(DATASET_PATH, test_img_path)
    h, w = cv2.imread(test_img_path).shape[:2]
    if NUM_TEST == 0:
        NUM_TEST = len(test_data)

    save_path = os.path.join(CHECKPOINT_PATH, 'epoch_' + str(EPOCH))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    start_time = time.time()
    results_pkl = []
    for i in tqdm(range(NUM_TEST)):

        save_path = os.path.join(CHECKPOINT_PATH, 'epoch_' + str(EPOCH),
                                 'out' + str(i).zfill(4) + '.png')

        # test_image_path = sys.argv[1]
        # x = float(sys.argv[2])
        # y = float(sys.argv[3])
        test_image_path = test_data[i]['filename']
        test_image_path = os.path.join(DATASET_PATH, test_image_path)
        x = test_data[i]['hx'] / w
        y = test_data[i]['hy'] / h
        # print(test_image_path,x,y)
        heatmap, p_x, p_y = test(net, test_image_path, (x, y))
        output = {  # PREDICTIONS
            'predictions': {
                'heatmap': heatmap,
                'p_x': p_x,
                'p_y': p_y
            },
            # INPUTS
            'inputs': {
                'image_path': test_data[i]['filename'],
                'eye_x': x,
                'eye_y': y,
            },
            # GROUND TRUTH
            'gt': {
                'gaze_cx': test_data[i]['gaze_cx'],
                'gaze_cy': test_data[i]['gaze_cy']
            }
        }
        results_pkl.append(output)
        draw_result(test_image_path, (x, y), heatmap, (p_x, p_y), save_path)

    end_time = time.time()
    process_time = end_time - start_time

    outfilename = os.path.join(CHECKPOINT_PATH, 'epoch_' + str(EPOCH),
                               'allresults.pkl')

    with open(outfilename, 'wb') as outfile:
        pickle.dump(results_pkl, outfile)
    print(results_pkl)
    print('Processed %i images in %f seconds.' % (NUM_TEST, process_time))
Esempio n. 12
0
def main():
    '''
    train_set = GazeDataset(root_dir='../../data/',
                            mat_file='../../data/train_annotations.mat',
                            training='train')
    train_data_loader = DataLoader(train_set, batch_size=48,
                                   shuffle=True, num_workers=8)

    test_set = GazeDataset(root_dir='../../test_data/',
                           mat_file='../../test_data/test2_annotations.mat',
                           training='test')
    test_data_loader = DataLoader(test_set, batch_size=32,
                                  shuffle=False, num_workers=8)
    '''

    dis_train_sets = dataset_wrapper(
        root_dir='../../data/',
        mat_file='../../data/train_annotations.mat',
        training='train')
    #dis_train_data_loader = DataLoader(dis_train_sets[0], batch_size=48,
    #                                   shuffle=True, num_workers=8)

    dis_test_sets = dataset_wrapper(
        root_dir='../../test_data/',
        mat_file='../../test_data/test2_annotations.mat',
        training='test')
    #dis_test_data_loader = DataLoader(dis_test_sets[0], batch_size=32,
    #                                  shuffle=False, num_workers=8)

    dis_train_data_loaders, dis_test_data_loaders = [], []
    for i in range(16):
        dis_train_data_loaders.append(
            DataLoader(dis_train_sets[i],
                       batch_size=40,
                       shuffle=True,
                       num_workers=8))
        dis_test_data_loaders.append(
            DataLoader(dis_test_sets[i],
                       batch_size=16,
                       shuffle=False,
                       num_workers=1))

    net = GazeNet()
    net = DataParallel(net)
    net.cuda()

    #print(next(net.module.fpn_net.parameters()).is_cuda)
    ##print(next(net.module.fpn_net.parameters()).is_cuda)
    area_count = 8
    area_in_network = int(16 / area_count)
    cur_area_idx = 0
    fpn_weights_transferred = False
    for i in range(area_count):
        net.module.change_fpn(i)
        if not next(net.module.fpn_net.parameters()).is_cuda:
            net.module.fpn_net.cuda()
    net.module.change_fpn(cur_area_idx)
    ##print(next(net.module.fpn_net.parameters()).is_cuda)
    #exit(0)

    resume_training = False
    if resume_training:
        pretrained_dict = torch.load('../model/pretrained_model.pkl')
        model_dict = net.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        net.load_state_dict(model_dict)
        test(net, test_data_loader)
        exit()

    method = 'Adam'
    learning_rate = 0.0001

    optimizer_s1 = optim.Adam(
        [{
            'params': net.module.face_net.parameters(),
            'initial_lr': learning_rate
        }, {
            'params': net.module.face_process.parameters(),
            'initial_lr': learning_rate
        }, {
            'params': net.module.eye_position_transform.parameters(),
            'initial_lr': learning_rate
        }, {
            'params': net.module.fusion.parameters(),
            'initial_lr': learning_rate
        }],
        lr=learning_rate,
        weight_decay=0.0001)
    #optimizer_s2 = optim.Adam([{'params': net.module.fpn_net.parameters(),
    #                            'initial_lr': learning_rate}],
    #                          lr=learning_rate, weight_decay=0.0001)
    optimizer_s2s, optimizer_s3s = [], []
    for i in range(area_count):
        net.module.change_fpn(i)
        optimizer_s2 = optim.Adam(
            [{
                'params': net.module.fpn_nets[i].parameters(),
                'initial_lr': learning_rate
            }],
            lr=learning_rate,
            weight_decay=0.0001)
        optimizer_s3 = optim.Adam([{
            'params': net.parameters(),
            'initial_lr': learning_rate
        }],
                                  lr=learning_rate * 0.1,
                                  weight_decay=0.0001)
        optimizer_s2s.append(optimizer_s2)
        optimizer_s3s.append(optimizer_s3)
    optimizer_s2 = optimizer_s2s[0]
    optimizer_s3 = optimizer_s3s[0]

    lr_scheduler_s1 = optim.lr_scheduler.StepLR(optimizer_s1,
                                                step_size=5,
                                                gamma=0.1,
                                                last_epoch=-1)
    #lr_scheduler_s2 = optim.lr_scheduler.StepLR(optimizer_s2, step_size=5, gamma=0.1, last_epoch=-1)
    lr_scheduler_s2s, lr_scheduler_s3s = [], []
    for i in range(area_count):
        lr_scheduler_s2 = optim.lr_scheduler.StepLR(optimizer_s2s[i],
                                                    step_size=5,
                                                    gamma=0.1,
                                                    last_epoch=-1)
        lr_scheduler_s3 = optim.lr_scheduler.StepLR(optimizer_s3s[i],
                                                    step_size=5,
                                                    gamma=0.1,
                                                    last_epoch=-1)
        lr_scheduler_s2s.append(lr_scheduler_s2)
        lr_scheduler_s3s.append(lr_scheduler_s3)
    lr_scheduler_s2 = lr_scheduler_s2s[0]
    lr_scheduler_s3 = lr_scheduler_s3s[0]

    # Set the model to use the first FPN
    net.module.change_fpn(cur_area_idx)

    max_epoch = 30

    epoch = 0
    #epoch = 7
    while epoch < max_epoch:
        logging.info('\n--- Epoch: %s\n' % str(epoch))
        if epoch == 0:
            lr_scheduler = lr_scheduler_s1
            optimizer = optimizer_s1
        elif epoch == 7:
            lr_scheduler = lr_scheduler_s2
            optimizer = optimizer_s2
        elif epoch == 15:
            lr_scheduler = lr_scheduler_s3
            optimizer = optimizer_s3

        #lr_scheduler.step()
        #lr_scheduler.step()

        running_loss = []

        #for data_loader_idx in range(len(dis_train_data_loaders)):
        for data_loader_idx in range(len(dis_train_data_loaders)):
            train_data_loader = dis_train_data_loaders[data_loader_idx]

            if epoch >= 10:
                #if epoch >= 7:
                if not fpn_weights_transferred:
                    net.module.transfer_fpn_weights()
                    fpn_weights_transferred = True

                area_idx = int(data_loader_idx / area_in_network)
                if cur_area_idx != area_idx:
                    cur_area_idx = area_idx
                    net.module.change_fpn(cur_area_idx)
                    if epoch < 15:
                        lr_scheduler = lr_scheduler_s2s[cur_area_idx]
                        optimizer = optimizer_s2s[cur_area_idx]
                    else:
                        lr_scheduler = lr_scheduler_s3s[cur_area_idx]
                        optimizer = optimizer_s3s[cur_area_idx]

            #if not next(net.module.fpn_net.parameters()).is_cuda:
            #    net.module.fpn_net.cuda()

            #test_data_loader = dis_test_data_loaders[data_loader_idx]
            #train_data_loader = DataLoader(dis_train_sets[data_loader_idx], batch_size=48,
            #                              shuffle=True, num_workers=2)
            #test_data_loaders = DataLoader(dis_test_sets[data_loader_idx], batch_size=32,
            #                              shuffle=False, num_workers=2)

            for i, data in tqdm(enumerate(train_data_loader)):
                image, face_image, gaze_field, eye_position, gt_position, gt_heatmap = \
                    data['image'], data['face_image'], data['gaze_field'], data['eye_position'], data['gt_position'], data['gt_heatmap']
                image, face_image, gaze_field, eye_position, gt_position, gt_heatmap = \
                    map(lambda x: x.cuda(), [image, face_image, gaze_field, eye_position, gt_position, gt_heatmap])
                # for var in [image, face_image, gaze_field, eye_position, gt_position]:
                #    print var.shape

                optimizer.zero_grad()

                direction, predict_heatmap = net(
                    [image, face_image, gaze_field, eye_position])

                heatmap_loss, m_angle_loss = \
                    F_loss(direction, predict_heatmap, eye_position, gt_position, gt_heatmap)

                if epoch == 0:
                    #if epoch < 7:
                    loss = m_angle_loss
                elif epoch >= 7 and epoch <= 14:
                    loss = heatmap_loss
                else:
                    loss = m_angle_loss + heatmap_loss

                loss.backward()
                optimizer.step()

                # running_loss.append([heatmap_loss.data[0],
                #                     m_angle_loss.data[0], loss.data[0]])
                running_loss.append(
                    [heatmap_loss.item(),
                     m_angle_loss.item(),
                     loss.item()])
                if i % 10 == 9:
                    logging.info('%s %s %s' %
                                 (str(np.mean(running_loss, axis=0)), method,
                                  str(lr_scheduler.get_last_lr())))
                    running_loss = []

        lr_scheduler.step()
        epoch += 1

        save_path = '../model/two_stage_fpn_concat_multi_scale_' + method
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if epoch % 5 == 0:
            torch.save(net.state_dict(),
                       save_path + '/model_epoch{}.pkl'.format(epoch))

        for i in range(16):
            torch.save(net.module.fpn_nets[i].state_dict(),
                       save_path + '/fpn_{}.pkl'.format(i))

        for data_loader_idx in range(len(dis_test_data_loaders)):
            test_data_loader = dis_test_data_loaders[data_loader_idx]
            if epoch > 10:
                area_idx = int(data_loader_idx / area_in_network)
                net.module.change_fpn(area_idx)
                cur_area_idx = area_idx
            test(net, test_data_loader)