예제 #1
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default='tmp',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_file",
        default="training.log",
        type=str,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--do_train",
        action='store_true',
        help="Whether to run training. This should ALWAYS be set to True.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=30,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--global_rank",
                        type=int,
                        default=-1,
                        help="global_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 32-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=100,
                        help="The length of visual token input")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=20,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='b',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=3,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=4,
                        type=int,
                        help="Number of workers for the data loader.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")

    # Others for VLP
    parser.add_argument(
        "--src_file",
        default=['/mnt/dat/COCO/annotations/dataset_coco.json'],
        type=str,
        nargs='+',
        help="The input data file name.")
    parser.add_argument('--enable_visdom', action='store_true')
    parser.add_argument('--visdom_port', type=int, default=8888)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument('--image_root',
                        type=str,
                        default='/mnt/dat/COCO/images')
    parser.add_argument('--dataset',
                        default='coco',
                        type=str,
                        help='coco | flickr30k | cc')
    parser.add_argument('--split',
                        type=str,
                        nargs='+',
                        default=['train', 'restval'])

    parser.add_argument('--world_size',
                        default=1,
                        type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url',
                        default='file://[PT_OUTPUT_DIR]/nonexistent_file',
                        type=str,
                        help='url used to set up distributed training')
    parser.add_argument(
        '--file_valid_jpgs',
        default='/mnt/dat/COCO/annotations/coco_valid_jpgs.json',
        type=str)
    parser.add_argument('--sche_mode',
                        default='warmup_linear',
                        type=str,
                        help="warmup_linear | warmup_constant | warmup_cosine")
    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--use_num_imgs', default=-1, type=int)
    parser.add_argument('--vis_mask_prob', default=0, type=float)
    parser.add_argument('--max_drop_worst_ratio', default=0, type=float)
    parser.add_argument('--drop_after', default=6, type=int)

    parser.add_argument(
        '--s2s_prob',
        default=1,
        type=float,
        help="Percentage of examples that are bi-uni-directional LM (seq2seq)."
    )
    parser.add_argument(
        '--bi_prob',
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.")
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument(
        '--region_bbox_file',
        default=
        'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5',
        type=str)
    parser.add_argument(
        '--region_det_file_prefix',
        default=
        'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval',
        type=str)
    parser.add_argument('--tasks', default='img2txt', help='img2txt | vqa2')
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--scst',
                        action='store_true',
                        help='Self-critical sequence training')

    args = parser.parse_args()

    print('global_rank: {}, local rank: {}'.format(args.global_rank,
                                                   args.local_rank))

    args.max_seq_length = args.max_len_b + args.len_vis_input + 3  # +3 for 2x[SEP] and [CLS]
    args.mask_image_regions = (args.vis_mask_prob > 0
                               )  # whether to mask out image regions
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    # arguments inspection
    assert (args.tasks in ('img2txt', 'vqa2'))
    assert args.enable_butd == True, 'only support region attn! featmap attn deprecated'
    assert (
        not args.scst) or args.dataset == 'coco', 'scst support on coco only!'
    if args.scst:
        assert args.dataset == 'coco', 'scst support on coco only!'
        assert args.max_pred == 0 and args.mask_prob == 0, 'no mask for scst!'
        rl_crit = RewardCriterion()

    if args.enable_butd:
        assert (args.len_vis_input == 100)
        args.region_bbox_file = os.path.join(args.image_root,
                                             args.region_bbox_file)
        args.region_det_file_prefix = os.path.join(
            args.image_root, args.region_det_file_prefix) if args.dataset in (
                'cc', 'coco') and args.region_det_file_prefix != '' else ''

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://localhost:10001',  #args.dist_url,
            world_size=args.world_size,
            rank=args.global_rank)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom
        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {'iter': None, 'score': None}

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir +
        '/.pretrained_model_{}'.format(args.global_rank))
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer

    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_image_regions=args.mask_image_regions,
                mode="s2s",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == 'vqa2'))
        ]
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_image_regions=args.mask_image_regions,
                mode="bi",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == 'vqa2')))

        train_dataset = seq2seq_loader.Img2txtDataset(
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,
            bi_prob=args.bi_prob,
            enable_butd=args.enable_butd,
            tasks=args.tasks)

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs * 1. /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'img2txt' else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, 'must init from maximum likelihood training'
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(
                os.path.join(args.output_dir,
                             "model.{0}.bin".format(recover_step)))
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total * 1. /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0
        if not args.scst:
            model = BertForPreTrainingLossMask.from_pretrained(
                args.bert_model,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                relax_projection=relax_projection,
                config_path=args.config_path,
                task_idx=task_idx_proj,
                max_position_embeddings=args.max_position_embeddings,
                label_smoothing=args.label_smoothing,
                fp32_embedding=args.fp32_embedding,
                cache_dir=args.output_dir +
                '/.pretrained_model_{}'.format(args.global_rank),
                drop_prob=args.drop_prob,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
                tasks=args.tasks)
        else:
            model = BertForSeq2SeqDecoder.from_pretrained(
                args.bert_model,
                max_position_embeddings=args.max_position_embeddings,
                config_path=args.config_path,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                task_idx=task_idx_proj,
                mask_word_id=mask_word_id,
                search_beam_size=1,
                eos_id=eos_word_ids,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input)

        del model_recover
        torch.cuda.empty_cache()

    # deprecated
    # from vlp.resnet import resnet
    # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    # cnn.to(device)
    if args.local_rank != -1:
        try:
            # from apex.parallel import DistributedDataParallel as DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
        # cnn = DDP(cnn)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)
        # cnn = DataParallelImbalance(cnn)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             schedule=args.sche_mode,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir,
                         "optim.{0}.bin".format(recover_step)))
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)
        logger.info("  Loader length = %d", len(train_dataloader))

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              args.num_train_epochs + 1,
                              desc="Epoch"):
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            train_loss = []
            pretext_loss = []
            vqa2_loss = []
            scst_reward = []
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels = batch

                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                if args.enable_butd:
                    conv_feats = img.data  # Bx100x2048
                    vis_pe = vis_pe.data
                else:
                    conv_feats, _ = cnn(img.data)  # Bx2048x7x7
                    conv_feats = conv_feats.view(conv_feats.size(0),
                                                 conv_feats.size(1),
                                                 -1).permute(0, 2,
                                                             1).contiguous()

                if not args.scst:
                    loss_tuple = model(
                        conv_feats,
                        vis_pe,
                        input_ids,
                        segment_ids,
                        input_mask,
                        lm_label_ids,
                        ans_labels,
                        is_next,
                        masked_pos=masked_pos,
                        masked_weights=masked_weights,
                        task_idx=task_idx,
                        vis_masked_pos=vis_masked_pos,
                        mask_image_regions=args.mask_image_regions,
                        drop_worst_ratio=args.max_drop_worst_ratio
                        if i_epoch > args.drop_after else 0)
                    mean_reward = loss_tuple[0].new(1).fill_(0)
                else:
                    # scst training
                    model.eval()
                    position_ids = torch.arange(
                        input_ids.size(1),
                        dtype=input_ids.dtype,
                        device=input_ids.device).unsqueeze(0).expand_as(
                            input_ids)
                    input_dummy = input_ids[:, :args.len_vis_input +
                                            2]  # +2 for [CLS] and [SEP]
                    greedy_res = input_ids.new(
                        input_ids.size(0),
                        input_ids.size(1) - args.len_vis_input - 2).fill_(0)
                    gen_result = input_ids.new(
                        input_ids.size(0),
                        input_ids.size(1) - args.len_vis_input - 2).fill_(0)

                    with torch.no_grad():
                        greedy_res_raw, _ = model(conv_feats,
                                                  vis_pe,
                                                  input_dummy,
                                                  segment_ids,
                                                  position_ids,
                                                  input_mask,
                                                  task_idx=task_idx,
                                                  sample_mode='greedy')
                        for b in range(greedy_res_raw.size(0)):
                            for idx in range(greedy_res_raw.size(1)):
                                if greedy_res_raw[b][idx] not in [
                                        eos_word_ids, pad_word_ids
                                ]:
                                    greedy_res[b][idx] = greedy_res_raw[b][idx]
                                else:
                                    if greedy_res_raw[b][idx] == eos_word_ids:
                                        greedy_res[b][idx] = eos_word_ids
                                    break
                    model.train()
                    gen_result_raw, sample_logprobs = model(
                        conv_feats,
                        vis_pe,
                        input_dummy,
                        segment_ids,
                        position_ids,
                        input_mask,
                        task_idx=task_idx,
                        sample_mode='sample')
                    for b in range(gen_result_raw.size(0)):
                        for idx in range(gen_result_raw.size(1)):
                            if gen_result_raw[b][idx] not in [
                                    eos_word_ids, pad_word_ids
                            ]:
                                gen_result[b][idx] = gen_result_raw[b][idx]
                            else:
                                if gen_result_raw[b][idx] == eos_word_ids:
                                    gen_result[b][idx] = eos_word_ids
                                break

                    gt_ids = input_ids[:, args.len_vis_input + 2:]
                    reward = get_self_critical_reward(greedy_res,
                                                      gt_ids, gen_result,
                                                      gt_ids.size(0))
                    reward = torch.from_numpy(reward).float().to(
                        gen_result.device)
                    mean_reward = reward.mean()
                    loss = rl_crit(sample_logprobs, gen_result.data, reward)

                    loss_tuple = [
                        loss,
                        loss.new(1).fill_(0.),
                        loss.new(1).fill_(0.)
                    ]

                # disable pretext_loss_deprecated for now
                masked_lm_loss, pretext_loss_deprecated, ans_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu. For dist, this is done through gradient addition.
                    masked_lm_loss = masked_lm_loss.mean()
                    pretext_loss_deprecated = pretext_loss_deprecated.mean()
                    ans_loss = ans_loss.mean()
                loss = masked_lm_loss + pretext_loss_deprecated + ans_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())
                train_loss.append(loss.item())
                pretext_loss.append(pretext_loss_deprecated.item())
                vqa2_loss.append(ans_loss.item())
                scst_reward.append(mean_reward.item())
                if step % 100 == 0:
                    logger.info(
                        "Epoch {}, Iter {}, Loss {:.2f}, Pretext {:.2f}, VQA2 {:.2f}, Mean R {:.3f}\n"
                        .format(i_epoch, step, np.mean(train_loss),
                                np.mean(pretext_loss), np.mean(vqa2_loss),
                                np.mean(scst_reward)))

                if args.enable_visdom:
                    if vis_window['iter'] is None:
                        vis_window['iter'] = vis.line(
                            X=np.tile(
                                np.arange((i_epoch - 1) * nbatches + step,
                                          (i_epoch - 1) * nbatches + step + 1),
                                (1, 1)).T,
                            Y=np.column_stack(
                                (np.asarray([np.mean(train_loss)]), )),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total']))
                    else:
                        vis.line(X=np.tile(
                            np.arange((i_epoch - 1) * nbatches + step,
                                      (i_epoch - 1) * nbatches + step + 1),
                            (1, 1)).T,
                                 Y=np.column_stack(
                                     (np.asarray([np.mean(train_loss)]), )),
                                 opts=dict(title='Training Loss',
                                           xlabel='Training Iteration',
                                           ylabel='Loss',
                                           legend=['total']),
                                 win=vis_window['iter'],
                                 update='append')

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(args.output_dir,
                                             "model.{0}.bin".format(i_epoch))
            output_optim_file = os.path.join(args.output_dir,
                                             "optim.{0}.bin".format(i_epoch))
            if args.global_rank in (
                    -1, 0):  # save model if the first device or no dist
                torch.save(
                    copy.deepcopy(model_to_save).cpu().state_dict(),
                    output_model_file)
                # torch.save(optimizer.state_dict(), output_optim_file) # disable for now, need to sanitize state and ship everthing back to cpu

            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
예제 #2
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--local_debug",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()
    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline,
            corpus_preprocessors=corpus_preprocessors)
        train_dataset.initial()
        print(len(train_dataset.ex_list))
        print(train_dataset.batch_size)

        # assert 1==0

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        # c = 0
        # for i_epoch in trange(0, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)):
        #     if args.local_rank != -1:
        #         train_sampler.set_epoch(i_epoch)
        #     iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)',
        #                     disable=args.local_rank not in (-1, 0))
        #     for step, batch in enumerate(iter_bar):
        #         batch = [
        #             t.to(device) if t is not None else None for t in batch]
        #         if args.has_sentence_oracle:
        #             input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
        #         else:
        #             input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch
        #             oracle_pos, oracle_weights, oracle_labels = None, None, None
        #         c += input_ids.shape[0]
        #
        #         # print(input_ids)
        #         #
        #         # print(input_ids.shape)
        #         # print(segment_ids)
        #         # print(segment_ids.shape)
        #         # print(is_next)
        #         # print(task_idx)
        #         # print(sop_label)
        #         # print(task_idx.shape)
        #         # for i in range(input_mask.shape[0]):
        #         #     print(input_mask[i])
        #     print(c)
        #     print(train_dataset.c)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
                else:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch
                    print(sop_label)
                    print(task_idx)
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                # loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next,
                #                    masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx,
                #                    masked_pos_2=oracle_pos, masked_weights_2=oracle_weights,
                #                    masked_labels_2=oracle_labels, mask_qkv=mask_qkv)
                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   sop_label,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv)
                masked_lm_loss, next_sentence_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                print('mask_lm_loss {}'.format(masked_lm_loss))
                print('next_sentence_loss {}'.format(next_sentence_loss))
                print('----------------------------------------------')
                loss = masked_lm_loss + next_sentence_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
