def main(): ''' concise script for training ''' # optional two command-line arguments path_indata = './DHF1K' path_output = './output' if len(sys.argv) > 1: path_indata = sys.argv[1] if len(sys.argv) > 2: path_output = sys.argv[2] # we checked that using only 2 gpus is enough to produce similar results num_gpu = 2 pile = 5 batch_size = 8 num_iters = 1000 len_temporal = 32 file_weight = './S3D_kinetics400.pt' path_output = os.path.join(path_output, time.strftime("%m-%d_%H-%M-%S")) if not os.path.isdir(path_output): os.makedirs(path_output) model = TASED_v2() # load the weight file and copy the parameters if os.path.isfile(file_weight): print('loading weight file') weight_dict = torch.load(file_weight) model_dict = model.state_dict() for name, param in weight_dict.items(): if 'module' in name: name = '.'.join(name.split('.')[1:]) if 'base.' in name: bn = int(name.split('.')[1]) sn_list = [0, 5, 8, 14] sn = sn_list[0] if bn >= sn_list[1] and bn < sn_list[2]: sn = sn_list[1] elif bn >= sn_list[2] and bn < sn_list[3]: sn = sn_list[2] elif bn >= sn_list[3]: sn = sn_list[3] name = '.'.join(name.split('.')[2:]) name = 'base%d.%d.' % (sn_list.index(sn) + 1, bn - sn) + name if name in model_dict: if param.size() == model_dict[name].size(): model_dict[name].copy_(param) else: print(' size? ' + name, param.size(), model_dict[name].size()) else: print(' name? ' + name) print(' loaded') else: print('weight file?') # parameter setting for fine-tuning params = [] for key, value in dict(model.named_parameters()).items(): if 'convtsp' in key: params += [{'params': [value], 'key': key + '(new)'}] else: params += [{'params': [value], 'lr': 0.001, 'key': key}] optimizer = torch.optim.SGD(params, lr=0.1, momentum=0.9, weight_decay=2e-7) criterion = KLDLoss() model = model.cuda() model = torch.nn.DataParallel(model, device_ids=range(num_gpu)) torch.backends.cudnn.benchmark = False model.train() train_loader = InfiniteDataLoader(DHF1KDataset(path_indata, len_temporal), batch_size=batch_size, shuffle=True, num_workers=24) i, step = 0, 0 loss_sum = 0 start_time = time.time() for clip, annt in islice(train_loader, num_iters * pile): with torch.set_grad_enabled(True): output = model(clip.cuda()) loss = criterion(output, annt.cuda()) loss_sum += loss.item() loss.backward() if (i + 1) % pile == 0: optimizer.step() optimizer.zero_grad() step += 1 # whole process takes less than 3 hours print('iteration: [%4d/%4d], loss: %.4f, %s' % (step, num_iters, loss_sum / pile, timedelta(seconds=int(time.time() - start_time))), flush=True) loss_sum = 0 # adjust learning rate if step in [750, 950]: for opt in optimizer.param_groups: if 'new' in opt['key']: opt['lr'] *= 0.1 if step % 25 == 0: torch.save(model.state_dict(), os.path.join(path_output, 'iter_%04d.pt' % step)) i += 1
def main(): ''' read frames in path_indata and generate frame-wise saliency maps in path_output ''' # optional two command-line arguments path_indata = './example' path_output = './output' if len(sys.argv) > 1: path_indata = sys.argv[1] if len(sys.argv) > 2: path_output = sys.argv[2] if not os.path.isdir(path_output): os.makedirs(path_output) len_temporal = 32 file_weight = './TASED_updated.pt' model = TASED_v2() # load the weight file and copy the parameters if os.path.isfile(file_weight): print ('loading weight file') weight_dict = torch.load(file_weight) model_dict = model.state_dict() for name, param in weight_dict.items(): if 'module' in name: name = '.'.join(name.split('.')[1:]) if name in model_dict: if param.size() == model_dict[name].size(): model_dict[name].copy_(param) else: print (' size? ' + name, param.size(), model_dict[name].size()) else: print (' name? ' + name) print (' loaded') else: print ('weight file?') model = model.cuda() torch.backends.cudnn.benchmark = False model.eval() # iterate over the path_indata directory list_indata = [d for d in os.listdir(path_indata) if os.path.isdir(os.path.join(path_indata, d))] list_indata.sort() for dname in list_indata: print ('processing ' + dname) list_frames = [f for f in os.listdir(os.path.join(path_indata, dname)) if os.path.isfile(os.path.join(path_indata, dname, f))] list_frames.sort() # process in a sliding window fashion if len(list_frames) >= 2*len_temporal-1: path_outdata = os.path.join(path_output, dname) if not os.path.isdir(path_outdata): os.makedirs(path_outdata) snippet = [] for i in range(len(list_frames)): img = cv2.imread(os.path.join(path_indata, dname, list_frames[i])) img = cv2.resize(img, (384, 224)) img = img[...,::-1] snippet.append(img) if i >= len_temporal-1: clip = transform(snippet) process(model, clip, path_outdata, i) # process first (len_temporal-1) frames if i < 2*len_temporal-2: process(model, torch.flip(clip, [1]), path_outdata, i-len_temporal+1) del snippet[0] else: print (' more frames are needed')
def main(): ''' read frames in path_indata and generate frame-wise saliency maps in path_output ''' # optional two command-line arguments args = parse_args() config_path = args.config config = edict(yaml.load(open(config_path))) # path_indata = '/data2/yuanx/QoEData/sailency-models/TASED-Net/example/' # path_output = '/data2/yuanx/QoEData/sailency-models/TASED-Net/output/' # model_path = '/data2/yuanx/QoEData/sailency-models/TASED-Net/models/' path_output = config.output_path model_path = config.model_path if not os.path.isdir(path_output): os.makedirs(path_output) len_temporal = 32 file_weight = os.path.join(model_path, 'TASED_v2.pt') model = TASED_v2() # load the weight file and copy the parameters if os.path.isfile(file_weight): print('loading weight file') weight_dict = torch.load(file_weight) model_dict = model.state_dict() for name, param in weight_dict.items(): if 'module' in name: name = '.'.join(name.split('.')[1:]) if name in model_dict: if param.size() == model_dict[name].size(): model_dict[name].copy_(param) else: print(' size? ' + name, param.size(), model_dict[name].size()) else: print(' name? ' + name) print(' loaded') else: print('weight file?') model = model.cuda() torch.backends.cudnn.benchmark = False model.eval() # iterate over the path_indata directory # list_indata = [d for d in os.listdir(path_indata) if os.path.isdir(os.path.join(path_indata, d))] list_video = [ os.path.join(config.data.base_path, v) for v in config.data.video_list ] # list_indata.sort() for i_vname, vname in enumerate(list_video): # import pdb; pdb.set_trace() for dash_idx in range(config.data.quailty_num[i_vname]): smaps = [] v_name = vname + str((dash_idx)).zfill(2) print('processing ' + v_name) # import pdb; pdb.set_trace() # list_frames = [f for f in os.listdir(os.path.join(path_indata, vname)) if os.path.isfile(os.path.join(path_indata, vname, f))] # list_frames.sort() v_name = v_name + '.' + config.data.video_suffix # vname = vname + config.data.quailty_index + '.' + config.data.video_suffix capture = cv2.VideoCapture(v_name) read_flag, img = capture.read() i = 0 # process in a sliding window fashion # suppose list_frames always > 2*len_temporal # if len(list_frames) >= 2*len_temporal-1: path_outdata = os.path.join(path_output, v_name.split('/')[-1].split('.')[0]) encoded_vid_path = os.path.join(path_outdata, "sailency.mp4") saliency_map_path = os.path.join(path_outdata, "sailency") if not os.path.isdir(path_outdata): os.makedirs(path_outdata) f, axarr = plt.subplots(1, 2, figsize=(10, 3)) snippet = [] # for i in range(len(list_frames)): while (read_flag): # print(i) # img = cv2.imread(os.path.join(path_indata, vname, list_frames[i])) img = cv2.resize(img, (384, 224)) img = img[..., ::-1] snippet.append(img) if i >= len_temporal - 1: clip = transform(snippet) smaps.append( process(model, clip, path_outdata, i, snippet[-1], f, axarr)) # process first (len_temporal-1) frames if i < 2 * len_temporal - 2: smaps.append( process(model, torch.flip(clip, [1]), path_outdata, i - len_temporal + 1, snippet[-1], f, axarr)) del snippet[0] read_flag, img = capture.read() i += 1 capture.release() smaps = np.asarray(smaps) np.save(saliency_map_path, smaps)