Ejemplo n.º 1
0
    model = EventDetector(pretrain=True,
                          width_mult=1.,
                          lstm_layers=1,
                          lstm_hidden=256,
                          device=device,
                          bidirectional=True,
                          dropout=False,
                          use_no_element=use_no_element
                          )
    #print('model.py, class EventDetector()')

    freeze_layers(k, model)
    #print('utils.py, func freeze_laters()')
    model.train()
    model.to(device)
    print('Loading Data')


    # TODO: vid_dirのpathをかえる。stsqの動画を切り出したimage全部が含まれているdirにする
    if use_no_element == False:
        dataset = StsqDB(data_file='data/no_ele/seq_length_{}/train_split_{}.pkl'.format(args.seq_length, args.split),
                        vid_dir='/home/akiho/projects/golfdb/data/videos_40/',
                        seq_length=int(seq_length),
                        transform=transforms.Compose([ToTensor(),
                                                    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
                        train=True)
    else:
        dataset = StsqDB(data_file='data/seq_length_{}/train_split_{}.pkl'.format(args.seq_length, args.split),
                    vid_dir='/home/akiho/projects/golfdb/data/videos_40/',
                    seq_length=int(seq_length),
Ejemplo n.º 2
0
        # transform=transforms.Compose(
        #     [ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
        # ),
        train=True,
        base_list=base_list,
        event_th=event_th,
        median=True,
    )

    model.cnn[0][0] = nn.Conv2d(len(base_list),
                                32,
                                3,
                                stride=2,
                                padding=1,
                                bias=False)
    model = model.to(device)
    # freeze_layers(k, model)
    model.train()
    # model.cuda()

    data_loader = DataLoader(dataset,
                             batch_size=bs,
                             shuffle=True,
                             num_workers=n_cpu,
                             drop_last=True)

    # the 8 golf swing events are classes 0 through 7, no-event is class 8
    # the ratio of events to no-events is approximately 1:35 so weight classes accordingly:
    weights = torch.FloatTensor(
        [1 / 8, 1 / 8, 1 / 8, 1 / 8, 1 / 8, 1 / 8, 1 / 8, 1 / 8,
         1 / 35]).to(device)