예제 #3
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument("--bert_model", default="bert-base-cased", type=str,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--model_recover_path", default=None, type=str,
                        help="The file of fine-tuned pretraining model.")

    # For decoding
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp', action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed', type=int, default=123,
                        help="random seed for initialization")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids', action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size', type=int, default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size', type=int, default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty', type=float, default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word', type=str, default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--max_tgt_length', type=int, default=20,
                        help="maximum length of target sequence")

    # Others for VLP
    parser.add_argument("--src_file", default='/mnt/dat/COCO/annotations/dataset_coco.json', type=str,		
                        help="The input data file name.")		
    parser.add_argument("--ref_file", default='pythia/data/v2_mscoco_val2014_annotations.json', type=str,
                        help="The annotation reference file name.")
    parser.add_argument('--dataset', default='coco', type=str,
                        help='coco | flickr30k | cc')
    parser.add_argument('--len_vis_input', type=int, default=100)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument('--image_root', type=str, default='/mnt/dat/COCO/images')		
    parser.add_argument('--split', type=str, default='val')
    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd', action='store_true',
                        help='set to take in region features')
    parser.add_argument('--region_bbox_file', default='coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str)
    parser.add_argument('--region_det_file_prefix', default='feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval', type=str)
    parser.add_argument("--output_dir",
                        default='tmp',
                        type=str,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument('--file_valid_jpgs', default='', type=str)

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    if args.enable_butd:
        assert(args.len_vis_input == 100)
        args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file)
        args.region_det_file_prefix = os.path.join(args.image_root, args.region_det_file_prefix) if args.dataset in ('cc', 'coco') and args.region_det_file_prefix != '' else ''

    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case)

    args.max_seq_length = args.max_tgt_length + args.len_vis_input + 3 # +3 for 2x[SEP] and [CLS]
    tokenizer.max_len = args.max_seq_length

    bi_uni_pipeline = []
    bi_uni_pipeline = [seq2seq_loader.Preprocess4Seq2seq(0, 0,
        list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length,
        new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': args.len_vis_input,
        'max_len_b': args.max_tgt_length, 'trunc_seg': 'b', 'always_truncate_tail': True},
        mode="bi", len_vis_input=args.len_vis_input, enable_butd=args.enable_butd,
        region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix,
        load_vqa_ann=True)]

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    logger.info('Attempting to recover models from: {}'.format(args.model_recover_path))
    if 0 == len(glob.glob(args.model_recover_path.strip())):
        logger.error('There are no models to recover. The program will exit.')
        sys.exit(1)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model, state_dict=model_recover, num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size, task_idx=0,
            max_position_embeddings=512, cache_dir=args.output_dir+'/.pretrained_model_{}'.format(-1),
            drop_prob=args.drop_prob, enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input, tasks='vqa2')
        del model_recover

	# deprecated
        # from vlp.resnet import resnet		
        # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

        if args.fp16:
            model.half()
            # cnn.half()
        model.to(device)
        # cnn.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
            # cnn = torch.nn.DataParallel(cnn)

        torch.cuda.empty_cache()
        model.eval()
        # cnn.eval()

        eval_lst = []

        img_dat = np.load(args.src_file, allow_pickle=True)
        img_idx = 0
        for i in range(1, img_dat.shape[0]):
            if args.enable_butd:
                src_tk = os.path.join(args.image_root, img_dat[i]['image_name'].split('_')[1],
                    img_dat[i]['feature_path'])
            else:
                raise NotImplementedError
            tgt_tk = tokenizer.tokenize(img_dat[i]['question_str'])
            eval_lst.append((img_idx, src_tk, tgt_tk, img_dat[i]['question_id']))
            img_idx += 1
        input_lines = eval_lst

        next_i = 0
        output_lines = [""] * len(input_lines)
        score_trace_list = [None] * len(input_lines)
        total_batch = math.ceil(len(input_lines) / args.batch_size)
        predictions = []

        print('start the VQA evaluation...')
        with tqdm(total=total_batch) as pbar:
            while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf = [(x[1], x[2]) for x in _chunk]
                buf_id = [(x[0], x[3]) for x in _chunk]
                next_i += args.batch_size
                instances = []
                for instance in buf:
                    for proc in bi_uni_pipeline:
                        instances.append(proc(instance[:2]+({'answers': ['dummy']},)))
                with torch.no_grad():
                    batch = batch_list_to_batch_tensors(
                        instances)
                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, _ = batch

                    if args.fp16:
                        img = img.half()
                        vis_pe = vis_pe.half()

                    if args.enable_butd:
                        conv_feats = img.data # Bx100x2048
                        vis_pe = vis_pe.data
                    else:
                        conv_feats, _ = cnn(img.data) # Bx2048x7x7
                        conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1),
                            -1).permute(0,2,1).contiguous()

                    ans_idx = model(conv_feats, vis_pe, input_ids, segment_ids,
                        input_mask, lm_label_ids, None, is_next, masked_pos=masked_pos,
                        masked_weights=masked_weights, task_idx=task_idx,
                        vis_masked_pos=vis_masked_pos, drop_worst_ratio=0,
                        vqa_inference=True)

                    for ind, (eval_idx, ques_id) in enumerate(buf_id):
                        predictions.append({'question_id': ques_id, 'answer': bi_uni_pipeline[0].ans_proc.idx2word(ans_idx[ind])})

                pbar.update(1)

        results_file = os.path.join(args.output_dir, 'vqa2-results-'+args.model_recover_path.split('/')[-2]+'-'+args.split+'-'+args.model_recover_path.split('/')[-1].split('.')[-2]+'.json')
        json.dump(predictions, open(results_file, 'w'))

        if args.split == 'test2015':
            print('*'*80)
            print('[WARNING] Evaluation unavailable for the test set!\
    \n          Please submit your saved JSON file named\
    \n          `{}`\
    \n          to the VQA 2.0 server:\
    \n          https://evalai.cloudcv.org/web/challenges/challenge-page/163/submission'.format(results_file))
            print('*'*80)
        else:
            import subprocess
            print('Evaluating result file {}'.format(results_file))
            subprocess.Popen(['python', 'pythia/pythia/legacy/eval_model/eval_demo.py', args.ref_file, results_file])
