Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        default='./datasets/coco_ir/',
                        type=str,
                        required=False,
                        help="The input data dir with all required files.")
    parser.add_argument("--img_feat_file",
                        default='/disk2/11811112/Oscar/coco_ir/features.tsv',
                        type=str,
                        required=False,
                        help="The absolute address of the image feature file.")
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=False,
        help="Path to pre-trained model or model type. required for training.")
    parser.add_argument(
        "--output_dir",
        default='output/',
        type=str,
        required=False,
        help="The output directory to save checkpoint and test results.")
    parser.add_argument("--loss_type",
                        default='sfmx',
                        type=str,
                        help="Loss function types: support kl, sfmx")
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name.")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name."
    )
    parser.add_argument(
        "--max_seq_length",
        default=70,
        type=int,
        help="The maximum total input sequence length after tokenization. "
        "Sequences longer than this will be truncated, "
        "sequences shorter will be padded."
        "This number is calculated on COCO dataset"
        "If add object detection labels, the suggested length should be 70.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run inference.")
    parser.add_argument(
        "--do_eval",
        action='store_true',
        help="Whether to run performance valuation."
        "do not activate if we want to inference on dataset without gt labels."
    )
    parser.add_argument("--test_split",
                        default='test',
                        type=str,
                        help='data split name.')
    parser.add_argument(
        "--eval_img_keys_file",
        default='',
        type=str,
        help="image key tsv to select a subset of images for evaluation. "
        "This is useful in 5-folds evaluation. The topn index file is not "
        "needed in this case.")
    parser.add_argument(
        "--eval_caption_index_file",
        default='',
        type=str,
        help="index of a list of (img_key, cap_idx) for each image."
        "this is used to perform re-rank using hard negative samples."
        "useful for validation set to monitor the performance during training."
    )
    parser.add_argument(
        "--cross_image_eval",
        action='store_true',
        help=
        "perform cross image inference, ie. each image with all texts from other images."
    )
    parser.add_argument("--add_od_labels",
                        default=False,
                        action='store_true',
                        help="Whether to add object detection labels or not.")
    parser.add_argument("--od_label_type",
                        default='vg',
                        type=str,
                        help="label type, support vg, gt, oid")
    parser.add_argument(
        "--att_mask_type",
        default='CLR',
        type=str,
        help="attention mask type, support ['CL', 'CR', 'LR', 'CLR']"
        "C: caption, L: labels, R: image regions; CLR is full attention by default."
        "CL means attention between caption and labels."
        "please pay attention to the order CLR, which is the default concat order."
    )
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--drop_out",
                        default=0.1,
                        type=float,
                        help="Drop out in BERT.")
    parser.add_argument("--max_img_seq_length",
                        default=50,
                        type=int,
                        help="The maximum total input image sequence length.")
    parser.add_argument("--img_feature_dim",
                        default=2054,
                        type=int,
                        help="The Image Feature Dimension.")
    parser.add_argument("--img_feature_type",
                        default='frcnn',
                        type=str,
                        help="Image feature type.")
    parser.add_argument("--use_img_layernorm",
                        type=int,
                        default=1,
                        help="Normalize image features with bertlayernorm")
    parser.add_argument("--img_layer_norm_eps",
                        default=1e-12,
                        type=float,
                        help="The eps in image feature laynorm layer")
    parser.add_argument("--per_gpu_train_batch_size",
                        default=2,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=2,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        "--output_mode",
        default='classification',
        type=str,
        help="output mode, support classification or regression.")
    parser.add_argument(
        "--num_labels",
        default=2,
        type=int,
        help="num_labels is 2 for classification and 1 for regression.")
    parser.add_argument(
        "--num_captions_per_img_train",
        default=5,
        type=int,
        help="number of positive matched captions for each training image.")
    parser.add_argument("--num_captions_per_img_val",
                        default=5,
                        type=int,
                        help="number of captions for each testing image.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help="Number of updates steps to accumulate before backward.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial lr.")
    parser.add_argument("--weight_decay",
                        default=0.05,
                        type=float,
                        help="Weight deay.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup.")
    parser.add_argument("--scheduler",
                        default='linear',
                        type=str,
                        help="constant or linear.")
    parser.add_argument("--num_workers",
                        default=4,
                        type=int,
                        help="Workers in dataloader.")
    parser.add_argument("--num_train_epochs",
                        default=20,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="Total number of training steps. Override num_train_epochs.")
    parser.add_argument('--logging_steps',
                        type=int,
                        default=20,
                        help="Log every X steps.")
    parser.add_argument(
        '--save_steps',
        type=int,
        default=-1,
        help="Save checkpoint every X steps. Will also perform evaluatin.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Run evaluation during training at each save_steps.")
    parser.add_argument("--eval_model_dir",
                        type=str,
                        default='./output0320/checkpoint-29-66390/',
                        help="Model directory for evaluation.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA.")
    parser.add_argument('--seed',
                        type=int,
                        default=88,
                        help="random seed for initialization.")
    args = parser.parse_args()

    global logger
    mkdir(args.output_dir)
    logger = setup_logger("vlpretrain", args.output_dir, 0)

    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()
    set_seed(args.seed, args.n_gpu)
    logger.warning("Device: %s, n_gpu: %s", args.device, args.n_gpu)
    logger.info('output_mode: {}, #Labels: {}'.format(args.output_mode,
                                                      args.num_labels))

    config_class, tokenizer_class = BertConfig, BertTokenizer
    model_class = ImageBertForSequenceClassification
    checkpoint = args.eval_model_dir
    assert op.isdir(checkpoint)
    config = config_class.from_pretrained(checkpoint)
    tokenizer = tokenizer_class.from_pretrained(checkpoint)
    model = model_class.from_pretrained(checkpoint, config=config)

    model.to(args.device)
    # inference and evaluation
    args = restore_training_settings(args)
    test_dataset = RetrievalDataset(tokenizer,
                                    args,
                                    args.test_split,
                                    is_train=False)
    checkpoint = args.eval_model_dir
    assert op.isdir(checkpoint)
    model = model_class.from_pretrained(checkpoint, config=config)
    model.to(args.device)
    print()
    if args.do_test or args.do_eval:
        args = restore_training_settings(args)
        test_dataset = RetrievalDataset(tokenizer,
                                        args,
                                        args.test_split,
                                        is_train=False)
        checkpoint = args.eval_model_dir
        assert op.isdir(checkpoint)
        logger.info("Evaluate the following checkpoint: %s", checkpoint)
        model = model_class.from_pretrained(checkpoint, config=config)
        model.to(args.device)
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        result = get_intermediate_data(args, model.module,
                                       test_dataset)  #得到中间数据
        ##test_result = test(args, model, test_dataset)
        mediate_file = op.basename("mediate_file.txt")
        torch.save(str(result), mediate_file)
        logger.info("Prediction results saved to {}.".format(mediate_file))
Ejemplo n.º 2
0
def main():
    args = get_args()

    global logger
    # global logger, writer

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    if args.do_train:
        mkdir(args.output_dir)

        t = datetime.today()
        args.output_dir = op.join(
            args.output_dir,
            f"{t.month}_{t.day}_{t.hour}_{t.minute}_{t.second}")
        if not op.exists(args.output_dir):
            mkdir(args.output_dir)

        logger = setup_logger("vlpretrain", args.output_dir, args.local_rank)
    else:
        logger = setup_logger("vlpretrain",
                              os.path.dirname(args.eval_model_dir),
                              args.local_rank, 'test_log.txt')

    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1))

    set_seed(args.seed, args.n_gpu)

    # writer = SummaryWriter(log_dir=args.output_dir, flush_secs=60)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    config_class, model_class, tokenizer_class = BertConfig, BertForImageCaptioning, BertTokenizer
    if args.do_train:
        assert args.model_name_or_path is not None
        config = config_class.from_pretrained(
            args.config_name if args.config_name else args.model_name_or_path,
            num_labels=args.num_labels,
            finetuning_task='image_captioning')
        if args.scst:
            # avoid using too much memory
            config.output_hidden_states = True
        tokenizer = tokenizer_class.from_pretrained(
            args.tokenizer_name
            if args.tokenizer_name else args.model_name_or_path,
            do_lower_case=args.do_lower_case)
        config.img_feature_dim = args.img_feature_dim
        config.img_feature_type = args.img_feature_type
        config.hidden_dropout_prob = args.drop_out
        config.loss_type = args.loss_type
        model = model_class.from_pretrained(
            args.model_name_or_path,
            from_tf=bool('.ckpt' in args.model_name_or_path),
            config=config)
    else:
        assert op.isdir(args.eval_model_dir)
        config = config_class.from_pretrained(args.eval_model_dir)
        config.output_hidden_states = args.output_hidden_states
        tokenizer = tokenizer_class.from_pretrained(args.eval_model_dir)
        logger.info("Evaluate the following checkpoint: %s",
                    args.eval_model_dir)
        model = model_class.from_pretrained(args.eval_model_dir, config=config)

    if args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    model.to(args.device)
    logger.info("Training/evaluation parameters %s", args)
    if args.do_train:
        train_dataset = build_dataset('train', tokenizer, args)
        val_dataset = build_dataset('dev', tokenizer, args, is_train=False)
        global_step, avg_loss = train(args, train_dataset, val_dataset, model,
                                      tokenizer)
        logger.info("Training done: total_step = %s, avg loss = %s",
                    global_step, avg_loss)

    # # inference and evaluation
    # if args.do_test or args.do_eval:
    #     args = restore_training_settings(args)
    #     test_dataset = build_dataset('test', tokenizer, args, is_train=False)
    #     if args.n_gpu > 1:
    #         model = torch.nn.DataParallel(model)

    #     if not args.do_eval:
    #         predict_file = get_predict_file('test', args.eval_model_dir, args)
    #         test(args, test_dataset, model, tokenizer, predict_file)
    #         logger.info("Prediction results saved to: {}".format(predict_file))
    #     else:
    #         evaluate_file = evaluate(args, test_dataset, model, tokenizer,
    #                                  args.eval_model_dir)
    #         logger.info(
    #             "Evaluation results saved to: {}".format(evaluate_file))

    if args.do_test and args.local_rank in [-1, 0]:
        args = restore_training_settings(args)
        test_dataset = build_dataset('test', tokenizer, args, is_train=False)
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        predict_file = get_predict_file('test', args.eval_model_dir, args)
        test(args, test_dataset, model, tokenizer, predict_file)
        logger.info("Prediction results saved to: {}".format(predict_file))

    if args.do_eval and args.local_rank in [-1, 0]:
        args = restore_training_settings(args)
        dev_dataset = build_dataset('dev', tokenizer, args, is_train=False)
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        predict_file = get_predict_file('dev', args.eval_model_dir, args)
        test(args, dev_dataset, model, tokenizer, predict_file)
        logger.info("Prediction results saved to: {}".format(predict_file))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        default='datasets/coco_caption',
                        type=str,
                        required=False,
                        help="The input data dir with all required files.")
    parser.add_argument("--train_yaml",
                        default='train.yaml',
                        type=str,
                        required=False,
                        help="yaml file for training.")
    parser.add_argument("--test_yaml",
                        default='test.yaml',
                        type=str,
                        required=False,
                        help="yaml file for testing.")
    parser.add_argument("--val_yaml",
                        default='val.yaml',
                        type=str,
                        required=False,
                        help="yaml file used for validation during training.")
    parser.add_argument("--model_name_or_path",
                        default=None,
                        type=str,
                        required=False,
                        help="Path to pre-trained model or model type.")
    parser.add_argument(
        "--output_dir",
        default='output/',
        type=str,
        required=False,
        help="The output directory to save checkpoint and test results.")
    parser.add_argument("--loss_type",
                        default='sfmx',
                        type=str,
                        help="Loss function types: support kl, x2, sfmx")
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name.")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name."
    )
    parser.add_argument(
        "--max_seq_length",
        default=70,
        type=int,
        help="The maximum total input sequence length after tokenization. "
        "Sequences longer than this will be truncated, "
        "sequences shorter will be padded.")
    parser.add_argument("--max_seq_a_length",
                        default=40,
                        type=int,
                        help="The maximum sequence length for caption.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run inference.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run evaluation.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help="Probability to mask input sentence during training.")
    parser.add_argument("--max_masked_tokens",
                        type=int,
                        default=3,
                        help="The max number of masked tokens per sentence.")
    parser.add_argument("--add_od_labels",
                        default=False,
                        action='store_true',
                        help="Whether to add object detection labels or not")
    parser.add_argument(
        "--disable_img_features",
        default=False,
        action='store_true',
        help="Whether to disable image feature in finetuning state or not")
    parser.add_argument(
        '--keep_top_percentage_tag_conf_threshold',
        type=float,
        default=0.3,
        help="Confidence threshold k for keep_top_percengate_tag")
    parser.add_argument(
        '--keep_top_percentage_tag',
        type=float,
        default=1,
        help=
        "Keep input percentage features at inference time given that >= k confidence"
    )
    parser.add_argument("--drop_out",
                        default=0.1,
                        type=float,
                        help="Drop out in BERT.")
    parser.add_argument("--max_img_seq_length",
                        default=50,
                        type=int,
                        help="The maximum total input image sequence length.")
    parser.add_argument("--img_feature_dim",
                        default=2054,
                        type=int,
                        help="The Image Feature Dimension.")
    parser.add_argument("--img_feature_type",
                        default='frcnn',
                        type=str,
                        help="Image feature type.")
    parser.add_argument("--per_gpu_train_batch_size",
                        default=64,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=64,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        "--output_mode",
        default='classification',
        type=str,
        help="output mode, support classification or regression.")
    parser.add_argument(
        "--num_labels",
        default=2,
        type=int,
        help="num_labels is 2 for classification and 1 for regression.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help="Number of updates steps to accumulate before backward.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial lr.")
    parser.add_argument("--weight_decay",
                        default=0.05,
                        type=float,
                        help="Weight deay.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup.")
    parser.add_argument("--scheduler",
                        default='linear',
                        type=str,
                        help="constant or linear or")
    parser.add_argument("--num_workers",
                        default=4,
                        type=int,
                        help="Workers in dataloader.")
    parser.add_argument("--num_train_epochs",
                        default=40,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="Total number of training steps. Override num_train_epochs.")
    parser.add_argument('--logging_steps',
                        type=int,
                        default=20,
                        help="Log every X steps.")
    parser.add_argument(
        '--save_steps',
        type=int,
        default=-1,
        help="Save checkpoint every X steps. Will also perform evaluatin.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Run evaluation during training at each save_steps.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA.")
    parser.add_argument('--seed',
                        type=int,
                        default=88,
                        help="random seed for initialization.")
    parser.add_argument('--scst',
                        action='store_true',
                        help='Self-critical sequence training')
    # for generation
    parser.add_argument("--eval_model_dir",
                        type=str,
                        default='',
                        help="Model directory for evaluation.")
    parser.add_argument('--max_gen_length',
                        type=int,
                        default=20,
                        help="max length of generated sentences")
    parser.add_argument('--output_hidden_states',
                        action='store_true',
                        help="Turn on for fast decoding")
    parser.add_argument('--num_return_sequences',
                        type=int,
                        default=1,
                        help="repeating times per image")
    parser.add_argument('--num_beams',
                        type=int,
                        default=5,
                        help="beam search width")
    parser.add_argument('--num_keep_best',
                        type=int,
                        default=1,
                        help="number of hypotheses to keep in beam search")
    parser.add_argument('--temperature',
                        type=float,
                        default=1,
                        help="temperature in softmax for sampling")
    parser.add_argument('--top_k',
                        type=int,
                        default=0,
                        help="filter distribution for sampling")
    parser.add_argument('--top_p',
                        type=float,
                        default=1,
                        help="filter distribution for sampling")
    parser.add_argument(
        '--repetition_penalty',
        type=int,
        default=1,
        help=
        "repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)"
    )
    parser.add_argument('--length_penalty',
                        type=int,
                        default=1,
                        help="beam search length penalty")
    # for Constrained Beam Search
    parser.add_argument('--use_cbs',
                        action='store_true',
                        help='Use constrained beam search for decoding')
    parser.add_argument('--min_constraints_to_satisfy',
                        type=int,
                        default=2,
                        help="minimum number of constraints to satisfy")
    args = parser.parse_args()

    global logger

    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()

    output_dir = args.output_dir
    mkdir(output_dir)

    logger = setup_logger("vlpretrain", output_dir, 0)
    logger.warning("Device: %s, n_gpu: %s", args.device, args.n_gpu)
    set_seed(args.seed, args.n_gpu)

    # Load pretrained model and tokenizer
    config_class, model_class, tokenizer_class = BertConfig, BertForImageCaptioning, BertTokenizer
    if args.do_train:
        assert args.model_name_or_path is not None
        config = config_class.from_pretrained(args.config_name if args.config_name else \
                args.model_name_or_path, num_labels=args.num_labels, finetuning_task='image_captioning')
        if args.scst:
            # avoid using too much memory
            config.output_hidden_states = True
        tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name \
                else args.model_name_or_path, do_lower_case=args.do_lower_case)
        config.img_feature_dim = args.img_feature_dim
        config.img_feature_type = args.img_feature_type
        config.hidden_dropout_prob = args.drop_out
        config.loss_type = args.loss_type
        model = model_class.from_pretrained(
            args.model_name_or_path,
            from_tf=bool('.ckpt' in args.model_name_or_path),
            config=config)
    else:
        checkpoint = args.eval_model_dir
        assert op.isdir(checkpoint)
        config = config_class.from_pretrained(checkpoint)
        config.output_hidden_states = args.output_hidden_states
        tokenizer = tokenizer_class.from_pretrained(checkpoint)
        logger.info("Evaluate the following checkpoint: %s", checkpoint)
        model = model_class.from_pretrained(checkpoint, config=config)

    model.to(args.device)
    logger.info("Training/evaluation parameters %s", args)
    if args.do_train:
        train_dataset = build_dataset(op.join(args.data_dir, args.train_yaml),
                                      tokenizer, args)
        val_dataset = build_dataset(op.join(args.data_dir, args.val_yaml),
                                    tokenizer,
                                    args,
                                    is_train=False)
        global_step, avg_loss = train(args, train_dataset, val_dataset, model,
                                      tokenizer)
        logger.info("Training done: total_step = %s, avg loss = %s",
                    global_step, avg_loss)

    # inference and evaluation
    if args.do_test or args.do_eval:
        args = restore_training_settings(args)
        test_dataset = build_dataset(op.join(args.data_dir, args.test_yaml),
                                     tokenizer,
                                     args,
                                     is_train=False)
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        if not args.do_eval:
            predict_file = get_predict_file(checkpoint, test_dataset.yaml_file,
                                            args)
            test(args, test_dataset, model, tokenizer, predict_file)
            logger.info("Prediction results saved to: {}".format(predict_file))
        else:
            evaluate_file = evaluate(args, test_dataset, model, tokenizer,
                                     checkpoint)
            logger.info(
                "Evaluation results saved to: {}".format(evaluate_file))
Ejemplo n.º 4
0
def main():
    args = get_args()

    global logger

    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()

    mkdir(args.output_dir)

    logger = setup_logger("vlpretrain", args.output_dir, 0)
    logger.warning("Device: %s, n_gpu: %s", args.device, args.n_gpu)
    set_seed(args.seed, args.n_gpu)

    # Load pretrained model and tokenizer
    config_class, model_class, tokenizer_class = BertConfig, BertForImageCaptioning, BertTokenizer
    if args.do_train:
        assert args.model_name_or_path is not None
        config = config_class.from_pretrained(args.config_name if args.config_name else
                                              args.model_name_or_path, num_labels=args.num_labels, finetuning_task='image_captioning')
        if args.scst:
            # avoid using too much memory
            config.output_hidden_states = True
        tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name
                                                    else args.model_name_or_path, do_lower_case=args.do_lower_case)
        config.img_feature_dim = args.img_feature_dim
        config.img_feature_type = args.img_feature_type
        config.hidden_dropout_prob = args.drop_out
        config.loss_type = args.loss_type
        model = model_class.from_pretrained(args.model_name_or_path,
                                            from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
    else:
        checkpoint = args.eval_model_dir
        assert op.isdir(checkpoint)
        config = config_class.from_pretrained(checkpoint)
        config.output_hidden_states = args.output_hidden_states
        tokenizer = tokenizer_class.from_pretrained(checkpoint)
        logger.info("Evaluate the following checkpoint: %s", checkpoint)
        model = model_class.from_pretrained(checkpoint, config=config)

    model.to(args.device)
    logger.info("Training/evaluation parameters %s", args)
    if args.do_train:
        train_dataset = build_dataset(
            op.join(args.data_dir, args.train_yaml), tokenizer, args)
        val_dataset = build_dataset(op.join(args.data_dir, args.val_yaml),
                                    tokenizer, args, is_train=False)
        global_step, avg_loss = train(
            args, train_dataset, val_dataset, model, tokenizer)
        logger.info("Training done: total_step = %s, avg loss = %s",
                    global_step, avg_loss)

    # inference and evaluation
    if args.do_test or args.do_eval:
        args = restore_training_settings(args)
        test_dataset = build_dataset(op.join(args.data_dir, args.test_yaml),
                                     tokenizer, args, is_train=False)
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        if not args.do_eval:
            predict_file = get_predict_file(
                checkpoint, test_dataset.yaml_file, args)
            test(args, test_dataset, model, tokenizer, predict_file)
            logger.info("Prediction results saved to: {}".format(predict_file))
        else:
            evaluate_file = evaluate(args, test_dataset, model, tokenizer,
                                     checkpoint)
            logger.info(
                "Evaluation results saved to: {}".format(evaluate_file))