Exemplo n.º 1
0
def main():
    ''' concise script for training '''
    # optional two command-line arguments
    path_indata = './DHF1K'
    path_output = './output'
    if len(sys.argv) > 1:
        path_indata = sys.argv[1]
        if len(sys.argv) > 2:
            path_output = sys.argv[2]

    # we checked that using only 2 gpus is enough to produce similar results
    num_gpu = 2
    pile = 5
    batch_size = 8
    num_iters = 1000
    len_temporal = 32
    file_weight = './S3D_kinetics400.pt'
    path_output = os.path.join(path_output, time.strftime("%m-%d_%H-%M-%S"))
    if not os.path.isdir(path_output):
        os.makedirs(path_output)

    model = TASED_v2()

    # load the weight file and copy the parameters
    if os.path.isfile(file_weight):
        print('loading weight file')
        weight_dict = torch.load(file_weight)
        model_dict = model.state_dict()
        for name, param in weight_dict.items():
            if 'module' in name:
                name = '.'.join(name.split('.')[1:])
            if 'base.' in name:
                bn = int(name.split('.')[1])
                sn_list = [0, 5, 8, 14]
                sn = sn_list[0]
                if bn >= sn_list[1] and bn < sn_list[2]:
                    sn = sn_list[1]
                elif bn >= sn_list[2] and bn < sn_list[3]:
                    sn = sn_list[2]
                elif bn >= sn_list[3]:
                    sn = sn_list[3]
                name = '.'.join(name.split('.')[2:])
                name = 'base%d.%d.' % (sn_list.index(sn) + 1, bn - sn) + name
            if name in model_dict:
                if param.size() == model_dict[name].size():
                    model_dict[name].copy_(param)
                else:
                    print(' size? ' + name, param.size(),
                          model_dict[name].size())
            else:
                print(' name? ' + name)

        print(' loaded')
    else:
        print('weight file?')

    # parameter setting for fine-tuning
    params = []
    for key, value in dict(model.named_parameters()).items():
        if 'convtsp' in key:
            params += [{'params': [value], 'key': key + '(new)'}]
        else:
            params += [{'params': [value], 'lr': 0.001, 'key': key}]

    optimizer = torch.optim.SGD(params,
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=2e-7)
    criterion = KLDLoss()

    model = model.cuda()
    model = torch.nn.DataParallel(model, device_ids=range(num_gpu))
    torch.backends.cudnn.benchmark = False
    model.train()

    train_loader = InfiniteDataLoader(DHF1KDataset(path_indata, len_temporal),
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=24)

    i, step = 0, 0
    loss_sum = 0
    start_time = time.time()
    for clip, annt in islice(train_loader, num_iters * pile):
        with torch.set_grad_enabled(True):
            output = model(clip.cuda())
            loss = criterion(output, annt.cuda())

        loss_sum += loss.item()
        loss.backward()
        if (i + 1) % pile == 0:
            optimizer.step()
            optimizer.zero_grad()
            step += 1

            # whole process takes less than 3 hours
            print('iteration: [%4d/%4d], loss: %.4f, %s' %
                  (step, num_iters, loss_sum / pile,
                   timedelta(seconds=int(time.time() - start_time))),
                  flush=True)
            loss_sum = 0

            # adjust learning rate
            if step in [750, 950]:
                for opt in optimizer.param_groups:
                    if 'new' in opt['key']:
                        opt['lr'] *= 0.1

            if step % 25 == 0:
                torch.save(model.state_dict(),
                           os.path.join(path_output, 'iter_%04d.pt' % step))

        i += 1
Exemplo n.º 2
0
def main():
    ''' read frames in path_indata and generate frame-wise saliency maps in path_output '''
    # optional two command-line arguments
    path_indata = './example'
    path_output = './output'
    if len(sys.argv) > 1:
        path_indata = sys.argv[1]
        if len(sys.argv) > 2:
            path_output = sys.argv[2]
    if not os.path.isdir(path_output):
        os.makedirs(path_output)

    len_temporal = 32
    file_weight = './TASED_updated.pt'

    model = TASED_v2()

    # load the weight file and copy the parameters
    if os.path.isfile(file_weight):
        print ('loading weight file')
        weight_dict = torch.load(file_weight)
        model_dict = model.state_dict()
        for name, param in weight_dict.items():
            if 'module' in name:
                name = '.'.join(name.split('.')[1:])
            if name in model_dict:
                if param.size() == model_dict[name].size():
                    model_dict[name].copy_(param)
                else:
                    print (' size? ' + name, param.size(), model_dict[name].size())
            else:
                print (' name? ' + name)

        print (' loaded')
    else:
        print ('weight file?')

    model = model.cuda()
    torch.backends.cudnn.benchmark = False
    model.eval()

    # iterate over the path_indata directory
    list_indata = [d for d in os.listdir(path_indata) if os.path.isdir(os.path.join(path_indata, d))]
    list_indata.sort()
    for dname in list_indata:
        print ('processing ' + dname)
        list_frames = [f for f in os.listdir(os.path.join(path_indata, dname)) if os.path.isfile(os.path.join(path_indata, dname, f))]
        list_frames.sort()

        # process in a sliding window fashion
        if len(list_frames) >= 2*len_temporal-1:
            path_outdata = os.path.join(path_output, dname)
            if not os.path.isdir(path_outdata):
                os.makedirs(path_outdata)

            snippet = []
            for i in range(len(list_frames)):
                img = cv2.imread(os.path.join(path_indata, dname, list_frames[i]))
                img = cv2.resize(img, (384, 224))
                img = img[...,::-1]
                snippet.append(img)

                if i >= len_temporal-1:
                    clip = transform(snippet)

                    process(model, clip, path_outdata, i)

                    # process first (len_temporal-1) frames
                    if i < 2*len_temporal-2:
                        process(model, torch.flip(clip, [1]), path_outdata, i-len_temporal+1)

                    del snippet[0]

        else:
            print (' more frames are needed')