예제 #4
0
def main():
    args = process_args()
    if args.loss_type == 'mlm':
        assert args.neg_num == 0 and args.multiple_neg == 0
    elif args.loss_type == 'nsp':
        assert int(args.bi_prob) == 1 and args.max_pred == 0 and args.neg_num > 0

    if args.adaptive_weight == 1:
        assert args.neg_num > 1
    if args.add_boundary == 1:
        assert args.inc_full_hist

    if args.world_size > 1:
        print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank))

    # Input format: [CLS] img [SEP] hist [SEP_0] ques [SEP_1] ans [SEP]
    args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1

    args.mask_image_regions = (args.vis_mask_prob > 0)  # whether to mask out image regions
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    # arguments inspection
    assert args.enable_butd, 'only support region attn! featmap attn deprecated'

    if args.enable_butd:
        if args.visdial_v == '1.0':
            assert (args.len_vis_input == 36) or (args.len_vis_input == 0)
        elif args.visdial_v == '0.9':
            if (args.len_vis_input == 100):
                args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file)
                args.region_det_file_prefix = os.path.join(args.image_root,
                                                           args.region_det_file_prefix) if args.dataset in (
                    'cc', 'coco') and args.region_det_file_prefix != '' else ''

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__, open(os.path.join(
        args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)
    ch = logging.StreamHandler(sys.stdout)
    ch.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(name)s -   %(message)s'))
    ch.setLevel(logging.INFO)
    logger.addHandler(ch)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url,
                                             world_size=args.world_size, rank=args.global_rank)
    logger.info('Arguments: %s\n' % (' '.join(sys.argv[:])))
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(
        args.train_batch_size / args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom
        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {'iter': None, 'score': None}

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank))
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
    assert args.do_train
    logger.info('Max seq length: %d, batch size: %d\n' % (args.max_seq_length, args.train_batch_size))
    bi_uni_pipeline = [Preprocess4TrainVisdial(args.max_pred, args.mask_prob,
                                               list(tokenizer.vocab.keys()),
                                               tokenizer.convert_tokens_to_ids, args.max_seq_length,
                                               new_segment_ids=args.new_segment_ids,
                                               truncate_config={'len_vis_input': args.len_vis_input,
                                                                'max_len_hist_ques': args.max_len_hist_ques,
                                                                'max_len_ans': args.max_len_ans},
                                               mask_image_regions=args.mask_image_regions,
                                               mode="s2s", vis_mask_prob=args.vis_mask_prob,
                                               region_bbox_file=args.region_bbox_file,
                                               region_det_file_prefix=args.region_det_file_prefix,
                                               image_features_hdfpath=args.image_features_hdfpath,
                                               visdial_v=args.visdial_v, pad_hist=args.pad_hist,
                                               finetune=args.finetune,
                                               only_mask_ans=args.only_mask_ans,
                                               add_boundary=args.add_boundary,
                                               only_qa=args.only_qa),
                       Preprocess4TrainVisdial(args.max_pred, args.mask_prob,
                                               list(tokenizer.vocab.keys()),
                                               tokenizer.convert_tokens_to_ids, args.max_seq_length,
                                               new_segment_ids=args.new_segment_ids,
                                               truncate_config={'len_vis_input': args.len_vis_input,
                                                                'max_len_hist_ques': args.max_len_hist_ques,
                                                                'max_len_ans': args.max_len_ans},
                                               mask_image_regions=args.mask_image_regions,
                                               mode="bi", vis_mask_prob=args.vis_mask_prob,
                                               region_bbox_file=args.region_bbox_file,
                                               region_det_file_prefix=args.region_det_file_prefix,
                                               image_features_hdfpath=args.image_features_hdfpath,
                                               visdial_v=args.visdial_v, pad_hist=args.pad_hist,
                                               finetune=args.finetune,
                                               only_mask_ans=args.only_mask_ans,
                                               add_boundary=args.add_boundary,
                                               only_qa=args.only_qa)]

    train_dataset = VisdialDataset(
        args.src_file, args.train_batch_size, data_tokenizer, use_num_imgs=args.use_num_imgs,
        bi_uni_pipeline=bi_uni_pipeline, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob,
        is_train=args.do_train, neg_num=args.neg_num, inc_gt_rel=args.inc_gt_rel,
        inc_full_hist=args.inc_full_hist, just_for_pretrain=args.just_for_pretrain, sub_sample=args.sub_sample)

    if args.world_size == 1:
        train_sampler = RandomSampler(train_dataset, replacement=False)
    else:
        train_sampler = DistributedSampler(train_dataset)

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.train_batch_size, sampler=train_sampler,
                                                   num_workers=args.num_workers,
                                                   collate_fn=batch_list_to_batch_tensors, pin_memory=True)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(len(train_dataloader) * args.num_train_epochs * 1. /
                  args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'img2txt' else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    if (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, 'must init from maximum likelihood training'
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size, relax_projection=relax_projection,
            config_path=args.config_path, task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob, enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type,
            neg_num=args.neg_num, adaptive_weight=args.adaptive_weight, add_attn_fuse=args.add_attn_fuse,
            no_h0=args.no_h0, no_vision=args.no_vision)
        global_step = 0
    else:
        if args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0

        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model, state_dict=model_recover, num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size, relax_projection=relax_projection,
            config_path=args.config_path, task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob, enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type,
            neg_num=args.neg_num, adaptive_weight=args.adaptive_weight, add_attn_fuse=args.add_attn_fuse,
            no_h0=args.no_h0, no_vision=args.no_vision)

        del model_recover
        torch.cuda.empty_cache()

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
        model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
    elif n_gpu > 1:
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(
                optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(
                optimizer, static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             schedule=args.sche_mode,
                             t_total=t_total)


    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)
        logger.info("  Loader length = %d", len(train_dataloader))

        model.train()

        start_epoch = 1
        logger.info("Begin training from epoch = %d", start_epoch)
        t0 = time.time()
        for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"):
            if args.multiple_neg and i_epoch > 1:
                train_dataset = VisdialDataset(
                    args.src_file, args.train_batch_size, data_tokenizer, use_num_imgs=args.use_num_imgs,
                    bi_uni_pipeline=bi_uni_pipeline, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob,
                    is_train=args.do_train, neg_num=args.neg_num, inc_gt_rel=args.inc_gt_rel,
                    inc_full_hist=args.inc_full_hist, just_for_pretrain=args.just_for_pretrain, sub_sample=args.sub_sample)

                if args.world_size == 1:
                    train_sampler = RandomSampler(train_dataset, replacement=False)
                else:
                    train_sampler = DistributedSampler(train_dataset)

                train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                               batch_size=args.train_batch_size, sampler=train_sampler,
                                                               num_workers=args.num_workers,
                                                               collate_fn=batch_list_to_batch_tensors, pin_memory=True)
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            losses = []
            pretext_loss = []
            mlm_losses = []
            nsp_losses = []
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, \
                task_idx, vis_masked_pos, img, vis_pe = batch

                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                if args.enable_butd:
                    conv_feats = img.data  # Bx100x2048
                    vis_pe = vis_pe.data

                loss_tuple = model(conv_feats, vis_pe, input_ids, segment_ids,
                                   input_mask, lm_label_ids, is_next, masked_pos=masked_pos,
                                   masked_weights=masked_weights, task_idx=task_idx,
                                   vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions,
                                   drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0)

                # disable pretext_loss_deprecated for now
                masked_lm_loss, pretext_loss_deprecated, nsp_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu. For dist, this is done through gradient addition.
                    masked_lm_loss = masked_lm_loss.mean()
                    pretext_loss_deprecated = pretext_loss_deprecated.mean()
                    nsp_loss = nsp_loss.mean()
                loss = masked_lm_loss + pretext_loss_deprecated + nsp_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())
                losses.append(loss.item())
                mlm_losses.append(masked_lm_loss.item())
                pretext_loss.append(pretext_loss_deprecated.item())
                nsp_losses.append(nsp_loss.item())

                if step % max(1, nbatches // 10) == 0:
                    logger.info(
                        "Epoch {}, Iter {}, Loss {:.4f}, MLM {:.4f}, NSP {:.4f}, Elapse time {:.2f}\n".format(
                            i_epoch, step, np.mean(losses), np.mean(mlm_losses), np.mean(nsp_losses), time.time() - t0))

                if args.enable_visdom:
                    if vis_window['iter'] is None:
                        vis_window['iter'] = vis.line(
                            X=np.tile(np.arange((i_epoch - 1) * nbatches + step,
                                                (i_epoch - 1) * nbatches + step + 1), (1, 1)).T,
                            Y=np.column_stack((np.asarray([np.mean(losses)]),)),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total'])
                        )
                    else:
                        vis.line(
                            X=np.tile(np.arange((i_epoch - 1) * nbatches + step,
                                                (i_epoch - 1) * nbatches + step + 1), (1, 1)).T,
                            Y=np.column_stack((np.asarray([np.mean(losses)]),)),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total']),
                            win=vis_window['iter'],
                            update='append'
                        )

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                                   warmup_linear(global_step / t_total,
                                                 args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(
                args.output_dir, "model.%d.%.3f.bin" % (i_epoch, np.mean(losses)))
            if args.global_rank in (-1, 0):  # save model if the first device or no dist
                torch.save(copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file)
                logger.info("Save model to %s", output_model_file)
            logger.info("Finish training epoch %d, avg loss: %.2f and takes %.2f seconds" % (
                i_epoch, np.mean(losses), time.time() - t0))
            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
예제 #5
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--topic_model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--topic_model_dict_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument('--topic_mode',
                        default=1,
                        type=float,
                        help="1:idea1 1.1:idea1_wo_theta 2:idea2 ")
    parser.add_argument('--topic_model',
                        default=False,
                        type=bool,
                        help="if only use topic model")
    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=192,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--train_batch_size",
        default=32,
        type=int,
        help="Total batch size for training.")  #batch_size = batch_size/n_gpus
    parser.add_argument("--eval_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=30,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()
    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)
    print("args.local_rank", args.local_rank)
    print("args.no_cuda", args.no_cuda)
    if args.local_rank == -1 or args.no_cuda:  #-1 False
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")  #device = cuda
        n_gpu = torch.cuda.device_count()
        print("n_gpu_1", n_gpu)
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
        print("n_gpu_1", n_gpu)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()
    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.data_dir,
            args.topic_model_dict_path,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)
    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2    ### type_vocab_size=6
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        unilm = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
        global_step = 0

    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:  # here is the entrance
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        unilm = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
    #1. 模型初始化,入口定义好
    gsm = GSM(train_dataset.vocabsize)
    gsm_checkpoint = torch.load(args.topic_model_recover_path)
    gsm.load_state_dict(gsm_checkpoint["net"])

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        unilm.half()
        gsm.half()
        if args.fp32_embedding:
            unilm.bert.embeddings.word_embeddings.float()
            unilm.bert.embeddings.position_embeddings.float()
            unilm.bert.embeddings.token_type_embeddings.float()
    unilm.to(device)
    gsm.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        unilm = DDP(unilm,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        unilm = DataParallelImbalance(unilm)
        gsm = DataParallelImbalance(gsm)
    # Prepare optimizer
    total = 0
    param_optimizer = list(unilm.named_parameters())
    param_optimizer_topic = list(gsm.named_parameters())
    for name, parameters in unilm.named_parameters():
        if "idea" in name:
            if "11" in name and "idea2" in name:
                total += np.prod(parameters.size())
                # print(name, ':', parameters.size())
        else:
            total += np.prod(parameters.size())
            # print(name, ':', parameters.size())

    print("gsm have {} paramerters in total".format(
        sum(x.numel() for x in gsm.parameters())))
    print("Number of parameter: %.6fM" % (total / 1e6))
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    if not args.topic_model:
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01,
            'topic':
            False
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0,
            'topic':
            False
        }, {
            'params': [p for n, p in param_optimizer_topic],
            'weight_decay':
            0.0,
            'lr':
            1e-3,
            'topic':
            True
        }]
    else:
        optimizer_grouped_parameters = [{
            'params': [p for n, p in param_optimizer_topic],
            'weight_decay':
            0.0,
            'lr':
            1e-3,
            'topic':
            True
        }]
    #一部分是有weight的,一部分是没有weight_dacay的
    # print("optimizer_grouped_parameters", optimizer_grouped_parameters)
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)
    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        unilm.train()
        gsm.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        print("000000", args.local_rank, start_epoch,
              int(args.num_train_epochs) + 1)
        topicloss = []
        unilmloss = []
        topicloss_lst = []
        unilmloss_lst = []
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):

            loss_sum = 0.0
            ppx_sum = 0.0
            word_count = 0.0
            doc_count = 0.0
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:  #false
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch
                else:  #这里加了bows
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, bows = batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                p_x, mus, log_vars, theta, beta, topic_embedding = gsm(bows)
                if not args.topic_model:
                    loss_tuple = unilm(input_ids,
                                       theta,
                                       beta,
                                       topic_embedding,
                                       args.topic_mode,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv)
                    masked_lm_loss, next_sentence_loss = loss_tuple

                ## topic loss
                logsoftmax = torch.log(p_x + 1e-10)
                rec_loss = -1.0 * torch.sum(
                    bows * logsoftmax
                )  #bows*logsoftmax = [batch_size, |V|], 其中torch.sum 把所有的loss全部加起来了,也可以只用加某一维度。
                rec_loss_per = -1.0 * torch.sum(bows * logsoftmax, dim=1)
                rec_loss_per = rec_loss_per.cpu().detach().numpy()

                kl_div = -0.5 * torch.sum(1 + log_vars - mus.pow(2) -
                                          log_vars.exp())
                loss_topic = rec_loss + kl_div
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    loss_topic = loss_topic.mean()
                    if not args.topic_model:
                        masked_lm_loss = masked_lm_loss.mean()
                        next_sentence_loss = next_sentence_loss.mean()
                if not args.topic_model:
                    loss_unilm = masked_lm_loss + next_sentence_loss

                # cal perplexity
                word_count_list = []
                loss_sum += loss_topic.item()
                for bow in bows:
                    word_num = torch.sum(bow).cpu().numpy()
                    word_count_list.append(word_num)
                    word_count += word_num

                word_count_np = np.array(word_count_list)
                doc_count += len(bows)
                ppx_sum += np.sum(np.true_divide(rec_loss_per, word_count_np))
                topicloss_lst.append(loss_topic.item() / len(bows))
                if not args.topic_model:
                    unilmloss_lst.append(loss_unilm.item())
                #topic_loss end
                if not args.topic_model:
                    loss = loss_unilm + loss_topic
                else:
                    loss = loss_topic
                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:  # =1
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            if not param_group['topic']:
                                param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                if not args.topic_model:
                    iter_bar.set_description(
                        'Iter (loss_unilm=%5.3f),Iter (ppl=%5.3f)' %
                        (loss_unilm.item(),
                         np.sum(np.true_divide(rec_loss_per, word_count_np))))
                else:
                    iter_bar.set_description(
                        'Iter (loss_topic=%5.3f), (ppl=%5.3f)' %
                        (loss_topic.item(),
                         np.sum(np.true_divide(rec_loss_per, word_count_np))))
            #Save a trained model
            ppx_word = np.exp(loss_sum / word_count)
            ppx_document = np.exp(ppx_sum / doc_count)
            print("********")
            print("word_count", word_count)
            print("ppx_word", ppx_word)
            print("ppx_document", ppx_document)
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                #save unilm model
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                unilm_model_to_save = unilm.module if hasattr(
                    unilm, 'module') else unilm  # Only save the model it-self
                output_unilm_model_file = os.path.join(
                    args.output_dir, "unilm.{0}.bin".format(i_epoch))
                torch.save(unilm_model_to_save.state_dict(),
                           output_unilm_model_file)
                #save topic model
                logger.info(
                    "** ** * Saving topic model and optimizer ** ** * ")
                topic_model_to_save = gsm.module if hasattr(
                    gsm, 'module') else gsm  # Only save the model it-self
                output_topic_model_file = os.path.join(
                    args.output_dir, "topic.{0}.ckpt".format(i_epoch))
                torch.save(topic_model_to_save.state_dict(),
                           output_topic_model_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
        smth_pts = smooth_curve(topicloss_lst)
        # plt.plot(range(len(topicloss_lst)), topicloss_lst)
        plt.plot(range(len(smth_pts)), smth_pts)
        plt.xlabel('epochs')
        plt.title('Topic Model Train Loss')
        plt.savefig(args.output_dir + '/topic_loss.png')
        plt.cla()
        plt.plot(range(len(unilmloss_lst)), unilmloss_lst)
        plt.xlabel('epochs')
        plt.title('Unilm Train Loss')
        plt.savefig(args.output_dir + '/unilm_loss.png')
예제 #6
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.",
    )
    parser.add_argument(
        "--config_path", default=None, type=str, help="Bert config file path."
    )
    parser.add_argument(
        "--output_dir",
        default="tmp",
        type=str,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--log_file",
        default="eval.log",
        type=str,
        help="The output directory where the log will be written.",
    )
    parser.add_argument(
        "--model_recover_path",
        default=None,
        type=str,
        help="The file of fine-tuned pretraining model.",
    )
    parser.add_argument(
        "--do_train",
        action="store_true",
        help="Whether to run training. This should ALWAYS be set to True.",
    )
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model.",
    )
    parser.add_argument(
        "--train_batch_size",
        default=64,
        type=int,
        help="Total batch size for training.",
    )
    parser.add_argument(
        "--learning_rate",
        default=3e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--label_smoothing",
        default=0,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--weight_decay",
        default=0.01,
        type=float,
        help="The weight decay rate for Adam.",
    )
    parser.add_argument(
        "--finetune_decay",
        action="store_true",
        help="Weight decay to the original weights.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=30,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--global_rank",
        type=int,
        default=-1,
        help="global_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--fp32_embedding",
        action="store_true",
        help="Whether to use 32-bit float precision instead of 32-bit for embeddings",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--amp", action="store_true", help="Whether to use amp for fp16"
    )
    parser.add_argument(
        "--from_scratch",
        action="store_true",
        help="Initialize parameters with random values (i.e., training from scratch).",
    )
    parser.add_argument(
        "--new_segment_ids",
        action="store_true",
        help="Use new segment ids for bi-uni-directional LM.",
    )
    parser.add_argument(
        "--tokenized_input", action="store_true", help="Whether the input is tokenized."
    )
    parser.add_argument(
        "--len_vis_input",
        type=int,
        default=100,
        help="The length of visual token input",
    )
    parser.add_argument(
        "--max_len_b",
        type=int,
        default=20,
        help="Truncate_config: maximum length of segment B.",
    )
    parser.add_argument(
        "--trunc_seg",
        default="b",
        help="Truncate_config: first truncate segment A/B (option: a, b).",
    )
    parser.add_argument(
        "--always_truncate_tail",
        action="store_true",
        help="Truncate_config: Whether we should always truncate tail.",
    )
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help="Number of prediction is sometimes less than max_pred when sequence is short.",
    )
    parser.add_argument(
        "--max_pred", type=int, default=3, help="Max tokens of prediction."
    )
    parser.add_argument(
        "--num_workers",
        default=4,
        type=int,
        help="Number of workers for the data loader.",
    )
    parser.add_argument(
        "--max_position_embeddings",
        type=int,
        default=None,
        help="max position embeddings",
    )

    # Others for VLP
    parser.add_argument(
        "--src_file",
        default=["/mnt/dat/COCO/annotations/dataset_coco.json"],
        type=str,
        nargs="+",
        help="The input data file name.",
    )
    parser.add_argument("--enable_visdom", action="store_true")
    parser.add_argument("--visdom_port", type=int, default=8888)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument("--image_root", type=str, default="/mnt/dat/COCO/images")
    parser.add_argument(
        "--dataset", default="coco", type=str, help="coco | flickr30k | cc"
    )
    parser.add_argument("--split", type=str, nargs="+", default=["train", "restval"])

    parser.add_argument(
        "--world_size", default=1, type=int, help="number of distributed processes"
    )
    parser.add_argument(
        "--dist_url",
        default="file://[PT_OUTPUT_DIR]/nonexistent_file",
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument(
        "--file_valid_jpgs",
        default="/mnt/dat/COCO/annotations/coco_valid_jpgs.json",
        type=str,
    )
    parser.add_argument(
        "--sche_mode",
        default="warmup_linear",
        type=str,
        help="warmup_linear | warmup_constant | warmup_cosine",
    )
    parser.add_argument("--drop_prob", default=0.1, type=float)
    parser.add_argument("--use_num_imgs", default=-1, type=int)
    parser.add_argument("--vis_mask_prob", default=0, type=float)
    parser.add_argument("--max_drop_worst_ratio", default=0, type=float)
    parser.add_argument("--drop_after", default=6, type=int)

    parser.add_argument(
        "--s2s_prob",
        default=1,
        type=float,
        help="Percentage of examples that are bi-uni-directional LM (seq2seq).",
    )
    parser.add_argument(
        "--bi_prob",
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.",
    )
    parser.add_argument(
        "--enable_butd", action="store_true", help="set to take in region features"
    )
    parser.add_argument(
        "--region_bbox_file",
        default="coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5",
        type=str,
    )
    parser.add_argument(
        "--region_det_file_prefix",
        default="feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval",
        type=str,
    )
    parser.add_argument("--tasks", default="img2txt", help="img2txt | vqa2")
    parser.add_argument(
        "--relax_projection",
        action="store_true",
        help="Use different projection layers for tasks.",
    )
    parser.add_argument(
        "--scst", action="store_true", help="Self-critical sequence training"
    )

    args = parser.parse_args()

    print("global_rank: {}, local rank: {}".format(args.global_rank, args.local_rank))

    args.max_seq_length = (
        args.max_len_b + args.len_vis_input + 3
    )  # +3 for 2x[SEP] and [CLS]
    args.mask_image_regions = (
        args.vis_mask_prob > 0
    )  # whether to mask out image regions
    args.dist_url = args.dist_url.replace("[PT_OUTPUT_DIR]", args.output_dir)

    # arguments inspection
    assert args.tasks in ("img2txt", "vqa2")
    assert args.enable_butd == True, "only support region attn! featmap attn deprecated"
    assert (not args.scst) or args.dataset == "coco", "scst support on coco only!"
    if args.scst:
        assert args.dataset == "coco", "scst support on coco only!"
        assert args.max_pred == 0 and args.mask_prob == 0, "no mask for scst!"
        rl_crit = RewardCriterion()

    if args.enable_butd:
        assert args.len_vis_input == 100
        args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file)
        args.region_det_file_prefix = (
            os.path.join(args.image_root, args.region_det_file_prefix)
            if args.dataset in ("cc", "coco") and args.region_det_file_prefix != ""
            else ""
        )

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(
        args.__dict__,
        open(os.path.join(args.output_dir, "eval_opt.json"), "w"),
        sort_keys=True,
        indent=2,
    )

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode="w",
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(
            backend="nccl",
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.global_rank,
        )
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                args.gradient_accumulation_steps
            )
        )

    args.train_batch_size = int(
        args.train_batch_size / args.gradient_accumulation_steps
    )

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom

        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {"iter": None, "score": None}

    # preprocessing/data loader
    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank),
    )
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer

    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    "max_len_b": args.max_len_b,
                    "trunc_seg": args.trunc_seg,
                    "always_truncate_tail": args.always_truncate_tail,
                },
                mask_image_regions=args.mask_image_regions,
                mode="s2s",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == "vqa2"),
            )
        ]
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    "max_len_b": args.max_len_b,
                    "trunc_seg": args.trunc_seg,
                    "always_truncate_tail": args.always_truncate_tail,
                },
                mask_image_regions=args.mask_image_regions,
                mode="bi",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == "vqa2"),
            )
        )

        train_dataset = seq2seq_loader.Img2txtDataset(
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,
            bi_prob=args.bi_prob,
            enable_butd=args.enable_butd,
            tasks=args.tasks,
        )

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True,
        )

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(
        len(train_dataloader)
        * args.num_train_epochs
        * 1.0
        / args.gradient_accumulation_steps
    )

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp

        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == "img2txt" else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"]
    )  # index in BERT vocab: 103, 102, 0

    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, "must init from maximum likelihood training"
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir
            + "/.pretrained_model_{}".format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks,
        )
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(
                os.path.join(args.output_dir, "model.{0}.bin".format(recover_step))
            )
            # recover_step == number of epochs
            global_step = math.floor(
                recover_step * t_total * 1.0 / args.num_train_epochs
            )
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****", args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0
        if not args.scst:
            model = BertForPreTrainingLossMask.from_pretrained(
                args.bert_model,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                relax_projection=relax_projection,
                config_path=args.config_path,
                task_idx=task_idx_proj,
                max_position_embeddings=args.max_position_embeddings,
                label_smoothing=args.label_smoothing,
                fp32_embedding=args.fp32_embedding,
                cache_dir=args.output_dir
                + "/.pretrained_model_{}".format(args.global_rank),
                drop_prob=args.drop_prob,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
                tasks=args.tasks,
            )
        else:
            model = BertForSeq2SeqDecoder.from_pretrained(
                args.bert_model,
                max_position_embeddings=args.max_position_embeddings,
                config_path=args.config_path,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                task_idx=task_idx_proj,
                mask_word_id=mask_word_id,
                search_beam_size=1,
                eos_id=eos_word_ids,
                mode="s2s",
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
            )

        del model_recover
        torch.cuda.empty_cache()

    # deprecated
    # from vlp.resnet import resnet
    # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    # cnn.to(device)
    if args.local_rank != -1:
        try:
            # from apex.parallel import DistributedDataParallel as DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )
        # cnn = DDP(cnn)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)
        # cnn = DataParallelImbalance(cnn)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.01,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            bias_correction=False,
            max_grad_norm=1.0,
        )
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(
                optimizer, static_loss_scale=args.loss_scale
            )
    else:
        optimizer = BertAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            warmup=args.warmup_proportion,
            schedule=args.sche_mode,
            t_total=t_total,
        )

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step))
        )
        if hasattr(optim_recover, "state_dict"):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        model.eval()

        losses = []
        for batch in tqdm(train_dataloader):
            # wrangle batch
            batch = [t.to(device) for t in batch]
            (
                input_ids,
                segment_ids,
                input_mask,
                lm_label_ids,
                masked_pos,
                masked_weights,
                is_next,
                task_idx,
                img,
                vis_masked_pos,
                vis_pe,
                ans_labels,
            ) = batch

            if args.fp16:
                img = img.half()
                vis_pe = vis_pe.half()

            if args.enable_butd:
                conv_feats = img.data  # Bx100x2048
                vis_pe = vis_pe.data
            else:
                conv_feats, _ = cnn(img.data)  # Bx2048x7x7
                conv_feats = (
                    conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1)
                    .permute(0, 2, 1)
                    .contiguous()
                )

            # compute loss
            masked_lm_loss, _, _ = model(
                conv_feats,
                vis_pe,
                input_ids,
                segment_ids,
                input_mask,
                lm_label_ids,
                ans_labels,
                is_next,
                masked_pos=masked_pos,
                masked_weights=masked_weights,
                task_idx=task_idx,
                vis_masked_pos=vis_masked_pos,
                mask_image_regions=args.mask_image_regions,
                drop_worst_ratio=args.max_drop_worst_ratio
            )

            # average across multiple GPUs
            if n_gpu > 1:
                masked_lm_loss = masked_lm_loss.mean()

            losses.append(masked_lm_loss.item())
        
        print(args.split, 'perplexity:', np.exp(np.mean(losses)))
