def eval_epoch(model, validation_data, device, opt): """The same setting as training, where ground-truth word x_{t-1} is used to predict next word x_{t}, not realistic for real inference""" model.eval() total_loss = 0 n_word_total = 0 n_word_correct = 0 with torch.no_grad(): for batch in tqdm(validation_data, mininterval=2, desc=" Validation =>"): if opt.recurrent: # prepare data batched_data = [ prepare_batch_inputs(step_data, device=device, non_blocking=opt.pin_memory) for step_data in batch[0] ] input_ids_list = [e["input_ids"] for e in batched_data] video_features_list = [ e["video_feature"] for e in batched_data ] input_masks_list = [e["input_mask"] for e in batched_data] token_type_ids_list = [ e["token_type_ids"] for e in batched_data ] input_labels_list = [e["input_labels"] for e in batched_data] loss, pred_scores_list = model(input_ids_list, video_features_list, input_masks_list, token_type_ids_list, input_labels_list) else: # single sentence if opt.untied or opt.mtrans: # prepare data batched_data = prepare_batch_inputs( batch[0], device=device, non_blocking=opt.pin_memory) video_feature = batched_data["video_feature"] video_mask = batched_data["video_mask"] text_ids = batched_data["text_ids"] text_mask = batched_data["text_mask"] text_labels = batched_data["text_labels"] loss, pred_scores = model(video_feature, video_mask, text_ids, text_mask, text_labels) pred_scores_list = [pred_scores] input_labels_list = [text_labels] else: # prepare data batched_data = prepare_batch_inputs( batch[0], device=device, non_blocking=opt.pin_memory) input_ids = batched_data["input_ids"] video_features = batched_data["video_feature"] input_masks = batched_data["input_mask"] token_type_ids = batched_data["token_type_ids"] input_labels = batched_data["input_labels"] loss, pred_scores = model(input_ids, video_features, input_masks, token_type_ids, input_labels) pred_scores_list = [pred_scores] input_labels_list = [input_labels] # keep logs n_correct = 0 n_word = 0 for pred, gold in zip(pred_scores_list, input_labels_list): n_correct += cal_performance(pred, gold) valid_label_mask = gold.ne(RCDataset.IGNORE) n_word += valid_label_mask.sum().item() n_word_total += n_word n_word_correct += n_correct total_loss += loss.item() if opt.debug: break loss_per_word = 1.0 * total_loss / n_word_total accuracy = 1.0 * n_word_correct / n_word_total return loss_per_word, accuracy
def train_epoch(model, training_data, optimizer, ema, device, opt, writer, epoch): model.train() total_loss = 0 n_word_total = 0 n_word_correct = 0 torch.autograd.set_detect_anomaly(True) for batch_idx, batch in tqdm(enumerate(training_data), mininterval=2, desc=" Training =>", total=len(training_data)): niter = epoch * len(training_data) + batch_idx writer.add_scalar("Train/LearningRate", float(optimizer.param_groups[0]["lr"]), niter) if opt.recurrent: # prepare data batched_data = [ prepare_batch_inputs(step_data, device=device, non_blocking=opt.pin_memory) for step_data in batch[0] ] input_ids_list = [e["input_ids"] for e in batched_data] video_features_list = [e["video_feature"] for e in batched_data] input_masks_list = [e["input_mask"] for e in batched_data] token_type_ids_list = [e["token_type_ids"] for e in batched_data] input_labels_list = [e["input_labels"] for e in batched_data] if opt.debug: def print_info(batched_data, step_idx, batch_idx): cur_data = batched_data[step_idx] logger.info("input_ids \n{}".format( cur_data["input_ids"][batch_idx])) logger.info("input_mask \n{}".format( cur_data["input_mask"][batch_idx])) logger.info("input_labels \n{}".format( cur_data["input_labels"][batch_idx])) logger.info("token_type_ids \n{}".format( cur_data["token_type_ids"][batch_idx])) print_info(batched_data, 0, 0) # forward & backward optimizer.zero_grad() loss, pred_scores_list = model(input_ids_list, video_features_list, input_masks_list, token_type_ids_list, input_labels_list) else: # single sentence if opt.untied or opt.mtrans: # prepare data batched_data = prepare_batch_inputs( batch[0], device=device, non_blocking=opt.pin_memory) video_feature = batched_data["video_feature"] video_mask = batched_data["video_mask"] text_ids = batched_data["text_ids"] text_mask = batched_data["text_mask"] text_labels = batched_data["text_labels"] if opt.debug: def print_info(cur_data, batch_idx): logger.info("text_ids \n{}".format( cur_data["text_ids"][batch_idx])) logger.info("text_mask \n{}".format( cur_data["text_mask"][batch_idx])) logger.info("text_labels \n{}".format( cur_data["text_labels"][batch_idx])) print_info(batched_data, 0) # forward & backward optimizer.zero_grad() loss, pred_scores = model(video_feature, video_mask, text_ids, text_mask, text_labels) # make it consistent with other configs pred_scores_list = [pred_scores] input_labels_list = [text_labels] else: # prepare data batched_data = prepare_batch_inputs( batch[0], device=device, non_blocking=opt.pin_memory) input_ids = batched_data["input_ids"] video_features = batched_data["video_feature"] input_masks = batched_data["input_mask"] token_type_ids = batched_data["token_type_ids"] input_labels = batched_data["input_labels"] if opt.debug: def print_info(cur_data, batch_idx): logger.info("input_ids \n{}".format( cur_data["input_ids"][batch_idx])) logger.info("input_mask \n{}".format( cur_data["input_mask"][batch_idx])) logger.info("input_labels \n{}".format( cur_data["input_labels"][batch_idx])) logger.info("token_type_ids \n{}".format( cur_data["token_type_ids"][batch_idx])) print_info(batched_data, 0) # forward & backward optimizer.zero_grad() loss, pred_scores = model(input_ids, video_features, input_masks, token_type_ids, input_labels) # make it consistent with other configs pred_scores_list = [pred_scores] input_labels_list = [input_labels] loss.backward() if opt.grad_clip != -1: # enable, -1 == disable nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() # update model parameters with ema if ema is not None: ema(model, niter) # keep logs n_correct = 0 n_word = 0 for pred, gold in zip(pred_scores_list, input_labels_list): n_correct += cal_performance(pred, gold) valid_label_mask = gold.ne(RCDataset.IGNORE) n_word += valid_label_mask.sum().item() n_word_total += n_word n_word_correct += n_correct total_loss += loss.item() if opt.debug: break torch.autograd.set_detect_anomaly(False) loss_per_word = 1.0 * total_loss / n_word_total accuracy = 1.0 * n_word_correct / n_word_total return loss_per_word, accuracy
def run_translate(eval_data_loader, translator, opt): # submission template batch_res = {"version": "VERSION 1.0", "results": defaultdict(list), "external_data": {"used": "true", "details": "ay"}} for raw_batch in tqdm(eval_data_loader, mininterval=2, desc=" - (Translate)"): if opt.recurrent: # prepare data step_sizes = raw_batch[1] # list(int), len == bsz meta = raw_batch[2] # list(dict), len == bsz batch = [prepare_batch_inputs(step_data, device=translator.device) for step_data in raw_batch[0]] model_inputs = [ [e["input_ids"] for e in batch], [e["video_feature"] for e in batch], [e["input_mask"] for e in batch], [e["token_type_ids"] for e in batch] ] dec_seq_list = translator.translate_batch( model_inputs, use_beam=opt.use_beam, recurrent=True, untied=False, xl=opt.xl) # example_idx indicates which example is in the batch for example_idx, (step_size, cur_meta) in enumerate(zip(step_sizes, meta)): # step_idx or we can also call it sen_idx for step_idx, step_batch in enumerate(dec_seq_list[:step_size]): batch_res["results"][cur_meta["name"]].append({ "sentence": eval_data_loader.dataset.convert_ids_to_sentence( step_batch[example_idx].cpu().tolist()).encode("ascii", "ignore"), "timestamp": cur_meta["timestamp"][step_idx], "gt_sentence": cur_meta["gt_sentence"][step_idx] }) else: # single sentence meta = raw_batch[2] # list(dict), len == bsz batched_data = prepare_batch_inputs(raw_batch[0], device=translator.device) if opt.untied or opt.mtrans: model_inputs = [ batched_data["video_feature"], batched_data["video_mask"], batched_data["text_ids"], batched_data["text_mask"], batched_data["text_labels"] ] else: model_inputs = [ batched_data["input_ids"], batched_data["video_feature"], batched_data["input_mask"], batched_data["token_type_ids"] ] dec_seq = translator.translate_batch( model_inputs, use_beam=opt.use_beam, recurrent=False, untied=opt.untied or opt.mtrans) # example_idx indicates which example is in the batch for example_idx, (cur_gen_sen, cur_meta) in enumerate(zip(dec_seq, meta)): cur_data = { "sentence": eval_data_loader.dataset.convert_ids_to_sentence( cur_gen_sen.cpu().tolist()).encode("ascii", "ignore"), "timestamp": cur_meta["timestamp"], "gt_sentence": cur_meta["gt_sentence"] } batch_res["results"][cur_meta["name"]].append(cur_data) if opt.debug: break batch_res["results"] = sort_res(batch_res["results"]) return batch_res