def setup_dataloaders(cfg, tokenizer): LOGGER.info("Init. train_loader and val_loader...") train_loaders = {} for db in cfg.train_datasets: train_loaders[db.name] = mk_captions_pretrain_dataloader( dataset_name=db.name, vis_format=db.vis_format, anno_path=db.txt, img_lmdb_dir=db.img, cfg=cfg, tokenizer=tokenizer, is_train=True) if "ratio" in db: train_loaders[db.name] = (train_loaders[db.name], db.ratio) val_loaders = {} for db in cfg.val_datasets: val_loaders[db.name] = mk_captions_pretrain_dataloader( dataset_name=db.name, vis_format=db.vis_format, anno_path=db.txt, img_lmdb_dir=db.img, cfg=cfg, tokenizer=tokenizer, is_train=False) return train_loaders, val_loaders
def __init__(self, opts, model, optimizer): if exists(f"{opts.output_dir}/log/args.json"): restore_opts = json.load( open(f'{opts.output_dir}/log/args.json', 'r')) with open(join( opts.output_dir, 'log', 'restore_args.json'), 'w') as writer: json.dump(vars(opts), writer, indent=4) # assert opts == edict(restore_opts) # keep 2 checkpoints in case of corrupted self.save_path = f'{opts.output_dir}/restore.pt' self.backup_path = f'{opts.output_dir}/restore_backup.pt' self.model = model self.optimizer = optimizer self.save_steps = int(opts.save_steps_ratio * opts.num_train_steps) self.amp = opts.fp16 # since saving to or loading from azure blob fails sometimes self.max_save_load_trial = 10 if exists(self.save_path) or exists(self.backup_path): LOGGER.info('found previous checkpoint. try to resume...') # with retrial, as azure blob fails occasionally. restore_trial = 0 while restore_trial < self.max_save_load_trial: LOGGER.info(f"TrainingRestorer restore trial NO. {restore_trial}") try: self.restore(opts) break except Exception as e: restore_trial += 1 else: self.global_step = 0
def __init__(self, opts, **ckpt_dict): if exists(opts.output_dir): restore_opts = json.load(open( f'{opts.output_dir}/log/args.json', 'r')) assert opts == edict(restore_opts) # keep 2 checkpoints in case of corrupted self.save_path = f'{opts.output_dir}/restore.pt' self.backup_path = f'{opts.output_dir}/restore_backup.pt' self.ckpt_dict = ckpt_dict self.save_steps = opts.save_steps self.amp = opts.fp16 # since saving to or loading from azure blob fails sometimes self.max_save_load_trial = 10 if exists(self.save_path) or exists(self.backup_path): LOGGER.info('found previous checkpoint. try to resume...') # with retrial, as azure blob fails occasionally. restore_trial = 0 while restore_trial < self.max_save_load_trial: LOGGER.info(f"TrainingRestorer restore trial NO. {restore_trial}") try: self.restore() break except Exception as e: restore_trial += 1 else: self.global_step = 0
def generate(self, history): try: print(history) history_ids = [self.tokenizer.encode(v) for v in history] input_ids = [self.tokenizer.cls_token_id] for history_id, history_utr in enumerate(history_ids): input_ids.extend(history_utr) input_ids.append(self.tokenizer.sep_token_id) # print(history_ids) # print(input_ids) input_ids = [copy.deepcopy(input_ids) for _ in range(self.batch_size)] curr_input_tensors = torch.tensor(input_ids).long().to(self.device) candidate_responses = self._make_dialogue_response(curr_input_tensors) assert len(candidate_responses) >= 1 best_response_ids = self._make_mmi_output(candidate_responses,history_ids) best_response_chars = self.tokenizer.convert_ids_to_tokens(best_response_ids) return best_response_chars except Exception as e: LOGGER.error("FAIL GEN: {}".format(str(e))) traceback.print_exc() return []
def __init__(self): try: self.device = 'cuda' if config_model.use_cuda else 'cpu' LOGGER.info('using device: {}'.format(self.device)) if self.device == 'cuda': os.environ["CUDA_VISIBLE_DEVICES"] = config_model.device_nums self.tokenizer = BertTokenizer(config_model.vocab_path) # dialogue model self.dialogue_model = GPT2LMHeadModel.from_pretrained(config_model.dialogue_model_path) self.dialogue_model.to(self.device) self.dialogue_model.eval() # mmi model self.mmi_model = GPT2LMHeadModel.from_pretrained(config_model.mmi_model_path) self.mmi_model.to(self.device) self.dialogue_model.eval() self.max_sequence_len = config_model.max_len self.batch_size = config_model.batch_size self.repetition_penalty = config_model.repetition_penalty self.temperature = config_model.temperature self.debug = config_model.debug self.topk = config_model.topk self.topp = config_model.topp except Exception as e: LOGGER.error("FAIL INIT: {}".format(str(e))) traceback.print_exc() sys.exit(-1)
def __getitem__(self, index): if self.vis_format == "image": # one image/video with multiple examples vis_id, examples = self.datalist[index] img_array = self._load_img(vis_id) # tensor, (T=1, C, H, W) else: # video num_retries = 3 # skip error videos for _ in range(num_retries): vis_id, examples = self.datalist[index] img_array, _ = self._load_video( vis_id) # tensor, (T=num_frm, C, H, W) # Select a random video if the current video was not able to access. if img_array is None: LOGGER.info( f"Failed to load examples with video: {vis_id}. " f"Will randomly sample an example as a replacement.") index = random.randint(0, len(self) - 1) continue else: break else: raise RuntimeError( f"Failed to fetch video after {num_retries} retries.") examples = [self._get_single_example(e, index) for e in examples] return dict( img=img_array, # (T, C, H, W) examples=examples, n_examples=len(examples) # used to create image feature copies. )
def get_history(self,session_id): try: if session_id not in self.history_dict or "history" not in self.history_dict[session_id]: return [] else: return self.history_dict[session_id]["history"][-self.max_history_len:] except Exception as e: LOGGER.error("FAIL update history: session_id: {}, error: {}".format(str(session_id), str(e))) return []
def validate(model, val_loader, eval_loader, cfg, train_global_step, eval_filepath): """use eval_score=False when doing inference on test sets where answers are not available""" model.eval() loss = 0. n_ex = 0 n_corrects = 0 st = time.time() debug_step = 5 for val_step, batch in enumerate(val_loader): # forward pass del batch["caption_ids"] outputs = forward_step(model, batch, cfg) targets = batch['labels'] loss += outputs["loss"].sum().item() if isinstance( outputs["loss"], torch.Tensor) else 0 n_ex += len(targets) if outputs["logits"].shape[1] == 2: n_corrects += (outputs["logits"].max( dim=-1)[1] == targets).sum().item() else: predictions = (torch.sigmoid(outputs["logits"]) > 0.5).long() predictions = predictions.view(outputs["loss"].shape[0], -1) targets = targets.view(outputs["loss"].shape[0], -1) matched = predictions[:, 0].squeeze() == targets[:, 0].squeeze() n_corrects += matched.sum().item() if cfg.debug and val_step >= debug_step: break loss = sum(all_gather_list(loss)) n_ex = sum(all_gather_list(n_ex)) n_corrects = sum(all_gather_list(n_corrects)) _, retrieval_metrics = inference_retrieval(model, eval_loader, eval_filepath, cfg) model.train() if hvd.rank() == 0: # average loss for each example acc = float(n_corrects / n_ex) val_log = {'valid/loss': float(loss / n_ex), 'valid/acc': acc} for ret_type, ret_m in retrieval_metrics.items(): val_log.update({ f"valid/{ret_type}_{k}": round(v, 4) for k, v in ret_m.items() }) TB_LOGGER.log_scalar_dict(val_log) LOGGER.info(f"validation finished in {int(time.time() - st)} seconds." f"itm_acc: {acc}. Retrieval res {retrieval_metrics}")
def restore(self): try: checkpoint = torch.load(self.save_path) except Exception: checkpoint = torch.load(self.backup_path) self.global_step = checkpoint['global_step'] for k in self.ckpt_dict: self.ckpt_dict[k].load_state_dict(_to_cuda(checkpoint[k])) if self.amp: amp.load_state_dict(checkpoint['amp_state_dict']) LOGGER.info(f'resume training from step {self.global_step}')
def step(self): self.global_step += 1 if self.global_step % self.save_steps == 0: # with retrial, as azure blob fails occasionally. save_trial = 0 while save_trial < self.max_save_load_trial: LOGGER.info(f"TrainingRestorer save trial NO. {save_trial}") try: self.save() break except Exception as e: save_trial += 1
def restore(self, opts): try: checkpoint = torch.load(self.save_path) except Exception: checkpoint = torch.load(self.backup_path) self.global_step = checkpoint['global_step'] self.model.load_state_dict(_to_cuda(checkpoint['model_state_dict'])) self.optimizer.load_state_dict( _to_cuda(checkpoint['optim_state_dict'])) if self.amp: amp.load_state_dict(checkpoint['amp_state_dict']) LOGGER.info(f'resume training from step {self.global_step}')
def update_history(self,session_id, new_input_text): try: if session_id not in self.history_dict: self.history_dict[session_id] = { "history": [], "modified_time": time.time() } self.history_dict[session_id]["history"].append(new_input_text) self.history_dict[session_id]["modified"] = time.time() return True except Exception as e: LOGGER.error("FAIL update history: session_id: {}, error: {}".format(str(session_id), str(e))) return False
def main(args): LOGGER.info('start') app = tornado.web.Application( handlers=urls, debug=False) http_server = tornado.httpserver.HTTPServer(app) http_server.bind(args.port) http_server.start(args.threads) #http_server.listen(options.port) print("START") tornado.ioloop.IOLoop.instance().start() print("INSTANT START")
def _make_dialogue_response(self, input_tensors): try: generated = [] finish_set = set() # 标记是否所有response均已生成结束,若第i个response生成结束,即生成了sep_token_id,则将i放入finish_set # 最多生成max_len个token for _ in range(self.max_sequence_len): outputs = self.dialogue_model(input_ids=input_tensors) next_token_logits = outputs[0][:, -1, :] # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率 for index in range(self.batch_size): for token_id in set([token_ids[index] for token_ids in generated]): next_token_logits[index][token_id] /= self.repetition_penalty next_token_logits = next_token_logits / self.temperature # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token for next_token_logit in next_token_logits: next_token_logit[self.tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') filtered_logits = self._top_k_top_p_filtering(next_token_logits, top_k=self.topk, top_p=self.topp) # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) # 判断是否有response生成了[SEP],将已生成了[SEP]的resposne进行标记 for index, token_id in enumerate(next_token[:, 0]): if token_id == self.tokenizer.sep_token_id: finish_set.add(index) # 检验是否所有的response均已生成[SEP] finish_flag = True # 是否所有的response均已生成[SEP]的token for index in range(self.batch_size): if index not in finish_set: # response批量生成未完成 finish_flag = False break if finish_flag: break generated.append([token.item() for token in next_token[:, 0]]) # 将新生成的token与原来的token进行拼接 input_tensors = torch.cat((input_tensors, next_token), dim=-1) candidate_responses = [] # 生成的所有候选response for batch_index in range(self.batch_size): response = [] for token_index in range(len(generated)): if generated[token_index][batch_index] != self.tokenizer.sep_token_id: response.append(generated[token_index][batch_index]) else: break candidate_responses.append(response) return candidate_responses except Exception as e: LOGGER.error("FAIL make response: {}".format(str(e))) traceback.print_exc() return []
def setup_dataloaders(cfg, tokenizer): LOGGER.info("Init. train_loader and val_loader...") train_loader = mk_video_ret_dataloader(anno_path=cfg.train_datasets[0].txt, lmdb_dir=cfg.train_datasets[0].img, cfg=cfg, tokenizer=tokenizer, is_train=True) val_loader = mk_video_ret_dataloader(anno_path=cfg.val_datasets[0].txt, lmdb_dir=cfg.val_datasets[0].img, cfg=cfg, tokenizer=tokenizer, is_train=False) img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std) train_loader = PrefetchLoader(train_loader, img_norm) val_loader = PrefetchLoader(val_loader, img_norm) return train_loader, val_loader
def _get_random_negative_caption(self, gt_index): gt_img_id, _ = self.datalist[gt_index] max_trials = 5 while max_trials > 0: neg_index = int(random.random() * len(self)) neg_img_id, neg_examples = self.datalist[neg_index] if neg_img_id == gt_img_id: max_trials -= 1 continue else: break if max_trials == 0: LOGGER.info(f"gt_filepath {gt_img_id} index {gt_index}, " f"neg_data filepath {neg_examples} index {neg_index}") raise Warning( f"The negative sampler cannot sample a true negative within 5 trials" ) neg_data = neg_examples[int(random.random() * len(neg_examples))] return neg_data["txt"]
def run(self): while True: time.sleep(1800) cur_update_time = time.time() expire_list = [] for key in self.history_dict.keys(): try: if not "history" in self.history_dict[key]: self.history_dict.pop(key) expire_list.append(json.dumps({ "session_id":key, "history":[], "last_modified":time.time() - 1800 })) if "modified_time" in self.history_dict[key] and type(self.history_dict[key]["modified_time"]) == float: if cur_update_time - self.history_dict[key]["modified_time"] > 1800: self.history_dict.pop(key) expire_list.append(json.dumps({ "session_id":key, "history": self.history_dict[key]["history"], "last_modified": self.history_dict[key]["modified_time"] })) else: self.history_dict.pop(key) expire_list.append(json.dumps({ "session_id":key, "history": self.history_dict[key]["history"], "last_modified": time.time() - 1800 })) except Exception as e: LOGGER.error("bad exec: {}, reason: {}".format(str(key),str(e))) traceback.print_exc() continue with open(self.expire_save_path, 'a') as fw: for expire_session in expire_list: fw.write(expire_session+'\n') with open(self.history_save_path, 'w') as fw: json.dump(self.history_dict,fw)
def save(self, step, model, optimizer=None, prefix="model"): model_path = join(self.output_dir, f"{prefix}_step_{step}.pt") state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.state_dict().items()} # with retrial, as azure blob fails occasionally. save_trial = 0 while save_trial < self.max_save_load_trial: try: LOGGER.info(f"ModelSaver save trial NO. {save_trial}") torch.save(state_dict, model_path) if optimizer is not None: optimizer_state_dict = \ {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in optimizer.state_dict().items()} dump = {'step': step, 'optimizer': optimizer_state_dict} torch.save( dump, f'{self.output_dir}/{prefix}_step_{step}_train_state.pt') break except Exception as e: save_trial += 1
def mk_video_ret_datalist(raw_datalist, cfg): """ Args: raw_datalist: list(dict) cfg: Returns: """ LOGGER.info(f"Loaded data size {len(raw_datalist)}") if cfg.data_ratio != 1.0: random.shuffle(raw_datalist) raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)] LOGGER.info( f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}" ) datalist = [] qid = 0 for raw_d in raw_datalist: d = dict(id=qid, txt=raw_d["caption"], vid_id=raw_d["clip_name"]) qid += 1 datalist.append(d) LOGGER.info(f"datalist {len(datalist)}") return datalist
def mk_msrvtt_mc_datalist(raw_datalist, cfg): """ Args: raw_datalist: list(dict) cfg: Returns: """ LOGGER.info(f"Loaded data size {len(raw_datalist)}") if cfg.data_ratio != 1.0: random.shuffle(raw_datalist) raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)] LOGGER.info( f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}" ) datalist = [] for raw_d in raw_datalist: d = dict( id=raw_d["qid"], vid_id=raw_d["clip_name"], answer=raw_d["answer"], options=raw_d["options"], ) datalist.append(d) LOGGER.info(f"datalist {len(datalist)}") return datalist
def compare_dict_difference(dict1, dict2, dict1_name="dict1", dict2_name="dict2", print_value_diff=True, verbose=False): """ Args: dict1: dict2: dict1_name: dict2_name: print_value_diff: bool, output dict value difference within shared keys for dict1 and dict2. In effect only when verbose == True verbose: """ keys1 = set(dict1.keys()) keys2 = set(dict2.keys()) shared_keys = keys1.intersection(keys2) keys1_unique = keys1.difference(shared_keys) keys2_unique = keys2.difference(shared_keys) key_diff_list = list(keys1_unique) + list(keys2_unique) # value difference in the shared keys in dict1 and dict2 value_diff_dict = {} for k in shared_keys: if dict1[k] != dict2[k]: value_diff_dict[k] = [(dict1_name, dict1[k]), (dict2_name, dict2[k])] if verbose: LOGGER.info("=" * 30 + "key difference") LOGGER.info(f"keys in {dict1_name} but not in {dict2_name}: " f"total {len(keys1_unique)}, {sorted(keys1_unique)}") LOGGER.info(f"keys in {dict2_name} but not in {dict1_name}: " f"total {len(keys2_unique)}, {sorted(keys2_unique)}") if verbose and print_value_diff: LOGGER.info("=" * 30 + "value difference") LOGGER.info(f"{json.dumps(value_diff_dict, indent=4)}") return value_diff_dict, key_diff_list
def save_training_meta(args): # args is an EasyDict object, treat it the same as a normal dict os.makedirs(join(args.output_dir, 'log'), exist_ok=True) os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True) # training args save_args_path = join(args.output_dir, 'log', 'args.json') save_json(args, save_args_path, save_pretty=True) # model args model_config = json.load(open(args.model_config)) save_model_config_path = join(args.output_dir, 'log', 'model_config.json') save_json(model_config, save_model_config_path, save_pretty=True) # save a copy of the codebase. !!!Do not store heavy file in your codebase when using it. code_dir = dirname(dirname(dirname(os.path.realpath(__file__)))) code_zip_filename = os.path.join(args.output_dir, "code.zip") LOGGER.info(f"Saving code from {code_dir} to {code_zip_filename}...") make_zipfile(code_dir, code_zip_filename, enclosing_dir="code", exclude_dirs_substring="results", exclude_dirs=["results", "debug_results", "__pycache__"], exclude_extensions=[".pyc", ".ipynb", ".swap"]) LOGGER.info(f"Saving code done.")
def mk_video_ret_dataloader(anno_path, lmdb_dir, cfg, tokenizer, is_train=True): """""" raw_datalist = load_jsonl(anno_path) datalist = mk_video_ret_datalist(raw_datalist, cfg) grouped = defaultdict(list) # examples grouped by image/video id for d in datalist: grouped[d["vid_id"]].append(d) LOGGER.info(f"grouped {len(grouped)}") # each group has a single image with multiple questions group_datalist = mk_input_group( grouped, max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1, # force 1 in eval, is_train=is_train) LOGGER.info(f"group_datalist {len(group_datalist)}") frm_sampling_strategy = cfg.frm_sampling_strategy if not is_train and frm_sampling_strategy == "rand": frm_sampling_strategy = "middle" dataset = ClipBertVideoRetrievalDataset( datalist=group_datalist, tokenizer=tokenizer, img_lmdb_dir=lmdb_dir, max_img_size=cfg.max_img_size, max_txt_len=cfg.max_txt_len, fps=cfg.fps, num_frm=cfg.num_frm, frm_sampling_strategy=frm_sampling_strategy, itm_neg_size=cfg.itm_neg_size, ensemble_n_clips=cfg.train_n_clips, random_sample_clips=cfg.random_sample_clips) LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, " f"each group {cfg.max_n_example_per_group if is_train else 1}") if cfg.do_inference: batch_size = cfg.inference_batch_size else: batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=is_train) vqa_collator = VideoRetrievalCollator(tokenizer=tokenizer, max_length=cfg.max_txt_len) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler, num_workers=cfg.n_workers, pin_memory=cfg.pin_mem, collate_fn=vqa_collator.collate_batch) return dataloader
def post(self): response = {'status': 0, 'data': {}, 'message': 'fail'} try: session_id = self.get_argument("sessionId") input_text = self.get_argument("text") except Exception as e: LOGGER.error("FAIL receive args: {}".format(str(e))) response['message'] = str(e) self.finish(response) return try: st = time.time() session_id = int(session_id) keeper_partition = session_id % config_instance.num_keepers keepers[keeper_partition].update_history(session_id=session_id, new_input_text=input_text) history = keepers[keeper_partition].get_history( session_id=session_id) generate_chars = worker.generate(history) print(generate_chars) if len(generate_chars) == 0: response['message'] = "fail generate response text" self.finish(response) generate = "".join(generate_chars) keepers[keeper_partition].update_history(session_id=session_id, new_input_text=generate) body_info = { 'sessionId': session_id, 'input': input_text, 'output': generate } print(body_info) LOGGER.info( "receive: session_id: {}, input_text: {}, back: {}, cost: {} ms" .format(str(session_id), input_text, json.dumps(body_info), (time.time() - st) * 1000)) response['data'] = body_info response['status'] = 1 response['message'] = 'success' self.finish(response) except Exception as e: LOGGER.error("FAIL make resonse: {}".format(str(e))) response['message'] = str(e) self.finish(response) return
SEED = 42 DATA = 'TREC' CUDA = False DEBUG = False DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MODE = 'static' # nonstatic[] WORD_VECTORS = 'rand' # choices = [rand, word2vec] DIM = 300 # set seed np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) # if use multi-GPU LOGGER.info(MODE) LOGGER.info(WORD_VECTORS) ######################### # Custom class Setting ######################### class GSST(SST): urls = ['http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'] dirname = 'trees' name = 'sst' @staticmethod def sort_key(ex): return len(ex.text)
def start_training(cfg): set_random_seed(cfg.seed) n_gpu = hvd.size() cfg.n_gpu = n_gpu device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) if hvd.rank() != 0: LOGGER.disabled = True LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format(device, n_gpu, hvd.rank(), bool(cfg.fp16))) model = setup_model(cfg, device=device) model.train() optimizer = setup_e2e_optimizer(model, cfg) # Horovod: (optional) compression algorithm.compressin compression = hvd.Compression.none optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=compression) # Horovod: broadcast parameters & optimizer state. hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) model, optimizer = amp.initialize(model, optimizer, enabled=cfg.fp16, opt_level='O2', keep_batchnorm_fp32=True) # prepare data tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir) train_loader, val_loader = setup_dataloaders(cfg, tokenizer) eval_loader = mk_video_ret_eval_dataloader( anno_path=cfg.val_datasets[0].txt, lmdb_dir=cfg.val_datasets[0].img, cfg=cfg, tokenizer=tokenizer, ) # compute the number of steps and update cfg total_n_examples = len(train_loader.dataset) * cfg.max_n_example_per_group total_train_batch_size = int(n_gpu * cfg.train_batch_size * cfg.gradient_accumulation_steps * cfg.max_n_example_per_group) cfg.num_train_steps = int( math.ceil(1. * cfg.num_train_epochs * total_n_examples / total_train_batch_size)) cfg.valid_steps = int( math.ceil(1. * cfg.num_train_steps / cfg.num_valid / cfg.min_valid_steps)) * cfg.min_valid_steps actual_num_valid = int( math.floor(1. * cfg.num_train_steps / cfg.valid_steps)) + 1 # restore restorer = TrainingRestorer(cfg, model, optimizer) global_step = restorer.global_step TB_LOGGER.global_step = global_step if hvd.rank() == 0: LOGGER.info("Saving training meta...") save_training_meta(cfg) path = join(cfg.output_dir, 'log', "detectron2_model_cfg.yaml") with open(path, "w") as f: f.write(model.cnn.config_file) LOGGER.info("Saving training done...") TB_LOGGER.create(join(cfg.output_dir, 'log')) pbar = tqdm(total=cfg.num_train_steps) model_saver = ModelSaver(join(cfg.output_dir, "ckpt")) add_log_to_file(join(cfg.output_dir, "log", "log.txt")) else: LOGGER.disabled = True pbar = NoOp() model_saver = NoOp() restorer = NoOp() if global_step > 0: pbar.update(global_step) LOGGER.info(cfg) LOGGER.info("Starting training...") LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") LOGGER.info( f" Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}") LOGGER.info(f" max_n_example_per_group = {cfg.max_n_example_per_group}") LOGGER.info(f" Accumulate steps = {cfg.gradient_accumulation_steps}") LOGGER.info( f" Total batch size = #GPUs * Single-GPU batch size * " f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}" ) LOGGER.info(f" Total #epochs = {cfg.num_train_epochs}") LOGGER.info(f" Total #steps = {cfg.num_train_steps}") LOGGER.info( f" Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times" ) # quick hack for amp delay_unscale bug with optimizer.skip_synchronize(): optimizer.zero_grad() if global_step == 0: optimizer.step() debug_step = 3 running_loss = RunningMeter('train_loss') for step, batch in enumerate(InfiniteIterator(train_loader)): # forward pass del batch["caption_ids"] mini_batch = dict() for k, v in batch.items(): if k != "visual_inputs": mini_batch[k] = v pool_method = cfg.score_agg_func # could be 1, where only a single clip is used num_clips = cfg.train_n_clips num_frm = cfg.num_frm # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W) bsz = batch["visual_inputs"].shape[0] new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:] visual_inputs = batch["visual_inputs"].view(*new_visual_shape) logits = [] for clip_idx in range(num_clips): # (B, num_frm, C, H, W) mini_batch["visual_inputs"] = visual_inputs[:, clip_idx] mini_batch["n_examples_list"] = batch["n_examples_list"] outputs = forward_step(model, mini_batch, cfg) logits.append(outputs["logits"]) # the losses are cross entropy and mse, no need to * num_labels logits = torch.stack(logits) # (num_frm, B, 5) if pool_method == "mean": logits = logits.mean(0) # (B, 5) elif pool_method == "max": logits = logits.max(0)[0] # (B, 5) elif pool_method == "lse": logits = logits.permute( 1, 0, 2).contiguous() # (B, num_frm, 5), pooling will be done in CE else: raise ValueError( f"Invalid value for pool_method, " f"got {pool_method}, expect one of [`mean`, `max`, `lse`]") if pool_method == "lse": out = torch.logsumexp(logits.view(logits.shape[0], -1), dim=-1, keepdim=True) \ - torch.logsumexp(logits, dim=1) loss = torch.gather(out, -1, batch["labels"].view(-1, 1)) else: _, loss = model.transformer.calc_loss( logits, batch["labels"], sample_size=len(batch["n_examples_list"])) loss = loss.mean() running_loss(loss.item()) # backward pass delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0 with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale) as scaled_loss: scaled_loss.backward() zero_none_grad(model) optimizer.synchronize() # optimizer if (step + 1) % cfg.gradient_accumulation_steps == 0: global_step += 1 # learning rate scheduling n_epoch = int(1. * total_train_batch_size * global_step / total_n_examples) # learning rate scheduling transformer lr_this_step_transformer = get_lr_sched( global_step, cfg.decay, cfg.learning_rate, cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio, decay_epochs=cfg.step_decay_epochs, multi_step_epoch=n_epoch) # learning rate scheduling cnn lr_this_step_cnn = get_lr_sched( global_step, cfg.cnn_lr_decay, cfg.cnn_learning_rate, cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio, decay_epochs=cfg.cnn_step_decay_epochs, multi_step_epoch=n_epoch) # Hardcoded param group length assert len(optimizer.param_groups) == 8 for pg_n, param_group in enumerate(optimizer.param_groups): if pg_n in [0, 1]: param_group['lr'] = (cfg.transformer_lr_mul * lr_this_step_transformer) elif pg_n in [2, 3]: param_group['lr'] = lr_this_step_transformer elif pg_n in [4, 5]: param_group['lr'] = (cfg.cnn_lr_mul * lr_this_step_cnn) else: param_group['lr'] = lr_this_step_cnn TB_LOGGER.add_scalar("train/lr_transformer", lr_this_step_transformer, global_step) TB_LOGGER.add_scalar("train/lr_cnn", lr_this_step_cnn, global_step) TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step) # update model params if cfg.grad_norm != -1: grad_norm = clip_grad_norm_(amp.master_params(optimizer), cfg.grad_norm) TB_LOGGER.add_scalar("train/grad_norm", grad_norm, global_step) TB_LOGGER.step() # Check if there is None grad none_grads = [ p[0] for p in model.named_parameters() if p[1].requires_grad and p[1].grad is None ] assert len(none_grads) == 0, f"{none_grads}" with optimizer.skip_synchronize(): optimizer.step() optimizer.zero_grad() restorer.step() pbar.update(1) # checkpoint if global_step % cfg.valid_steps == 0: LOGGER.info(f'Step {global_step}: start validation') validate(model, val_loader, eval_loader, cfg, global_step, eval_filepath=cfg.val_datasets[0].txt) model_saver.save(step=global_step, model=model) if global_step >= cfg.num_train_steps: break if cfg.debug and global_step >= debug_step: break if global_step % cfg.valid_steps != 0: LOGGER.info(f'Step {global_step}: start validation') validate(model, val_loader, eval_loader, cfg, global_step, eval_filepath=cfg.val_datasets[0].txt) model_saver.save(step=global_step, model=model)
def start_inference(cfg): set_random_seed(cfg.seed) n_gpu = hvd.size() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) if hvd.rank() != 0: LOGGER.disabled = True inference_res_dir = join( cfg.output_dir, f"mc_results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/" f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}" ) if hvd.rank() == 0: os.makedirs(inference_res_dir, exist_ok=True) save_json(cfg, join(inference_res_dir, "raw_args.json"), save_pretty=True) LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format(device, n_gpu, hvd.rank(), bool(cfg.fp16))) # overwrite cfg with stored_cfg, # but skip keys containing the keyword 'inference' stored_cfg_path = join(cfg.output_dir, "log/args.json") stored_cfg = edict(load_json(stored_cfg_path)) for k, v in cfg.items(): if k in stored_cfg and "inference" not in k and "output_dir" not in k: setattr(cfg, k, stored_cfg[k]) # setup models cfg.model_config = join(cfg.output_dir, "log/model_config.json") e2e_weights_path = join(cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt") cfg.e2e_weights_path = e2e_weights_path model = setup_model(cfg, device=device) model.eval() # FIXME separate scaling for each loss model = amp.initialize(model, enabled=cfg.fp16, opt_level='O2') global_step = 0 # prepare data tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir) cfg.data_ratio = 1. val_loader = mk_msrvtt_mc_eval_dataloader( anno_path=cfg.inference_txt_db, lmdb_dir=cfg.inference_img_db, cfg=cfg, tokenizer=tokenizer, ) LOGGER.info(cfg) LOGGER.info("Starting inference...") LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****") LOGGER.info(f" Batch size = {cfg.inference_batch_size}") LOGGER.info(f'Step {global_step}: start validation') ret_results, ret_scores = inference_retrieval_mc(model, val_loader, cfg.inference_txt_db, cfg) if hvd.rank() == 0: save_json(cfg, join(inference_res_dir, "merged_args.json"), save_pretty=True) save_json(ret_results, join(inference_res_dir, "mc_test_results.json"), save_pretty=True) save_json(ret_scores, join(inference_res_dir, "mc_test_scores.json"), save_pretty=True)
def inference_retrieval_mc(model, val_loader, eval_file_path, cfg, n_options=5): model.eval() pred_id2ans = dict() st = time.time() LOGGER.info(f"Evaluate retrieval MC: {len(val_loader)}") if hvd.rank() == 0: pbar = tqdm(total=len(val_loader), desc="eval") for batch in val_loader: # compile shared text part question_ids = batch["question_ids"] bsz = len(question_ids) del batch["question_ids"] mini_batch = dict() for k, v in batch.items(): if k not in ["visual_inputs", "meta"]: mini_batch[k] = v # multi-frame test, scores across frames of the same video will be pooled together # batch["visual_inputs"] (B, T, C, H, W) pool_method = cfg.score_agg_func # could be 1, where only a single clip is evaluated num_clips = cfg.inference_n_clips num_frm = cfg.num_frm # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W) new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:] visual_inputs = batch["visual_inputs"].view(*new_visual_shape) logits = [] for clip_idx in range(num_clips): mini_batch["visual_inputs"] = visual_inputs[:, clip_idx] mini_batch["n_examples_list"] = batch["n_examples_list"] outputs = forward_step(model, mini_batch, cfg, n_options=n_options) logits.append(outputs["logits"].cpu()) logits = torch.stack(logits) # (num_frm, B, 1 or 2) if pool_method == "mean": logits = logits.mean(0) # (B, 1 or 2) elif pool_method == "max": logits = logits.max(0)[0] # (B, 1 or 2) elif pool_method == "lse": logits = logits.permute( 1, 0, 2).contiguous() # (B, num_frm, 5), pooling will be done in CE logits = torch.logsumexp( logits, dim=1) # torch.exp alone might be too large and unstable else: raise ValueError( f"Invalid value for pool_method, " f"got {pool_method}, expect one of [`mean`, `max`, `lse`]") if logits.shape[1] == 2: probs = F.softmax(logits, dim=1)[:, 1] else: probs = torch.sigmoid(logits.squeeze()) # B probs = probs.view(-1, n_options) # (B, 5) pred_answers = probs.max(1)[1].tolist() # (B, ) for qid, pred_ans in zip(question_ids, pred_answers): pred_id2ans[qid] = int(pred_ans) if hvd.rank() == 0: pbar.update(1) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) n_gpu = hvd.size() eval_dir = join( cfg.output_dir, f"results_mc_{os.path.splitext(os.path.basename(eval_file_path))[0]}") os.makedirs(eval_dir, exist_ok=True) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json( pred_id2ans, join(eval_dir, f"tmp_results_mc_rank{hvd.rank()}.json")) break except Exception as e: print(f"Saving exception: {e}") save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: pred_id2ans = [] for rk in range(n_gpu): pred_id2ans.append( load_json(join(eval_dir, f"tmp_results_mc_rank{rk}.json"))) pred_id2ans = merge_dicts(pred_id2ans) LOGGER.info('results joined') if hvd.rank() == 0: retrieval_qa_metrics = val_loader.dataset.evaluate_qa_accuracy( pred_id2ans, force_same=True) LOGGER.info( f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_qa_metrics}" ) else: retrieval_qa_metrics = None model.train() return pred_id2ans, retrieval_qa_metrics
def setup_model(cfg, device=None): LOGGER.info("Setup model...") # has to be a BertConfig instance model_cfg = load_json(cfg.model_config) model_cfg = BertConfig(**model_cfg) # add downstream model config add_attr_list = [ "num_labels", "classifier", "cls_hidden_scale", "loss_type", "margin", ] for k in add_attr_list: setattr(model_cfg, k, cfg[k]) transformer_model_cls = ClipBertForVideoTextRetrieval # we separate the CNN and the transformer in order to use different optimizer for each # transformer still has a CNN layer inside, used to down sample grid. LOGGER.info("setup e2e model") model = ClipBert(model_cfg, input_format=cfg.img_input_format, detectron2_model_cfg=cfg.detectron2_model_cfg, transformer_cls=transformer_model_cls) if cfg.e2e_weights_path: LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}") load_state_dict_with_mismatch(model, cfg.e2e_weights_path) else: LOGGER.info(f"Loading cnn weights from {cfg.detectron2_weights_path}") LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}") model.load_separate_ckpt(cnn_weights_path=cfg.detectron2_weights_path, bert_weights_path=cfg.bert_weights_path) if cfg.freeze_cnn: model.freeze_cnn_backbone() model.to(device) LOGGER.info("Setup model done!") return model
def inference_retrieval(model, val_loader, eval_file_path, cfg): model.eval() retrieval_res = [] # list(dict): dict(vid_id=..., txt_id=..., score=...) st = time.time() eval_bsz = cfg.inference_batch_size if cfg.do_inference else cfg.eval_retrieval_batch_size LOGGER.info(f"Evaluate retrieval #video per GPU: {len(val_loader)}") if hvd.rank() == 0: pbar = tqdm(total=len(val_loader), desc="eval") for batch in val_loader: # each batch contains 1 video and N (=1000) captions n_mini_batches = math.ceil(len(batch["caption_ids"]) / eval_bsz) vid_id = batch["vid_id"] for idx in range(n_mini_batches): # compile shared text part mini_batch = dict() for k in ["text_input_ids", "text_input_mask", "labels"]: if batch[k] is not None: mini_batch[k] = batch[k][idx * eval_bsz:(idx + 1) * eval_bsz] else: mini_batch[k] = None caption_ids = batch["caption_ids"][idx * eval_bsz:(idx + 1) * eval_bsz] # bsz = len(caption_ids) mini_batch["n_examples_list"] = [len(caption_ids)] # multi-frame test, scores across frames of the same video will be pooled together pool_method = cfg.score_agg_func # could be 1, where only a single clip is evaluated num_clips = cfg.inference_n_clips num_frm = cfg.num_frm # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W) new_visual_shape = (1, num_clips, num_frm) + batch["visual_inputs"].shape[2:] visual_inputs = batch["visual_inputs"].view(*new_visual_shape) logits = [] for clip_idx in range(num_clips): mini_batch["visual_inputs"] = visual_inputs[:, clip_idx] outputs = forward_step(model, mini_batch, cfg) logits.append(outputs["logits"].cpu()) logits = torch.stack(logits) # (num_frm, B, 1 or 2) if pool_method == "mean": logits = logits.mean(0) # (B, 1 or 2) elif pool_method == "max": logits = logits.max(0)[0] # (B, 1 or 2) elif pool_method == "lse": logits = logits.permute(1, 0, 2).contiguous( ) # (B, num_frm, 5), pooling will be done in CE logits = torch.logsumexp( logits, dim=1) # torch.exp alone might be too large and unstable else: raise ValueError( f"Invalid value for pool_method, " f"got {pool_method}, expect one of [`mean`, `max`, `lse`]") if logits.shape[1] == 2: probs = F.softmax(logits, dim=1)[:, 1].tolist() else: probs = torch.sigmoid(logits.squeeze()).tolist() # B for cap_id, score in zip(caption_ids, probs): retrieval_res.append( dict(vid_id=vid_id, txt_id=cap_id, score=round(score, 4))) if hvd.rank() == 0: pbar.update(1) # ###### Saving with Horovod #################### # dummy sync _ = None all_gather_list(_) n_gpu = hvd.size() eval_dir = join( cfg.output_dir, f"results_{os.path.splitext(os.path.basename(eval_file_path))[0]}") os.makedirs(eval_dir, exist_ok=True) if n_gpu > 1: # with retrial, as azure blob fails occasionally. max_save_load_trial = 10 save_trial = 0 while save_trial < max_save_load_trial: try: LOGGER.info(f"Save results trial NO. {save_trial}") save_json(retrieval_res, join(eval_dir, f"tmp_results_rank{hvd.rank()}.json")) break except Exception as e: print(f"Saving exception: {e}") save_trial += 1 # dummy sync _ = None all_gather_list(_) # join results if n_gpu > 1 and hvd.rank() == 0: retrieval_res = [] for rk in range(n_gpu): retrieval_res.extend( load_json(join(eval_dir, f"tmp_results_rank{rk}.json"))) LOGGER.info('results joined') if hvd.rank() == 0: retrieval_metrics = eval_retrieval(retrieval_res, val_loader.dataset.gt_cap_id2vid_id, val_loader.dataset.id2data) LOGGER.info( f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_metrics}" ) else: retrieval_metrics = None model.train() return retrieval_res, retrieval_metrics