예제 #7
0
파일: train.py 프로젝트: kyang888/unilm
def main():
    parser = argparse.ArgumentParser()
    # Path parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The raw data dir.")
    parser.add_argument("--vocab_path",
                        default=None,
                        type=str,
                        required=True,
                        help="bert vocab path")
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--model_output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The param init of pretrain or finetune")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    # Data Process Parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")
    # Model Paramters
    parser.add_argument("--sop",
                        action='store_true',
                        help="whether use sop task.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")

    # Train Eval Test Paramters

    parser.add_argument("--checkpoint_steps",
                        required=True,
                        type=int,
                        help="save model eyery checkpoint_steps")

    parser.add_argument("--total_steps",
                        required=True,
                        type=int,
                        help="all steps of training model")

    parser.add_argument("--max_checkpoint",
                        required=True,
                        type=int,
                        help="max saved model in model_output_dir")

    parser.add_argument(
        "--examples_size_once",
        type=int,
        default=1000,
        help="read how many examples every time in pretrain or finetune")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="process rank in local")
    parser.add_argument("--local_debug",
                        action='store_true',
                        help="whether debug")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--fine_tune",
                        action='store_true',
                        help="Whether to run fine_tune.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates   accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=str,
        default='dynamic',
        help=
        '(float or str, optional, default=None):  Optional property override.  '
        'If passed as a string,must be a string representing a number, e.g., "128.0", or the string "dynamic".'
    )
    parser.add_argument(
        '--opt_level',
        type=str,
        default='O1',
        help=
        ' (str, optional, default="O1"):  Pure or mixed precision optimization level.  '
        'Accepted values are "O0", "O1", "O2", and "O3", explained in detail above.'
    )
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )

    # Other Patameters
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--rank',
                        type=int,
                        default=0,
                        help="global rank of current process")
    parser.add_argument("--world_size",
                        default=2,
                        type=int,
                        help="Number of process(显卡)")

    args = parser.parse_args()
    cur_env = os.environ
    args.rank = int(cur_env.get('RANK', -1))
    args.world_size = int(cur_env.get('WORLD_SIZE', -1))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)
    assert args.train_batch_size >= 1, 'batch_size < 1 '

    # 更新一次模型参数需要多少个样本
    examples_per_update = args.world_size * args.train_batch_size * args.gradient_accumulation_steps
    args.examples_size_once = args.examples_size_once // examples_per_update * examples_per_update
    if args.fine_tune:
        args.examples_size_once = examples_per_update

    os.makedirs(args.model_output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.model_output_dir, 'unilm_config.json'),
                   'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = torch.cuda.device_count()
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=args.world_size,
                                rank=args.rank)
    logger.info(
        "world_size:{}, rank:{}, device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}"
        .format(args.world_size, args.rank, device, n_gpu,
                bool(args.world_size > 1), args.fp16))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.fine_tune and not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.vocab_path,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    if args.local_rank == 0:
        dist.barrier()
    bi_uni_pipeline = [
        Preprocess4Seq2seq(args.max_pred,
                           args.mask_prob,
                           list(tokenizer.vocab.keys()),
                           tokenizer.convert_tokens_to_ids,
                           args.max_seq_length,
                           new_segment_ids=args.new_segment_ids,
                           truncate_config={
                               'max_len_a': args.max_len_a,
                               'max_len_b': args.max_len_b,
                               'trunc_seg': args.trunc_seg,
                               'always_truncate_tail':
                               args.always_truncate_tail
                           },
                           mask_source_words=args.mask_source_words,
                           skipgram_prb=args.skipgram_prb,
                           skipgram_size=args.skipgram_size,
                           mask_whole_word=args.mask_whole_word,
                           mode="s2s",
                           has_oracle=args.has_sentence_oracle,
                           num_qkv=args.num_qkv,
                           s2s_special_token=args.s2s_special_token,
                           s2s_add_segment=args.s2s_add_segment,
                           s2s_share_segment=args.s2s_share_segment,
                           pos_shift=args.pos_shift,
                           fine_tune=args.fine_tune)
    ]
    file_oracle = None
    if args.has_sentence_oracle:
        file_oracle = os.path.join(args.data_dir, 'train.oracle')

    # t_total表示模型参数更新的次数
    # t_total = args.train_steps
    # Prepare model
    recover_step = _get_max_epoch_model(args.model_output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_model_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * args.checkpoint_step)
        # 预训练时模型的参数初始化,比如使用chinese-bert-base的模型参数进行初始化
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)

    total_trainable_params = sum(p.numel() for p in model.parameters()
                                 if p.requires_grad)
    logger.info("模型参数: {}".format(total_trainable_params))
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=args.total_steps)
    if args.amp and args.fp16:
        from apex import amp
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale=args.loss_scale)
        from apex.parallel import DistributedDataParallel as DDP
        model = DDP(model)
    else:
        from torch.nn.parallel import DistributedDataParallel as DDP
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)

    if recover_step:
        logger.info("** ** * Recover optimizer: %d * ** **", recover_step)
        optim_recover = torch.load(os.path.join(
            args.model_output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.fp16 and args.amp:
            amp_recover = torch.load(os.path.join(
                args.model_output_dir, "amp.{0}.bin".format(recover_step)),
                                     map_location='cpu')
            logger.info("** ** * Recover amp: %d * ** **", recover_step)
            amp.load_state_dict(amp_recover)
    logger.info("** ** * CUDA.empty_cache() * ** **")
    torch.cuda.empty_cache()

    if args.rank == 0:
        writer = SummaryWriter(log_dir=args.log_dir)
    logger.info("***** Running training *****")
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Param Update Num = %d", args.total_steps)
    model.train()

    PRE = "rank{},local_rank {},".format(args.rank, args.local_rank)
    step = 1
    start = time.time()
    train_data_loader = TrainDataLoader(
        bi_uni_pipline=bi_uni_pipeline,
        examples_size_once=args.examples_size_once,
        world_size=args.world_size,
        train_batch_size=args.train_batch_size,
        num_workers=args.num_workers,
        data_dir=args.data_dir,
        tokenizer=tokenizer,
        max_len=args.max_seq_length)
    best_result = -float('inf')
    for global_step, batch in enumerate(train_data_loader, start=global_step):
        batch = [t.to(device) if t is not None else None for t in batch]
        if args.has_sentence_oracle:
            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
        else:
            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label = batch
            oracle_pos, oracle_weights, oracle_labels = None, None, None
        if not args.sop:
            # 不使用sop训练任务
            sop_label = None
        loss_tuple = model(input_ids,
                           segment_ids,
                           input_mask,
                           masked_lm_labels=lm_label_ids,
                           next_sentence_label=sop_label,
                           masked_pos=masked_pos,
                           masked_weights=masked_weights,
                           task_idx=task_idx,
                           masked_pos_2=oracle_pos,
                           masked_weights_2=oracle_weights,
                           masked_labels_2=oracle_labels,
                           mask_qkv=mask_qkv)
        masked_lm_loss, next_sentence_loss = loss_tuple
        # mean() to average on multi-gpu.
        if n_gpu > 1:
            masked_lm_loss = masked_lm_loss.mean()
            next_sentence_loss = next_sentence_loss.mean()
        # ensure that accumlated gradients are normalized
        if args.gradient_accumulation_steps > 1:
            masked_lm_loss = masked_lm_loss / args.gradient_accumulation_steps
            next_sentence_loss = next_sentence_loss / args.gradient_accumulation_steps
        if not args.sop:
            loss = masked_lm_loss
        else:
            loss = masked_lm_loss + next_sentence_loss
        if args.fp16 and args.amp:
            with amp.scale_loss(loss, optimizer) as scale_loss:
                scale_loss.backward()
        else:
            loss.backward()
        if (global_step + 1) % args.gradient_accumulation_steps == 0:
            if args.rank == 0:
                writer.add_scalar('unilm/mlm_loss', masked_lm_loss,
                                  global_step)
                writer.add_scalar('unilm/sop_loss', next_sentence_loss,
                                  global_step)
            lr_this_step = args.learning_rate * warmup_linear(
                global_step / args.total_steps, args.warmup_proportion)
            if args.fp16:
                # modify learning rate with special warm up BERT uses
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
            optimizer.step()
            optimizer.zero_grad()
            #global_step += 1
            #更新一次模型参数花费的时间,单位:秒
            cost_time_per_update = time.time() - start
            # 更新完所有参数花费的时间,单位:小时
            need_time = cost_time_per_update * (args.total_steps -
                                                global_step) / 3600.0
            cost_time_per_chectpoint = cost_time_per_update * args.checkpoint_steps / 3600.0
            start = time.time()
            if args.local_rank in [-1, 0]:
                INFO = PRE + '当前/chcklpoint_steps/total:{}/{}/{},loss{}/{},更新一次参数{}秒,checkpoint_steps {}小时,' \
                             '训练完成{}小时\n'.format(global_step, args.checkpoint_steps, args.total_steps,
                                                 round(masked_lm_loss.item(), 5),
                                                 round(next_sentence_loss.item(), 5), round(cost_time_per_update, 4),
                                                 round(cost_time_per_chectpoint, 3), round(need_time, 3))
                print(INFO)
        # Save a trained model
        if (global_step + 1) % args.checkpoint_steps == 0:
            checkpoint_index = (global_step + 1) % args.checkpoint_steps
            if args.rank >= 0:
                train_data_loader.train_sampler.set_epoch(checkpoint_index)
            # if args.eval:
            #     # 如果是pretrain,验证MLM;如果微调,验证评价指标
            #     result = None
            #if best_result < result and _get_checkpont_num(args.model_output_num):
            if args.rank in [0, -1]:
                logger.info("** ** * Saving  model and optimizer * ** **")

                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.model_output_dir,
                    "model.{0}.bin".format(checkpoint_index))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.model_output_dir,
                    "optim.{0}.bin".format(checkpoint_index))
                torch.save(optimizer.state_dict(), output_optim_file)
                if args.fp16 and args.amp:
                    logger.info("** ** * Saving  amp state  * ** **")
                    output_amp_file = os.path.join(
                        args.model_output_dir,
                        "amp.{0}.bin".format(checkpoint_index))
                    torch.save(amp.state_dict(), output_amp_file)
                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
    if args.rank == 0:
        writer.close()
        print('** ** * train finished * ** **')