Exemplo n.º 3
0
def main():
    ''' read frames in path_indata and generate frame-wise saliency maps in path_output '''
    # optional two command-line arguments
    args = parse_args()
    config_path = args.config
    config = edict(yaml.load(open(config_path)))
    # path_indata = '/data2/yuanx/QoEData/sailency-models/TASED-Net/example/'
    # path_output = '/data2/yuanx/QoEData/sailency-models/TASED-Net/output/'
    # model_path = '/data2/yuanx/QoEData/sailency-models/TASED-Net/models/'
    path_output = config.output_path
    model_path = config.model_path

    if not os.path.isdir(path_output):
        os.makedirs(path_output)

    len_temporal = 32
    file_weight = os.path.join(model_path, 'TASED_v2.pt')

    model = TASED_v2()
    # load the weight file and copy the parameters
    if os.path.isfile(file_weight):
        print('loading weight file')
        weight_dict = torch.load(file_weight)
        model_dict = model.state_dict()
        for name, param in weight_dict.items():
            if 'module' in name:
                name = '.'.join(name.split('.')[1:])
            if name in model_dict:
                if param.size() == model_dict[name].size():
                    model_dict[name].copy_(param)
                else:
                    print(' size? ' + name, param.size(),
                          model_dict[name].size())
            else:
                print(' name? ' + name)

        print(' loaded')
    else:
        print('weight file?')

    model = model.cuda()
    torch.backends.cudnn.benchmark = False
    model.eval()

    # iterate over the path_indata directory
    # list_indata = [d for d in os.listdir(path_indata) if os.path.isdir(os.path.join(path_indata, d))]
    list_video = [
        os.path.join(config.data.base_path, v) for v in config.data.video_list
    ]
    # list_indata.sort()
    for i_vname, vname in enumerate(list_video):
        # import pdb; pdb.set_trace()
        for dash_idx in range(config.data.quailty_num[i_vname]):
            smaps = []
            v_name = vname + str((dash_idx)).zfill(2)
            print('processing ' + v_name)
            # import pdb; pdb.set_trace()
            # list_frames = [f for f in os.listdir(os.path.join(path_indata, vname)) if os.path.isfile(os.path.join(path_indata, vname, f))]
            # list_frames.sort()
            v_name = v_name + '.' + config.data.video_suffix
            # vname = vname + config.data.quailty_index + '.' + config.data.video_suffix
            capture = cv2.VideoCapture(v_name)
            read_flag, img = capture.read()
            i = 0
            # process in a sliding window fashion
            # suppose list_frames always > 2*len_temporal
            # if len(list_frames) >= 2*len_temporal-1:
            path_outdata = os.path.join(path_output,
                                        v_name.split('/')[-1].split('.')[0])
            encoded_vid_path = os.path.join(path_outdata, "sailency.mp4")
            saliency_map_path = os.path.join(path_outdata, "sailency")

            if not os.path.isdir(path_outdata):
                os.makedirs(path_outdata)

            f, axarr = plt.subplots(1, 2, figsize=(10, 3))
            snippet = []
            # for i in range(len(list_frames)):
            while (read_flag):
                # print(i)
                # img = cv2.imread(os.path.join(path_indata, vname, list_frames[i]))
                img = cv2.resize(img, (384, 224))
                img = img[..., ::-1]
                snippet.append(img)

                if i >= len_temporal - 1:
                    clip = transform(snippet)

                    smaps.append(
                        process(model, clip, path_outdata, i, snippet[-1], f,
                                axarr))

                    # process first (len_temporal-1) frames
                    if i < 2 * len_temporal - 2:
                        smaps.append(
                            process(model, torch.flip(clip, [1]), path_outdata,
                                    i - len_temporal + 1, snippet[-1], f,
                                    axarr))

                    del snippet[0]
                read_flag, img = capture.read()
                i += 1
            capture.release()
            smaps = np.asarray(smaps)
            np.save(saliency_map_path, smaps)