Exemplo n.º 1
0
    close_cams_file = '/home/yogesh/kara/data/panoptic/closecams.list'

    RANDOM_ALL = True
    BATCH_SIZE = 32
    CHANNELS = 3
    FRAMES = 8
    SKIP_LEN = 2
    HEIGHT = 112
    WIDTH = 112
    PRECROP = False
    CLOSE_VIEWS = True
    VIEW_DIST_MAX = 10

    trainset = PanopticDataset(root_dir=data_root_dir, data_file=train_split,
                               resize_height=HEIGHT, resize_width=WIDTH,
                               clip_len=FRAMES, skip_len=SKIP_LEN,
                               random_all=RANDOM_ALL, close_views=CLOSE_VIEWS,
                               close_cams_file=close_cams_file, precrop=PRECROP)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    testset = PanopticDataset(root_dir=data_root_dir, data_file=train_split,
                               resize_height=HEIGHT, resize_width=WIDTH,
                               clip_len=FRAMES, skip_len=SKIP_LEN,
                               random_all=RANDOM_ALL, close_views=CLOSE_VIEWS,
                               close_cams_file=close_cams_file, precrop=PRECROP)
    testloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    print('TRAINING...')
    for batch_idx, (vp_diff, vid1, vid2) in enumerate(trainloader):
        print('{} {} {} {}'.format(batch_idx, vp_diff.size(), vid1.size(), vid2.size()))
Exemplo n.º 2
0
                             clip_len=FRAMES, skip_len=SKIP_LEN,
                             random_all=RANDOM_ALL, precrop=PRECROP)
        testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    elif DATASET.lower() == 'panoptic':
        data_root_dir, test_split, weights_path, output_video_dir = panoptic_config()

        # generator
        model = FullNetwork(vp_value_count=3, output_shape=(BATCH_SIZE, CHANNELS, FRAMES, HEIGHT, WIDTH))
        model.load_state_dict(torch.load(weights_path))
        model = model.to(device)

        if device == 'cuda':
            net = torch.nn.DataParallel(model)
            cudnn.benchmark = True

        criterion = nn.MSELoss()

        testset = PanopticDataset(root_dir=data_root_dir, data_file=test_split,
                                  resize_height=HEIGHT, resize_width=WIDTH,
                                  clip_len=FRAMES, skip_len=SKIP_LEN,
                                  random_all=RANDOM_ALL, precrop=PRECROP)
        testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    else:
        print('This network has only been set up to run on the NTU and panoptic datasets.')

    print_params()
    print(model)
    test_model()
Exemplo n.º 3
0
                             random_all=RANDOM_ALL, precrop=PRECROP)
        testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    elif DATASET.lower() == 'panoptic':
        data_root_dir, test_split, weights_path, output_video_dir = panoptic_config()

        # generator
        model = FullNetwork(vp_value_count=3, output_shape=(BATCH_SIZE, CHANNELS, FRAMES, HEIGHT, WIDTH))
        model.load_state_dict(torch.load(weights_path))
        model = model.to(device)

        if device == 'cuda':
            net = torch.nn.DataParallel(model)
            cudnn.benchmark = True

        criterion = nn.MSELoss()

        testset = PanopticDataset(root_dir=data_root_dir, data_file=test_split,
                                  resize_height=HEIGHT, resize_width=WIDTH,
                                  clip_len=FRAMES, skip_len=SKIP_LEN,
                                  random_all=RANDOM_ALL, close_views=CLOSE_VIEWS, view_dist_max=VIEW_DIST_MAX,
                                  precrop=PRECROP)
        testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    else:
        print('This network has only been set up to run on the NTU and panoptic datasets.')

    print_params()
    print(model)
    test_model()