예제 #8
0
def main():
    args = process_args()
    os.makedirs(args.output_dir, exist_ok=True)

    if args.enable_butd:
        if args.visdial_v == '1.0':
            assert (args.len_vis_input == 36)
        elif args.visdial_v == '0.9':
            assert (args.len_vis_input == 100)
            args.region_bbox_file = os.path.join(args.image_root,
                                                 args.region_bbox_file)
            args.region_det_file_prefix = os.path.join(
                args.image_root,
                args.region_det_file_prefix) if args.dataset in (
                    'cc', 'coco') and args.region_det_file_prefix != '' else ''

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1

    tokenizer.max_len = args.max_seq_length
    bi_uni_pipeline = [
        Preprocess4TestVisdialDiscTest(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'len_vis_input': args.len_vis_input,
                'max_len_hist_ques': args.max_len_hist_ques,
                'max_len_ans': args.max_len_ans
            },
            mode="bi",
            region_bbox_file=args.region_bbox_file,
            region_det_file_prefix=args.region_det_file_prefix,
            image_features_hdfpath=args.image_features_hdfpath,
            visdial_v=args.visdial_v,
            pad_hist=args.pad_hist,
            inc_full_hist=args.inc_full_hist)
    ]

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    logger.info('Attempting to recover models from: {}'.format(
        args.model_recover_path))
    if 0 == len(glob.glob(args.model_recover_path.strip())):
        logger.error('There are no models to recover. The program will exit.')
        sys.exit(1)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            task_idx=0,
            max_position_embeddings=512,
            cache_dir=args.output_dir + '/.pretrained_model_{}'.format(-1),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            visdial_v=args.visdial_v,
            loss_type=args.loss_type,
            eval_disc=True)

        del model_recover

        if args.fp16:
            model.half()
            # cnn.half()
        model.to(device)
        # cnn.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
            # cnn = torch.nn.DataParallel(cnn)

        torch.cuda.empty_cache()
        model.eval()

        def read_data(src_file):
            eval_lst = []
            with open(src_file, "r", encoding='utf-8') as f_src:
                data = json.load(f_src)['data']
                dialogs = data['dialogs']
                questions = data['questions']
                answers = data['answers']
                img_idx = 0
                for dialog in tqdm(dialogs):
                    if img_idx < args.use_num_imgs or args.use_num_imgs == -1:
                        img_id = dialog['image_id']

                        cap_tokens = tokenizer.tokenize(dialog['caption'])

                        ques_id = [
                            item['question'] for item in dialog['dialog']
                        ]
                        ques_tokens = [
                            tokenizer.tokenize(questions[id] + '?')
                            for id in ques_id
                        ]
                        turn_num = len(ques_id)

                        ans_id = [
                            item['answer'] for item in dialog['dialog']
                            if 'answer' in item.keys()
                        ]
                        ans_tokens = [
                            tokenizer.tokenize(answers[id]) for id in ans_id
                        ]

                        assert len(ques_id) == len(ans_id) + 1

                        ans_opts = dialog['dialog'][len(
                            ans_id)]['answer_options']
                        ans_opts_tokens = [
                            tokenizer.tokenize(answers[id]) for id in ans_opts
                        ]

                        eval_lst.append(
                            (img_id, cap_tokens, ques_tokens, ans_tokens,
                             ans_opts_tokens, turn_num))

                        img_idx += 1
            return eval_lst

        input_lines = read_data(args.src_file)
        next_i = 0
        total_batch = math.ceil(len(input_lines) / args.batch_size)

        print('start the visdial decode evaluation...')
        ranks_json = []
        scores_json = []
        score_fn = args.save_ranks_path.replace('.json', '_score.json')
        with tqdm(total=total_batch) as pbar:
            while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf_id = [x[0] for x in _chunk]
                buf = [x[:-1] for x in _chunk]
                turn_id = [x[-1] for x in _chunk]
                next_i += args.batch_size
                instances = []
                for instance in buf:
                    instances.append(bi_uni_pipeline[0](instance))

                with torch.no_grad():
                    batch_data = list(zip(*instances))
                    img, vis_pe = (torch.stack(x).to(device)
                                   for x in batch_data[-2:])
                    task_idx = torch.tensor(batch_data[-3],
                                            dtype=torch.long).to(device)

                    conv_feats = img.data  # Bx100x2048
                    vis_pe = vis_pe.data

                    input_ids = torch.tensor(batch_data[0],
                                             dtype=torch.long).to(device)
                    segment_ids = torch.tensor(batch_data[1],
                                               dtype=torch.long).to(device)
                    input_mask = torch.stack(batch_data[2]).to(device)

                    output_scores = model(conv_feats,
                                          vis_pe,
                                          input_ids,
                                          segment_ids,
                                          input_mask,
                                          task_idx=task_idx)

                    output_scores = output_scores[:, :,
                                                  1]  # [batch_size, num_options]

                    ranks = scores_to_ranks(output_scores.unsqueeze(
                        1))  # [batch_size, num_rounds, num_options]

                    for i in range(len(buf_id)):
                        # Cast into types explicitly to ensure no errors in schema.
                        # Round ids are 1-10, not 0-9
                        if args.split == "test":
                            ranks_json.append({
                                "image_id":
                                buf_id[i],
                                "round_id":
                                turn_id[i],
                                "ranks": [rank.item() for rank in ranks[i][0]]
                            })
                            scores_json.append({
                                "image_id":
                                buf_id[i],
                                "round_id":
                                turn_id[i],
                                "ranks": [rank.item() for rank in ranks[i][0]],
                                "scores":
                                [score.item() for score in output_scores[i]]
                            })

                pbar.update(1)

        json.dump(ranks_json, open(args.save_ranks_path, "w"))
        json.dump(scores_json, open(score_fn, "w"))
        logger.info("Finish writing rankings into %s" % (args.save_ranks_path))
        logger.info("Finish writing scores into %s" % score_fn)
