def start_end_collate(batch): batch_meta = [e["meta"] for e in batch] # seems no need to collate ? model_inputs_keys = batch[0]["model_inputs"].keys() batched_data = dict() for k in model_inputs_keys: if "feat" in k: batched_data[k] = pad_sequences_1d( [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None) if "st_ed_indices" in model_inputs_keys: batched_data["st_ed_indices"] = torch.stack( [e["model_inputs"]["st_ed_indices"] for e in batch], dim=0) return batch_meta, batched_data def prepare_batch_inputs(batched_model_inputs, device, non_blocking=False): model_inputs = {} for k, v in batched_model_inputs.items(): if "feat" in k: model_inputs[k] = v[0].to(device, non_blocking=non_blocking) model_inputs[k.replace("feat", "mask")] = v[1].to(device, non_blocking=non_blocking) else: model_inputs[k] = v.to(device, non_blocking=non_blocking) return model_inputs if __name__ == '__main__': from baselines.crossmodal_moment_localization.config import BaseOptions options = BaseOptions().parse()
def start_training(): logger.info("Setup config, data and model...") opt = BaseOptions().parse() set_seed(opt.seed) if opt.debug: # keep the model run deterministically # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config. # Enable this only when input size is fixed. cudnn.benchmark = False cudnn.deterministic = True opt.writer = SummaryWriter(opt.tensorboard_log_dir) opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n" train_dataset = StartEndDataset( dset_name=opt.dset_name, data_path=opt.train_path, desc_bert_path_or_handler=opt.desc_bert_path, sub_bert_path_or_handler=opt.sub_bert_path, max_desc_len=opt.max_desc_l, max_ctx_len=opt.max_ctx_l, vid_feat_path_or_handler=opt.vid_feat_path, clip_length=opt.clip_length, ctx_mode=opt.ctx_mode, h5driver=opt.h5driver, data_ratio=opt.data_ratio, normalize_vfeat=not opt.no_norm_vfeat, normalize_tfeat=not opt.no_norm_tfeat, ) if opt.eval_path is not None: # val dataset, used to get eval loss train_eval_dataset = StartEndDataset( dset_name=opt.dset_name, data_path=opt.eval_path, desc_bert_path_or_handler=train_dataset.desc_bert_h5, sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None, max_desc_len=opt.max_desc_l, max_ctx_len=opt.max_ctx_l, vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None, clip_length=opt.clip_length, ctx_mode=opt.ctx_mode, h5driver=opt.h5driver, data_ratio=opt.data_ratio, normalize_vfeat=not opt.no_norm_vfeat, normalize_tfeat=not opt.no_norm_tfeat ) eval_dataset = StartEndEvalDataset( dset_name=opt.dset_name, eval_split_name=opt.eval_split_name, # should only be val set data_path=opt.eval_path, desc_bert_path_or_handler=train_dataset.desc_bert_h5, sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None, max_desc_len=opt.max_desc_l, max_ctx_len=opt.max_ctx_l, video_duration_idx_path=opt.video_duration_idx_path, vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None, clip_length=opt.clip_length, ctx_mode=opt.ctx_mode, data_mode="query", h5driver=opt.h5driver, data_ratio=opt.data_ratio, normalize_vfeat=not opt.no_norm_vfeat, normalize_tfeat=not opt.no_norm_tfeat ) else: eval_dataset = None model_config = EDict( merge_two_stream=not opt.no_merge_two_stream, # merge video and subtitles cross_att=not opt.no_cross_att, # use cross-attention when encoding video and subtitles span_predictor_type=opt.span_predictor_type, # span_predictor_type encoder_type=opt.encoder_type, # gru, lstm, transformer add_pe_rnn=opt.add_pe_rnn, # add pe for RNNs pe_type=opt.pe_type, # visual_input_size=opt.vid_feat_size, sub_input_size=opt.sub_feat_size, # for both desc and subtitles query_input_size=opt.q_feat_size, # for both desc and subtitles hidden_size=opt.hidden_size, # stack_conv_predictor_conv_kernel_sizes=opt.stack_conv_predictor_conv_kernel_sizes, # conv_kernel_size=opt.conv_kernel_size, conv_stride=opt.conv_stride, max_ctx_l=opt.max_ctx_l, max_desc_l=opt.max_desc_l, input_drop=opt.input_drop, cross_att_drop=opt.cross_att_drop, drop=opt.drop, n_heads=opt.n_heads, # self-att heads initializer_range=opt.initializer_range, # for linear layer ctx_mode=opt.ctx_mode, # video, sub or video_sub margin=opt.margin, # margin for ranking loss ranking_loss_type=opt.ranking_loss_type, # loss type, 'hinge' or 'lse' lw_neg_q=opt.lw_neg_q, # loss weight for neg. query and pos. context lw_neg_ctx=opt.lw_neg_ctx, # loss weight for pos. query and neg. context lw_st_ed=0, # will be assigned dynamically at training time use_hard_negative=False, # reset at each epoch hard_pool_size=opt.hard_pool_size, use_self_attention=not opt.no_self_att, # whether to use self attention no_modular=opt.no_modular ) logger.info("model_config {}".format(model_config)) model = XML(model_config) count_parameters(model) logger.info("Start Training...") train(model, train_dataset, train_eval_dataset, eval_dataset, opt) return opt.results_dir, opt.eval_split_name, opt.eval_path, opt.debug