Пример #1
0
def start_end_collate(batch):
    batch_meta = [e["meta"] for e in batch]  # seems no need to collate ?

    model_inputs_keys = batch[0]["model_inputs"].keys()
    batched_data = dict()
    for k in model_inputs_keys:
        if "feat" in k:
            batched_data[k] = pad_sequences_1d(
                [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None)

    if "st_ed_indices" in model_inputs_keys:
        batched_data["st_ed_indices"] = torch.stack(
            [e["model_inputs"]["st_ed_indices"] for e in batch], dim=0)
    return batch_meta, batched_data


def prepare_batch_inputs(batched_model_inputs, device, non_blocking=False):
    model_inputs = {}
    for k, v in batched_model_inputs.items():
        if "feat" in k:
            model_inputs[k] = v[0].to(device, non_blocking=non_blocking)
            model_inputs[k.replace("feat", "mask")] = v[1].to(device, non_blocking=non_blocking)
        else:
            model_inputs[k] = v.to(device, non_blocking=non_blocking)
    return model_inputs


if __name__ == '__main__':
    from baselines.crossmodal_moment_localization.config import BaseOptions
    options = BaseOptions().parse()
Пример #2
0
def start_training():
    logger.info("Setup config, data and model...")
    opt = BaseOptions().parse()
    set_seed(opt.seed)
    if opt.debug:  # keep the model run deterministically
        # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
        # Enable this only when input size is fixed.
        cudnn.benchmark = False
        cudnn.deterministic = True

    opt.writer = SummaryWriter(opt.tensorboard_log_dir)
    opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
    opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n"

    train_dataset = StartEndDataset(
        dset_name=opt.dset_name,
        data_path=opt.train_path,
        desc_bert_path_or_handler=opt.desc_bert_path,
        sub_bert_path_or_handler=opt.sub_bert_path,
        max_desc_len=opt.max_desc_l,
        max_ctx_len=opt.max_ctx_l,
        vid_feat_path_or_handler=opt.vid_feat_path,
        clip_length=opt.clip_length,
        ctx_mode=opt.ctx_mode,
        h5driver=opt.h5driver,
        data_ratio=opt.data_ratio,
        normalize_vfeat=not opt.no_norm_vfeat,
        normalize_tfeat=not opt.no_norm_tfeat,
    )

    if opt.eval_path is not None:
        # val dataset, used to get eval loss
        train_eval_dataset = StartEndDataset(
            dset_name=opt.dset_name,
            data_path=opt.eval_path,
            desc_bert_path_or_handler=train_dataset.desc_bert_h5,
            sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None,
            max_desc_len=opt.max_desc_l,
            max_ctx_len=opt.max_ctx_l,
            vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None,
            clip_length=opt.clip_length,
            ctx_mode=opt.ctx_mode,
            h5driver=opt.h5driver,
            data_ratio=opt.data_ratio,
            normalize_vfeat=not opt.no_norm_vfeat,
            normalize_tfeat=not opt.no_norm_tfeat
        )

        eval_dataset = StartEndEvalDataset(
            dset_name=opt.dset_name,
            eval_split_name=opt.eval_split_name,  # should only be val set
            data_path=opt.eval_path,
            desc_bert_path_or_handler=train_dataset.desc_bert_h5,
            sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None,
            max_desc_len=opt.max_desc_l,
            max_ctx_len=opt.max_ctx_l,
            video_duration_idx_path=opt.video_duration_idx_path,
            vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None,
            clip_length=opt.clip_length,
            ctx_mode=opt.ctx_mode,
            data_mode="query",
            h5driver=opt.h5driver,
            data_ratio=opt.data_ratio,
            normalize_vfeat=not opt.no_norm_vfeat,
            normalize_tfeat=not opt.no_norm_tfeat
        )
    else:
        eval_dataset = None

    model_config = EDict(
        merge_two_stream=not opt.no_merge_two_stream,  # merge video and subtitles
        cross_att=not opt.no_cross_att,  # use cross-attention when encoding video and subtitles
        span_predictor_type=opt.span_predictor_type,  # span_predictor_type
        encoder_type=opt.encoder_type,  # gru, lstm, transformer
        add_pe_rnn=opt.add_pe_rnn,  # add pe for RNNs
        pe_type=opt.pe_type,  #
        visual_input_size=opt.vid_feat_size,
        sub_input_size=opt.sub_feat_size,  # for both desc and subtitles
        query_input_size=opt.q_feat_size,  # for both desc and subtitles
        hidden_size=opt.hidden_size,  #
        stack_conv_predictor_conv_kernel_sizes=opt.stack_conv_predictor_conv_kernel_sizes,  #
        conv_kernel_size=opt.conv_kernel_size,
        conv_stride=opt.conv_stride,
        max_ctx_l=opt.max_ctx_l,
        max_desc_l=opt.max_desc_l,
        input_drop=opt.input_drop,
        cross_att_drop=opt.cross_att_drop,
        drop=opt.drop,
        n_heads=opt.n_heads,  # self-att heads
        initializer_range=opt.initializer_range,  # for linear layer
        ctx_mode=opt.ctx_mode,  # video, sub or video_sub
        margin=opt.margin,  # margin for ranking loss
        ranking_loss_type=opt.ranking_loss_type,  # loss type, 'hinge' or 'lse'
        lw_neg_q=opt.lw_neg_q,  # loss weight for neg. query and pos. context
        lw_neg_ctx=opt.lw_neg_ctx,  # loss weight for pos. query and neg. context
        lw_st_ed=0,  # will be assigned dynamically at training time
        use_hard_negative=False,  # reset at each epoch
        hard_pool_size=opt.hard_pool_size,
        use_self_attention=not opt.no_self_att,  # whether to use self attention
        no_modular=opt.no_modular
    )
    logger.info("model_config {}".format(model_config))
    model = XML(model_config)
    count_parameters(model)
    logger.info("Start Training...")
    train(model, train_dataset, train_eval_dataset, eval_dataset, opt)
    return opt.results_dir, opt.eval_split_name, opt.eval_path, opt.debug