예제 #9
0
def main():
    args = process_args()
    os.makedirs(args.output_dir, exist_ok=True)

    if args.enable_butd:
        if args.visdial_v == '1.0' and not args.no_vision:
            assert (args.len_vis_input == 36)
        elif args.visdial_v == '0.9':
            assert (args.len_vis_input == 100)
            args.region_bbox_file = os.path.join(args.image_root,
                                                 args.region_bbox_file)
            args.region_det_file_prefix = os.path.join(
                args.image_root,
                args.region_det_file_prefix) if args.dataset in (
                    'cc', 'coco') and args.region_det_file_prefix != '' else ''

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1

    tokenizer.max_len = args.max_seq_length
    bi_uni_pipeline = [
        Preprocess4TestVisdialDisc(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'len_vis_input': args.len_vis_input,
                'max_len_hist_ques': args.max_len_hist_ques,
                'max_len_ans': args.max_len_ans
            },
            mode="bi",
            region_bbox_file=args.region_bbox_file,
            region_det_file_prefix=args.region_det_file_prefix,
            image_features_hdfpath=args.image_features_hdfpath,
            visdial_v=args.visdial_v,
            pad_hist=args.pad_hist,
            inc_full_hist=args.inc_full_hist,
            only_qa=args.only_qa)
    ]

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    logger.info('Attempting to recover models from: {}'.format(
        args.model_recover_path))
    if 0 == len(glob.glob(args.model_recover_path.strip())):
        logger.error('There are no models to recover. The program will exit.')
        sys.exit(1)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            task_idx=0,
            max_position_embeddings=512,
            cache_dir=args.output_dir + '/.pretrained_model_{}'.format(-1),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            visdial_v=args.visdial_v,
            loss_type=args.loss_type,
            eval_disc=True,
            add_attn_fuse=args.add_attn_fuse,
            no_vision=args.no_vision)

        del model_recover

        if args.fp16:
            model.half()
            # cnn.half()
        model.to(device)
        # cnn.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
            # cnn = torch.nn.DataParallel(cnn)

        torch.cuda.empty_cache()
        model.eval()

        def read_data(src_file):
            eval_lst = []
            with open(src_file, "r", encoding='utf-8') as f_src:
                data = json.load(f_src)['data']
                dialogs = data['dialogs']
                questions = data['questions']
                answers = data['answers']
                img_idx = 0
                for dialog in tqdm(dialogs):
                    if img_idx < args.use_num_imgs or args.use_num_imgs == -1:
                        img_id = dialog['image_id']

                        cap_tokens = tokenizer.tokenize(dialog['caption'])

                        ques_id = [
                            item['question'] for item in dialog['dialog']
                        ]
                        ques_tokens = [
                            tokenizer.tokenize(questions[id] + '?')
                            for id in ques_id
                        ]

                        ans_id = [item['answer'] for item in dialog['dialog']]
                        ans_tokens = [
                            tokenizer.tokenize(answers[id]) for id in ans_id
                        ]
                        gt_id = [item['gt_index'] for item in dialog['dialog']]

                        ans_opts = [
                            item['answer_options'] for item in dialog['dialog']
                        ]
                        ans_opts_tokens = [[
                            tokenizer.tokenize(answers[id]) for id in ans
                        ] for ans in ans_opts]

                        assert len(ques_tokens) == len(ans_tokens) == len(ans_opts_tokens) == 10, \
                            "ques num: %d, ans num: %d, ans opt num: %d" % (
                                len(ques_tokens), len(ans_tokens), len(ans_opts_tokens))
                        assert all([
                            len(ans_opt) == 100 for ans_opt in ans_opts_tokens
                        ]), "all the answer have 100 options"
                        eval_lst.append((img_id, cap_tokens, ques_tokens,
                                         ans_tokens, ans_opts_tokens, gt_id))

                        img_idx += 1
            return eval_lst

        def get_gt_rel_dict(fname):
            gt_rel_dict = {}
            gt_rel_data = json.load(open(fname))
            for item in gt_rel_data:
                image_id = item['image_id']
                round_id = item['round_id']
                gt_relevance = item['gt_relevance']
                # each image only at most has one turn having dense annotation
                if image_id not in gt_rel_dict:
                    gt_rel_dict[image_id] = (round_id, gt_relevance)
            return gt_rel_dict

        if args.gt_rel_file != '':
            gt_rel_dict = get_gt_rel_dict(args.gt_rel_file)

        input_lines = read_data(args.src_file)
        next_i = 0
        total_batch = math.ceil(len(input_lines) / args.batch_size)

        print('start the visdial decode evaluation...')
        t0 = time.time()
        ranks_json = []
        sparse_metrics = SparseGTMetrics()
        ndcg = NDCG()
        with tqdm(total=total_batch) as pbar:
            while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf_id = [x[0] for x in _chunk]
                buf = [x[:-1] for x in _chunk]
                buf_gt_id = [x[-1] for x in _chunk]
                next_i += args.batch_size
                instances = []
                for instance in buf:
                    instances.append(bi_uni_pipeline[0](instance))

                with torch.no_grad():
                    buf_gt_id = torch.tensor(buf_gt_id).long().to(device)
                    batch_data = list(zip(*instances))
                    task_idx = torch.tensor(batch_data[-3],
                                            dtype=torch.long).to(device)

                    if args.no_vision:
                        conv_feats = []
                        vis_pe = []
                    else:
                        img, vis_pe = (torch.stack(x).to(device)
                                       for x in batch_data[-2:])
                        conv_feats = img.data  # Bx100x2048
                        vis_pe = vis_pe.data

                    output_scores_turn = []

                    input_ids_turns = [[x[turn_i] for x in batch_data[0]]
                                       for turn_i in range(10)]
                    segment_ids_turns = [[x[turn_i] for x in batch_data[1]]
                                         for turn_i in range(10)]
                    input_mask_turns = [[x[turn_i] for x in batch_data[2]]
                                        for turn_i in range(10)]

                    for turn_i in range(10):
                        input_ids = torch.tensor(input_ids_turns[turn_i],
                                                 dtype=torch.long).to(device)
                        segment_ids = torch.tensor(segment_ids_turns[turn_i],
                                                   dtype=torch.long).to(device)
                        input_mask = torch.stack(
                            input_mask_turns[turn_i]).to(device)

                        output_scores = model(conv_feats,
                                              vis_pe,
                                              input_ids,
                                              segment_ids,
                                              input_mask,
                                              task_idx=task_idx)

                        output_scores = output_scores[:, :,
                                                      1]  # [batch_size, num_options]
                        output_scores_turn.append(output_scores)

                    output_scores_turn = torch.stack(
                        output_scores_turn,
                        1)  # [batch_size, num_rounds, num_options]
                    ranks = scores_to_ranks(output_scores_turn)
                    # output_scores_turn_cheat = output_scores_turn.scatter_(2, buf_gt_id.unsqueeze(2), 100.0)
                    sparse_metrics.observe(output_scores_turn, buf_gt_id)
                    for i in range(len(buf_id)):
                        # Cast into types explicitly to ensure no errors in schema.
                        # Round ids are 1-10, not 0-9
                        if args.split == "val":
                            for j in range(10):
                                ranks_json.append({
                                    "image_id":
                                    buf_id[i],
                                    "round_id":
                                    int(j + 1),
                                    "ranks":
                                    [rank.item() for rank in ranks[i][j]],
                                })
                    if args.gt_rel_file:
                        scores = []
                        gt_rels = []
                        for i in range(len(buf_id)):
                            if buf_id[i] in gt_rel_dict:
                                turn_idx, gt_rel = gt_rel_dict[buf_id[i]]
                                scores.append(output_scores_turn[i, turn_idx -
                                                                 1, :])
                                gt_rels.append(
                                    torch.tensor(
                                        gt_rel,
                                        dtype=torch.float32).to(device))
                        scores = torch.stack(scores)
                        gt_rels = torch.stack(gt_rels)
                        ndcg.observe(scores, gt_rels)

                pbar.update(1)

        json.dump(ranks_json, open(args.save_ranks_path, "w"))
        logger.info("Finish writing rankings into %s" % (args.save_ranks_path))

        if args.split == "val":
            fw = open(args.save_ranks_path.replace('.json', '_results.txt'),
                      "w")
            all_metrics = {}
            all_metrics.update(sparse_metrics.retrieve(reset=True))
            if args.gt_rel_file:
                all_metrics.update(ndcg.retrieve(reset=True))
            for metric_name, metric_value in all_metrics.items():
                print(f"{metric_name}: {metric_value}")
                fw.write("%s: %.6f\n" % (metric_name, metric_value))