Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default="../../../dataset/final_data/commongen",
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bart_model",
        default="facebook/bart-large",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default="../../../output/train_kgbart",
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default="../../../log/train_kgbart",
        type=str,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=64,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=True,
                        type=bool,
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=True,
                        type=bool,
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        type=bool,
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=48,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0.1,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=10.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=6,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        default=True,
        type=bool,
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        default=True,
                        type=bool,
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=64,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=64,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        default=True,
        type=bool,
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.70,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=30,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=5,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument('--pretraining_KG',
                        action='store_true',
                        help="Whether to pretraining KG. ")
    parser.add_argument('--train_pretraining_num',
                        type=int,
                        default=20,
                        help="The number of sample training pretraining KG. ")
    parser.add_argument('--val_pretraining_num',
                        type=int,
                        default=20,
                        help="The number of sample validing pretraining KG.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=64,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        default=False,
        type=bool,
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        default=False,
                        type=bool,
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        default=False,
        type=bool,
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--keep_last_epochs",
                        default=5,
                        type=int,
                        help="Keep the last few epochs.")

    args = parser.parse_args()

    # assert Path(args.model_recover_path).exists(
    # ), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BartTokenizer.from_pretrained(args.bart_model,
                                              do_lower_case=args.do_lower_case)
    # if args.max_position_embeddings:
    #     tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = tokenizer  # WhitespaceTokenizer() if args.tokenized_input else
    if args.local_rank == 0:
        dist.barrier()

    bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Pretrain(
            list(tokenizer.encoder.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mode="s2s",
            pretraining_KG=args.pretraining_KG,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    file_oracle = None
    entity_id = os.path.join(
        args.data_dir, args.tgt_file
        if args.tgt_file else 'CommonGen_KG/commongen_entity2id.txt')
    entity_embedding_path = os.path.join(
        args.data_dir, args.tgt_file
        if args.tgt_file else 'CommonGen_KG/commongen_ent_embeddings')

    relation_id = os.path.join(
        args.data_dir, args.tgt_file
        if args.tgt_file else 'CommonGen_KG/commongen_relation2id.txt')
    relation_embedding_path = os.path.join(
        args.data_dir, args.tgt_file
        if args.tgt_file else 'CommonGen_KG/commongen_rel_embeddings')

    entity_embedding = np.array(pickle.load(open(entity_embedding_path, "rb")))
    entity_embedding = np.array(
        list(np.zeros((4, 1024))) + list(entity_embedding))
    relation_embedding = np.array(
        pickle.load(open(relation_embedding_path, "rb")))

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        if args.pretraining_KG:
            file_oracle = os.path.join(args.data_dir,
                                       'CommonGen_KG/commongen_entity2id.txt')
        fn_src = os.path.join(
            args.data_dir,
            args.src_file if args.src_file else 'commongen.train.src_new.txt')
        fn_tgt = os.path.join(
            args.data_dir,
            args.tgt_file if args.tgt_file else 'commongen.train.tgt.txt')
        fn_onehop = os.path.join(
            args.data_dir,
            args.tgt_file if args.tgt_file else 'commongen.train.onehop_5.txt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            entity_id,
            relation_id,
            fn_onehop,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            pretraining_KG=file_oracle,
            pretraining_num=args.train_pretraining_num,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)
    if args.do_eval:
        print("Loading Dev Dataset", args.data_dir)
        if args.pretraining_KG:
            file_oracle = os.path.join(args.data_dir,
                                       'CommonGen_KG/commongen_entity2id.txt')
        fn_src = os.path.join(
            args.data_dir,
            args.src_file if args.src_file else 'commongen.dev.src_new.txt')
        fn_tgt = os.path.join(
            args.data_dir,
            args.tgt_file if args.tgt_file else 'commongen.dev.tgt.txt')
        fn_onehop = os.path.join(
            args.data_dir,
            args.tgt_file if args.tgt_file else 'commongen.dev.onehop_5.txt')
        dev_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            entity_id,
            relation_id,
            fn_onehop,
            args.eval_batch_size,
            data_tokenizer,
            args.max_seq_length,
            pretraining_KG=file_oracle,
            pretraining_num=args.val_pretraining_num,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            dev_sampler = RandomSampler(dev_dataset, replacement=False)
            _batch_size = args.eval_batch_size
        else:
            dev_sampler = DistributedSampler(dev_dataset)
            _batch_size = args.eval_batch_size // dist.get_world_size()
        dev_dataloader = torch.utils.data.DataLoader(
            dev_dataset,
            batch_size=_batch_size,
            sampler=dev_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
                      (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.pretraining_KG else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = KGBartForConditionalGeneration.from_pretrained(
            args.bart_model,
            entity_weight=entity_embedding,
            relation_weight=relation_embedding)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = KGBartForConditionalGeneration.from_pretrained(
            args.bart_model,
            state_dict=model_recover,
            entity_weight=entity_embedding,
            relation_weight=relation_embedding)
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        schedule_recover = torch.load(os.path.join(
            args.output_dir, "sched.{0}.bin".format(recover_step)),
                                      map_location='cpu')
        scheduler.load_state_dict(schedule_recover)

        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    best_dev_loss = 1000

    output_eval_file = os.path.join(args.log_dir, "eval_results.txt")
    writer = open(output_eval_file, "w")

    def checkpoint_paths(path, pattern=r"model(\d+)\.pt"):
        """Retrieves all checkpoints found in `path` directory.

        Checkpoints are identified by matching filename to the specified pattern. If
        the pattern contains groups, the result will be sorted by the first group in
        descending order.
        """
        pt_regexp = re.compile(pattern)
        files = os.listdir(path)

        entries = []
        for i, f in enumerate(files):
            m = pt_regexp.fullmatch(f)
            if m is not None:
                idx = float(m.group(1)) if len(m.groups()) > 0 else int(
                    f.split(".")[1])
                entries.append((idx, m.group(0)))
        return [
            os.path.join(path, x[1]) for x in sorted(entries, reverse=True)
        ]

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            model.train()
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            # iter_bar = tqdm(BackgroundGenerator(train_dataloader), desc='Iter (loss=X.XXX)',
            #                 disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Training",
                         position=0,
                         leave=True)):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.pretraining_KG:
                    input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch
                else:
                    input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch

                loss_output = model(
                    input_ids,
                    input_entity_ids=input_entity_ids,
                    attention_mask=subword_mask,
                    word_mask=word_mask,
                    word_subword=word_subword,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    labels=labels,
                    label_smoothing=False)

                masked_lm_loss = loss_output.loss
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                loss = masked_lm_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                # iter_bar.set_description('Iter %d (loss=%5.3f)' % (i_epoch, loss.item()))

                if step % 1000 == 0:
                    print('Iter %d  (Gen_loss=%5.3f)' % (i_epoch, loss.item()))

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.do_eval and (args.local_rank == -1
                                 or torch.distributed.get_rank() == 0):
                model.eval()
                cur_dev_loss = []
                with torch.no_grad():
                    for step, batch in enumerate(
                            tqdm(dev_dataloader,
                                 desc="Evaluating",
                                 position=0,
                                 leave=True)):
                        batch = [
                            t.to(device) if t is not None else None
                            for t in batch
                        ]
                        if args.pretraining_KG:
                            input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch
                        else:
                            input_ids, input_entity_ids, subword_mask, word_mask, word_subword, decoder_input_ids, decoder_attention_mask, labels = batch

                        loss_output = model(
                            input_ids,
                            input_entity_ids=input_entity_ids,
                            attention_mask=subword_mask,
                            word_mask=word_mask,
                            word_subword=word_subword,
                            decoder_input_ids=decoder_input_ids,
                            decoder_attention_mask=decoder_attention_mask,
                            labels=labels,
                            label_smoothing=False)

                        masked_lm_loss = loss_output.loss

                        if n_gpu > 1:  # mean() to average on multi-gpu.
                            # loss = loss.mean()
                            masked_lm_loss = masked_lm_loss.mean()
                            # next_sentence_loss = next_sentence_loss.mean()
                        loss = masked_lm_loss
                        cur_dev_loss.append(float(loss.item()))
                        # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                    dev_loss = sum(cur_dev_loss) / float(len(cur_dev_loss))
                    print("the epoch {} DEV loss is {}".format(
                        i_epoch, dev_loss))
                    if best_dev_loss > dev_loss:
                        best_dev_loss = dev_loss
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        os.makedirs(args.output_dir + "/best_model",
                                    exist_ok=True)
                        output_model_file = os.path.join(
                            args.output_dir, "best_model/model.best.bin")
                        # output_optim_file = os.path.join(
                        #     args.output_dir, "best_model/optim.best.bin")
                        # output_schedule_file = os.path.join(
                        #     args.output_dir, "best_model/sched.best.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        # torch.save(optimizer.state_dict(), output_optim_file)
                        # torch.save(scheduler.state_dict(), output_schedule_file)

                    logger.info(
                        "** ** * Saving fine-tuned model and optimizer ** ** * "
                    )
                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self
                    output_model_file = os.path.join(
                        args.output_dir, "model.{0}.bin".format(i_epoch))
                    output_optim_file = os.path.join(
                        args.output_dir, "optim.{0}.bin".format(i_epoch))
                    output_schedule_file = os.path.join(
                        args.output_dir, "sched.{0}.bin".format(i_epoch))
                    torch.save(model_to_save.state_dict(), output_model_file)
                    torch.save(optimizer.state_dict(), output_optim_file)
                    torch.save(scheduler.state_dict(), output_schedule_file)

                    writer.write("epoch " + str(i_epoch) + "\n")
                    writer.write("the current eval accuracy is: " +
                                 str(dev_loss) + "\n")

                    logger.info("***** CUDA.empty_cache() *****")
                    torch.cuda.empty_cache()

                    if args.keep_last_epochs > 0:
                        # remove old epoch checkpoints; checkpoints are sorted in descending order
                        checkpoints = checkpoint_paths(
                            args.output_dir, pattern=r"model.\d+.bin")
                        for old_chk in checkpoints[args.keep_last_epochs:]:
                            if os.path.lexists(old_chk):
                                os.remove(old_chk)

                        checkpoints = checkpoint_paths(
                            args.output_dir, pattern=r"optim.\d+.bin")
                        for old_chk in checkpoints[args.keep_last_epochs:]:
                            if os.path.lexists(old_chk):
                                os.remove(old_chk)

                        checkpoints = checkpoint_paths(
                            args.output_dir, pattern=r"sched.\d+.bin")
                        for old_chk in checkpoints[args.keep_last_epochs:]:
                            if os.path.lexists(old_chk):
                                os.remove(old_chk)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
Beispiel #2
0
class Trainer:
    def __init__(self, callback: TrainerCallback):
        self.callback = callback
        self.callback.trainer = self
        logging.basicConfig(level=logging.INFO)

    def parse_args(self):
        self.parser = argparse.ArgumentParser()
        self.parser.add_argument('--train', action='store_true')
        self.parser.add_argument('--dev', action='store_true')
        self.parser.add_argument('--test', action='store_true')
        self.parser.add_argument('--debug', action='store_true')
        self.parser.add_argument("--per_gpu_train_batch_size",
                                 default=8,
                                 type=int)
        self.parser.add_argument("--per_gpu_eval_batch_size",
                                 default=8,
                                 type=int)
        self.parser.add_argument("--learning_rate", default=5e-5, type=float)
        self.parser.add_argument("--gradient_accumulation_steps",
                                 type=int,
                                 default=1)
        self.parser.add_argument("--weight_decay", default=0.0, type=float)
        self.parser.add_argument("--adam_epsilon", default=1e-8, type=float)
        self.parser.add_argument("--max_grad_norm", default=1.0, type=float)
        self.parser.add_argument("--epochs", default=2, type=int)
        self.parser.add_argument("--warmup_ratio", default=0.1, type=float)
        self.parser.add_argument("--logging_steps", type=int, default=500)
        self.parser.add_argument("--save_steps", type=int, default=10000)
        self.parser.add_argument("--seed", type=int, default=42)
        self.parser.add_argument("--num_workers", type=int, default=0)
        self.parser.add_argument("--local_rank", type=int, default=-1)
        self.parser.add_argument("--fp16", action="store_true")
        self.parser.add_argument("--fp16_opt_level", type=str, default="O1")
        self.parser.add_argument("--no_cuda", action="store_true")
        self.parser.add_argument("--load_checkpoint", default=None, type=str)
        self.parser.add_argument("--ignore_progress", action='store_true')
        self.parser.add_argument("--dataset_ratio", type=float, default=1.0)
        self.parser.add_argument("--no_save", action="store_true")
        self.callback.on_argument(self.parser)
        self.args = self.parser.parse_args()
        keys = list(self.args.__dict__.keys())
        for key in keys:
            value = getattr(self.args, key)
            if type(value) == str and os.path.exists(value):
                setattr(self.args, key, os.path.abspath(value))
        if not self.args.train:
            self.args.epochs = 1
        self.train = self.args.train
        self.dev = self.args.dev
        self.test = self.args.test
        self.debug = self.args.debug
        self.per_gpu_train_batch_size = self.args.per_gpu_train_batch_size
        self.per_gpu_eval_batch_size = self.args.per_gpu_eval_batch_size
        self.learning_rate = self.args.learning_rate
        self.gradient_accumulation_steps = self.args.gradient_accumulation_steps
        self.weight_decay = self.args.weight_decay
        self.adam_epsilon = self.args.adam_epsilon
        self.max_grad_norm = self.args.max_grad_norm
        self.epochs = self.args.epochs
        self.warmup_ratio = self.args.warmup_ratio
        self.logging_steps = self.args.logging_steps
        self.save_steps = self.args.save_steps
        self.seed = self.args.seed
        self.num_workers = self.args.num_workers
        self.local_rank = self.args.local_rank
        self.fp16 = self.args.fp16
        self.fp16_opt_level = self.args.fp16_opt_level
        self.no_cuda = self.args.no_cuda
        self.load_checkpoint = self.args.load_checkpoint
        self.ignore_progress = self.args.ignore_progress
        self.dataset_ratio = self.args.dataset_ratio
        self.no_save = self.args.no_save
        self.callback.args = self.args

    def set_env(self):
        if self.debug:
            sys.excepthook = IPython.core.ultratb.FormattedTB(
                mode='Verbose', color_scheme='Linux', call_pdb=1)
        if self.local_rank == -1 or self.no_cuda:
            device = torch.device("cuda" if torch.cuda.is_available()
                                  and not self.no_cuda else "cpu")
            self.n_gpu = 0 if self.no_cuda else torch.cuda.device_count()
        else:
            torch.cuda.set_device(self.local_rank)
            device = torch.device("cuda", self.local_rank)
            torch.distributed.init_process_group(backend="nccl")
            self.n_gpu = 1
        set_seed(self.seed, self.n_gpu)
        self.device = device
        with self.once_barrier():
            if not os.path.exists('r'):
                os.mkdir('r')
            runs = os.listdir('r')
            i = max([int(c) for c in runs], default=-1) + 1
            os.mkdir(os.path.join('r', str(i)))
            src_names = [
                source for source in os.listdir() if source.endswith('.py')
            ]
            for src_name in src_names:
                shutil.copy(src_name, os.path.join('r', str(i), src_name))
            os.mkdir(os.path.join('r', str(i), 'output'))
            os.mkdir(os.path.join('r', str(i), 'tmp'))
        runs = os.listdir('r')
        i = max([int(c) for c in runs])
        os.chdir(os.path.join('r', str(i)))
        with self.once_barrier():
            json.dump(sys.argv, open('output/args.json', 'w'))
        logging.info(
            "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, 16-bits training: {}"
            .format(self.local_rank, device, self.n_gpu,
                    bool(self.local_rank != -1), self.fp16))
        self.train_batch_size = self.per_gpu_train_batch_size * max(
            1, self.n_gpu)
        self.eval_batch_size = self.per_gpu_eval_batch_size * max(
            1, self.n_gpu)
        if self.fp16:
            apex.amp.register_half_function(torch, "einsum")
        self.stream = torch.cuda.Stream()

    def set_model(self):
        self.model = self.callback.load_model()
        self.model.to(self.device)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.weight_decay
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0
            },
        ]
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.learning_rate,
                               eps=self.adam_epsilon)
        if self.fp16:
            self.model, self.optimizer = apex.amp.initialize(
                self.model, self.optimizer, opt_level=self.fp16_opt_level)
        if self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)
        if self.local_rank != -1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                find_unused_parameters=True)

    def once(self):
        return Once(self.local_rank)

    def once_barrier(self):
        return OnceBarrier(self.local_rank)

    def cache(self):
        return Cache(self.local_rank)

    def load_data(self):
        self.train_step = 1
        self.epochs_trained = 0
        self.steps_trained_in_current_epoch = 0
        train_dataset, dev_dataset, test_dataset = self.callback.load_data()
        train_fn, dev_fn, test_fn = self.callback.collate_fn()
        if train_dataset:
            if self.dataset_ratio < 1:
                train_dataset = torch.utils.data.Subset(
                    train_dataset,
                    list(range(int(len(train_dataset) * self.dataset_ratio))))
            self.train_dataset = train_dataset
            self.train_sampler = RandomSampler(
                self.train_dataset
            ) if self.local_rank == -1 else DistributedSampler(
                self.train_dataset)
            self.train_dataloader = Prefetcher(
                DataLoader(self.train_dataset,
                           sampler=self.train_sampler,
                           batch_size=self.train_batch_size,
                           collate_fn=train_fn,
                           num_workers=self.num_workers), self.stream)
            self.t_total = len(
                self.train_dataloader
            ) // self.gradient_accumulation_steps * self.epochs
            self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=int(self.t_total * self.warmup_ratio),
                num_training_steps=self.t_total)
        if dev_dataset:
            if self.dataset_ratio < 1:
                dev_dataset = torch.utils.data.Subset(
                    dev_dataset,
                    list(range(int(len(dev_dataset) * self.dataset_ratio))))
            self.dev_dataset = dev_dataset
            self.dev_sampler = SequentialSampler(
                self.dev_dataset
            ) if self.local_rank == -1 else DistributedSampler(
                self.dev_dataset)
            self.dev_dataloader = Prefetcher(
                DataLoader(self.dev_dataset,
                           sampler=self.dev_sampler,
                           batch_size=self.eval_batch_size,
                           collate_fn=dev_fn,
                           num_workers=self.num_workers), self.stream)
        if test_dataset:
            if self.dataset_ratio < 1:
                test_dataset = torch.utils.data.Subset(
                    test_dataset,
                    list(range(int(len(test_dataset) * self.dataset_ratio))))
            self.test_dataset = test_dataset
            self.test_sampler = SequentialSampler(
                self.test_dataset
            ) if self.local_rank == -1 else DistributedSampler(
                self.test_dataset)
            self.test_dataloader = Prefetcher(
                DataLoader(self.test_dataset,
                           sampler=self.test_sampler,
                           batch_size=self.eval_batch_size,
                           collate_fn=test_fn,
                           num_workers=self.num_workers), self.stream)

    def restore_checkpoint(self, path, ignore_progress=False):
        if self.no_save:
            return
        model_to_load = self.model.module if hasattr(self.model,
                                                     "module") else self.model
        model_to_load.load_state_dict(
            torch.load(os.path.join(path, 'pytorch_model.bin'),
                       map_location=self.device))
        self.optimizer.load_state_dict(
            torch.load(os.path.join(path, "optimizer.pt"),
                       map_location=self.device))
        self.scheduler.load_state_dict(
            torch.load(os.path.join(path, "scheduler.pt"),
                       map_location=self.device))
        self.callback.on_load(path)
        if not ignore_progress:
            self.train_step = int(path.split("-")[-1])
            self.epochs_trained = self.train_step // (
                len(self.train_dataloader) // self.gradient_accumulation_steps)
            self.steps_trained_in_current_epoch = self.train_step % (
                len(self.train_dataloader) // self.gradient_accumulation_steps)
        logging.info(
            "  Continuing training from checkpoint, will skip to saved train_step"
        )
        logging.info("  Continuing training from epoch %d",
                     self.epochs_trained)
        logging.info("  Continuing training from train step %d",
                     self.train_step)
        logging.info("  Will skip the first %d steps in the first epoch",
                     self.steps_trained_in_current_epoch)

    def save_checkpoint(self):
        if self.no_save:
            return
        output_dir = os.path.join('output',
                                  "checkpoint-{}".format(self.train_step))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        model_to_save = self.model.module if hasattr(self.model,
                                                     "module") else self.model
        torch.save(model_to_save.state_dict(),
                   os.path.join(output_dir, 'pytorch_model.bin'))
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
        torch.save(self.optimizer.state_dict(),
                   os.path.join(output_dir, "optimizer.pt"))
        torch.save(self.scheduler.state_dict(),
                   os.path.join(output_dir, "scheduler.pt"))
        self.callback.on_save(output_dir)

    def run(self):
        self.parse_args()
        self.set_env()
        with self.once():
            self.writer = SummaryWriter()
        self.set_model()
        self.load_data()
        if self.load_checkpoint is not None:
            self.restore_checkpoint(self.load_checkpoint, self.ignore_progress)
        best_performance = 0
        best_step = -1
        for epoch in range(self.epochs):
            if epoch < self.epochs_trained:
                continue
            with self.once():
                logging.info('epoch %d', epoch)
            if self.train:
                tr_loss, logging_loss = 0.0, 0.0
                self.model.zero_grad()
                self.model.train()
                self.callback.on_train_epoch_start(epoch)
                if self.local_rank >= 0:
                    self.train_sampler.set_epoch(epoch)
                for step, batch in enumerate(
                        tqdm(self.train_dataloader,
                             disable=self.local_rank > 0)):
                    if step < self.steps_trained_in_current_epoch:
                        continue
                    inputs, extra = self.callback.process_train_data(batch)
                    outputs = self.model(**inputs)
                    loss = outputs[0]
                    if self.n_gpu > 1:
                        loss = loss.mean()
                    if self.gradient_accumulation_steps > 1:
                        loss = loss / self.gradient_accumulation_steps
                    if self.local_rank < 0 or (
                            step + 1) % self.gradient_accumulation_steps == 0:
                        if self.fp16:
                            with apex.amp.scale_loss(
                                    loss, self.optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()
                    else:
                        with self.model.no_sync():
                            if self.fp16:
                                with apex.amp.scale_loss(
                                        loss, self.optimizer) as scaled_loss:
                                    scaled_loss.backward()
                            else:
                                loss.backward()
                    tr_loss += loss.item()
                    if (step + 1) % self.gradient_accumulation_steps == 0:
                        if self.fp16:
                            torch.nn.utils.clip_grad_norm_(
                                apex.amp.master_params(self.optimizer),
                                self.max_grad_norm)
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                self.model.parameters(), self.max_grad_norm)
                        self.optimizer.step()
                        self.scheduler.step()
                        self.model.zero_grad()
                        self.train_step += 1
                        with self.once():
                            if self.train_step % self.logging_steps == 0:
                                self.writer.add_scalar(
                                    "lr",
                                    self.scheduler.get_lr()[0],
                                    self.train_step)
                                self.writer.add_scalar(
                                    "loss", (tr_loss - logging_loss) /
                                    self.logging_steps, self.train_step)
                                logging_loss = tr_loss
                            if self.train_step % self.save_steps == 0:
                                self.save_checkpoint()
                    self.callback.on_train_step(step, self.train_step, inputs,
                                                extra, loss.item(), outputs)
                with self.once():
                    self.save_checkpoint()
                self.callback.on_train_epoch_end(epoch)
            if self.dev:
                with torch.no_grad():
                    self.model.eval()
                    self.callback.on_dev_epoch_start(epoch)
                    for step, batch in enumerate(
                            tqdm(self.dev_dataloader,
                                 disable=self.local_rank > 0)):
                        inputs, extra = self.callback.process_dev_data(batch)
                        outputs = self.model(**inputs)
                        self.callback.on_dev_step(step, inputs, extra, outputs)
                    performance = self.callback.on_dev_epoch_end(epoch)
                    if performance > best_performance:
                        best_performance = performance
                        best_step = self.train_step
        if self.test:
            with torch.no_grad():
                if best_step > 0 and self.train:
                    self.restore_checkpoint(
                        os.path.join('output',
                                     "checkpoint-{}".format(best_step)))
                self.model.eval()
                self.callback.on_test_epoch_start(epoch)
                for step, batch in enumerate(
                        tqdm(self.test_dataloader,
                             disable=self.local_rank > 0)):
                    inputs, extra = self.callback.process_test_data(batch)
                    outputs = self.model(**inputs)
                    self.callback.on_test_step(step, inputs, extra, outputs)
                self.callback.on_test_epoch_end(epoch)
        with self.once():
            self.writer.close()
        json.dump(True, open('output/f.json', 'w'))

    def distributed_broadcast(self, l):
        assert type(l) == list or type(l) == dict
        if self.local_rank < 0:
            return l
        else:
            torch.distributed.barrier()
            process_number = torch.distributed.get_world_size()
            json.dump(l, open(f'tmp/{self.local_rank}.json', 'w'))
            torch.distributed.barrier()
            objs = list()
            for i in range(process_number):
                objs.append(json.load(open(f'tmp/{i}.json')))
            if type(objs[0]) == list:
                ret = list()
                for i in range(process_number):
                    ret.extend(objs[i])
            else:
                ret = dict()
                for i in range(process_number):
                    for k, v in objs.items():
                        assert k not in ret
                        ret[k] = v
            torch.distributed.barrier()
            return ret

    def distributed_merge(self, l):
        assert type(l) == list or type(l) == dict
        if self.local_rank < 0:
            return l
        else:
            torch.distributed.barrier()
            process_number = torch.distributed.get_world_size()
            json.dump(l, open(f'tmp/{self.local_rank}.json', 'w'))
            torch.distributed.barrier()
            if self.local_rank == 0:
                objs = list()
                for i in range(process_number):
                    objs.append(json.load(open(f'tmp/{i}.json')))
                if type(objs[0]) == list:
                    ret = list()
                    for i in range(process_number):
                        ret.extend(objs[i])
                else:
                    ret = dict()
                    for i in range(process_number):
                        for k, v in objs.items():
                            assert k not in ret
                            ret[k] = v
            else:
                ret = None
            torch.distributed.barrier()
            return ret

    def distributed_get(self, v):
        if self.local_rank < 0:
            return v
        else:
            torch.distributed.barrier()
            if self.local_rank == 0:
                json.dump(v, open('tmp/v.json', 'w'))
            torch.distributed.barrier()
            v = json.load(open('tmp/v.json'))
            torch.distributed.barrier()
            return v
Beispiel #3
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    non_multi_model = model
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(non_multi_model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    best_perplexity = float('inf')
    for i, epoch in enumerate(train_iterator):
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        if args.local_rank != -1:
            train_sampler.set_epoch(epoch)

        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

        if args.do_eval:
            file_path = Path(args.data_dir, args.eval_data_file)

            out_file_path = Path(args.data_dir,
                                 "output_" + args.eval_data_file)
            id_to_json_map = {}
            with open(file_path, encoding="utf-8") as f:
                lines = []
                i = 0

                eval_loss = 0.0
                nb_eval_steps = 0
                for line in tqdm(f, desc="Evaluating"):
                    out_json = {}
                    line = json.loads(line)
                    example_id = line.get("example_id")
                    question_text = line.get("question_text")

                    prompt_text = question_text + " " + args.sep_token + " "
                    encoded_prompt = tokenizer.encode(prompt_text,
                                                      add_special_tokens=False,
                                                      return_tensors="pt")
                    encoded_prompt = encoded_prompt.to(args.device)

                    output_sequences = non_multi_model.generate(
                        input_ids=encoded_prompt,
                        max_length=args.length + len(encoded_prompt[0]),
                        temperature=args.temperature,
                        top_k=args.k,
                        top_p=args.p,
                        repetition_penalty=args.repetition_penalty,
                        do_sample=True,
                        num_return_sequences=args.num_return_sequences,
                    )
                    if len(output_sequences.shape) > 2:
                        output_sequences.squeeze_()

                    generated_sequences = []

                    for generated_sequence_idx, generated_sequence in enumerate(
                            output_sequences):
                        # print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
                        # generated_sequence = output_sequences[0]

                        generated_sequence = generated_sequence.tolist()

                        # Decode text
                        text = tokenizer.decode(
                            generated_sequence,
                            clean_up_tokenization_spaces=True)

                        # Remove all text after the stop token
                        if args.stop_token:
                            text = text[:text.find(args.stop_token)]

                        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
                        total_sequence = (prompt_text + text[len(
                            tokenizer.decode(encoded_prompt[0],
                                             clean_up_tokenization_spaces=True)
                        ):])

                        # print(total_sequence)

                        out_json["journaling_input"], out_json[
                            "reflection_output"] = total_sequence.split(
                                args.sep_token)[:2]

                        sample_dataset = GenerateTextDataset(
                            tokenizer, total_sequence, args.block_size)

                        def collate(examples: List[torch.Tensor]):
                            if tokenizer._pad_token is None:
                                return pad_sequence(examples, batch_first=True)
                            return pad_sequence(
                                examples,
                                batch_first=True,
                                padding_value=tokenizer.pad_token_id)

                        eval_sampler = SequentialSampler(sample_dataset)
                        eval_dataloader = DataLoader(sample_dataset,
                                                     sampler=eval_sampler,
                                                     batch_size=1,
                                                     collate_fn=collate)

                        model_lm = model
                        if args.n_gpu > 1:
                            model_lm = torch.nn.DataParallel(model_lm)

                        model_lm.eval()

                        for batch in eval_dataloader:
                            inputs, labels = mask_tokens(
                                batch, tokenizer,
                                args) if args.mlm else (batch, batch)
                            inputs = inputs.to(args.device)
                            labels = labels.to(args.device)

                            with torch.no_grad():
                                outputs = model_lm(inputs,
                                                   masked_lm_labels=labels
                                                   ) if args.mlm else model_lm(
                                                       inputs, labels=labels)
                                lm_loss = outputs[0]
                                example_loss = lm_loss.mean().item()
                                eval_loss += example_loss
                            nb_eval_steps += 1

                        perplexity = torch.exp(
                            torch.tensor(example_loss)).item()
                        # print(perplexity)
                        out_json["perplexity"] = perplexity

                        example_id += "-" + str(generated_sequence_idx)
                        id_to_json_map[example_id] = json.dumps(
                            out_json, ensure_ascii=False)

                    # result = {"perplexity": perplexity}

                eval_loss = eval_loss / nb_eval_steps
                total_perplexity = torch.exp(torch.tensor(eval_loss))
                logger.info(f"total_loss:: {eval_loss}")
                logger.info(
                    f"total_perplexity:: {torch.exp(torch.tensor(eval_loss))}")
                if total_perplexity < best_perplexity:
                    logger.info(
                        f"Current best epoch::: {i}, with perplexity:: {total_perplexity}"
                    )
                    best_perplexity = total_perplexity

                    with open(out_file_path, "w+",
                              encoding="utf-8") as out_file:
                        for _, out_json in id_to_json_map.items():
                            out_file.write(out_json + "\n")

                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self

                    # If we save using the predefined names, we can load using `from_pretrained`
                    output_model_file = os.path.join(args.output_dir,
                                                     WEIGHTS_NAME)
                    output_config_file = os.path.join(args.output_dir,
                                                      CONFIG_NAME)
                    torch.save(model_to_save.state_dict(), output_model_file)
                    model_to_save.config.to_json_file(output_config_file)
                    tokenizer.save_vocabulary(args.output_dir)

    return global_step, tr_loss / global_step
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--data_dir", default='/data/lq/tianchi/qg/model/unilm/data_file/', type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--src_file", default='src_file/train_data.json', type=str,
                        help="The input data file name.")
    parser.add_argument("--dev_file", default='dev_data.json', type=str, help="dev file.")
    parser.add_argument("--model_type", default='unilm', type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name_or_path", default='/data/lq/tianchi/qg/model/unilm/torch-model/', type=str, required=True,
                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
    parser.add_argument("--output_dir", default='/data/lq/tianchi/qg/model/unilm/output_dir/', type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--log_dir", default='', type=str,
                        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path", default=None, type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path", default=None, type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")

    # Other parameters
    parser.add_argument("--dev_batch_size", default=20, type=str, help="dev batch size.")
    parser.add_argument("--max_seq_length", default=512, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument('--max_position_embeddings', type=int, default=512,
                        help="max position embeddings")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size", default=32, type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size", default=64, type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing", default=0, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.01, type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob", default=0.1, type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed', type=int, default=777,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--tokenized_input', action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a', type=int, default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b', type=int, default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument('--trunc_seg', default='',
                        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument('--always_truncate_tail', action='store_true',
                        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument("--mask_prob", default=0.20, type=float,
                        help="Number of prediction is sometimes less than max_pred when sequence is short.")
    parser.add_argument("--mask_prob_eos", default=0, type=float,
                        help="Number of prediction is sometimes less than max_pred when sequence is short.")
    parser.add_argument('--max_pred', type=int, default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers", default=0, type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words', action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb', type=float, default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size', type=int, default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word', action='store_true',
                        help="Whether masking a whole word.")

    args = parser.parse_args()


    if not(args.model_recover_path and Path(args.model_recover_path).exists()):
        args.model_recover_path = None

    args.output_dir = args.output_dir.replace(
        '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace(
        '[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    if args.log_dir:
        os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__, open(os.path.join(
        args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(
        args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
    data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [utils_seq2seq.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, mask_source_words=False, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, tokenizer=data_tokenizer)]

        file = os.path.join(
            args.data_dir, args.src_file if args.src_file else 'train.tgt')
        train_dataset = utils_seq2seq.Seq2SeqDataset(
            file, args.train_batch_size, data_tokenizer, args.max_seq_length, bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=_batch_size, sampler=train_sampler,
                                                       num_workers=args.num_workers, collate_fn=utils_seq2seq.batch_list_to_batch_tensors, pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(len(train_dataloader) * args.num_train_epochs /
                  args.gradient_accumulation_steps)

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    global_step = 0
    if (recover_step is None) and (args.model_recover_path is None):
        model_recover = None
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(
                recover_step * t_total / args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(
                args.model_recover_path, map_location='cpu')
    model = model_class.from_pretrained(
        args.model_name_or_path, state_dict=model_recover, config=config)
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*t_total), num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(
            model, optimizer, opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model, device_ids=[
                    args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)

        #logger.info("***** Recover amp: %d *****", recover_step)
        #amp_recover = torch.load(os.path.join(
        #    args.output_dir, "amp.{0}.bin".format(recover_step)), map_location='cpu')
        #amp.load_state_dict(amp_recover)

        logger.info("***** Recover scheduler: %d *****", recover_step)
        scheduler_recover = torch.load(os.path.join(
            args.output_dir, "sched.{0}.bin".format(recover_step)), map_location='cpu')
        scheduler.load_state_dict(scheduler_recover)

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    #####################
    ### 载入验证数据######
    #####################
    print("Loading Dev Dataset", args.data_dir)
    bi_uni_pipeline = [utils_seq2seq.Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()),
                                                        tokenizer.convert_tokens_to_ids, args.max_seq_length,
                                                        mask_source_words=False, skipgram_prb=args.skipgram_prb,
                                                        skipgram_size=args.skipgram_size,
                                                        mask_whole_word=args.mask_whole_word, tokenizer=data_tokenizer)]
    file_dev = os.path.join(args.data_dir, args.dev_file if args.dev_file else 'train.tgt')
    dev_dataset = utils_seq2seq.Seq2SeqDataset(file_dev, args.dev_batch_size, data_tokenizer, args.max_seq_length,
                                               bi_uni_pipeline=bi_uni_pipeline)
    if args.local_rank == -1:
        dev_sampler = RandomSampler(dev_dataset, replacement=False)
        _batch_size = args.dev_batch_size
    else:
        dev_sampler = DistributedSampler(dev_dataset)
        _batch_size = args.dev_batch_size // dist.get_world_size()

    dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=int(_batch_size), sampler=dev_sampler,
                                                 num_workers=args.num_workers,
                                                 collate_fn=utils_seq2seq.batch_list_to_batch_tensors,
                                                 pin_memory=False)

    #####################
    ######训练开始#######
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step+1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)):

            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)(lr=X.XXXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                                       masked_pos=masked_pos, masked_weights=masked_weights)
                if n_gpu > 1:    # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                loss = masked_lm_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss={:.2f})(lr={:0.2e})'.format(loss.item(), scheduler.get_lr()[0]))

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    optimizer.zero_grad()
                    global_step += 1

                if step %100==0:
                    # Save a trained model
                    if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                        logger.info(
                            "** ** * Saving fine-tuned model and optimizer  :step{}** ** * ".format(step))
                        model_to_save = model.module if hasattr(
                            model, 'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir, "model.{}.{}.bin".format(i_epoch,step))
                        torch.save(model_to_save.state_dict(), output_model_file)
                        output_optim_file = os.path.join(
                            args.output_dir, "optim.{}.{}.bin".format(i_epoch,step))
                        torch.save(optimizer.state_dict(), output_optim_file)
                        output_sched_file = os.path.join(
                            args.output_dir, "sched.{}.{}.bin".format(i_epoch,step))
                        torch.save(scheduler.state_dict(), output_sched_file)

                        logger.info("***** CUDA.empty_cache() *****")
                        torch.cuda.empty_cache()

                    ###################################
                    ###################################
                    # 载入此轮保存的模型哦!!!!!!!!
                    ###################################
                    ###################################

                    #dev_iter_bangbang1 = tqdm(dev_dataloader,disable=args.local_rank not in (-1, 0))
                    temp = []
                    with torch.no_grad():
                        for step, batch in enumerate(dev_dataloader):
                            batch = [
                                t.to(device) if t is not None else None for t in batch]
                            input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                            masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                                                   masked_pos=masked_pos, masked_weights=masked_weights)
                            if n_gpu > 1:
                                masked_lm_loss = masked_lm_loss.mean()
                            dev_loss = masked_lm_loss
                            bbb = dev_loss.cpu().detach().numpy()
                            temp.append(bbb)
                        dev_loss_fin = sum(temp) / step
                        print('####   Dev loss:', dev_loss_fin,output_model_file)

            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)
                if args.fp16:
                    output_amp_file = os.path.join(
                        args.output_dir, "amp.{0}.bin".format(i_epoch))
                    torch.save(amp.state_dict(), output_amp_file)
                output_sched_file = os.path.join(
                    args.output_dir, "sched.{0}.bin".format(i_epoch))
                torch.save(scheduler.state_dict(), output_sched_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()


            #dev_iter_bangbang = tqdm(dev_dataloader, desc='Iter (loss=X.XXX)',
            #                     disable=args.local_rank not in (-1, 0))
            temp = []
            with torch.no_grad():
                for step, batch in enumerate(dev_dataloader):
                    batch = [
                        t.to(device) if t is not None else None for t in batch]
                    input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, _ = batch
                    masked_lm_loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                                           masked_pos=masked_pos, masked_weights=masked_weights)
                    if n_gpu > 1:  # mean() to average on multi-gpu.
                        # loss = loss.mean()
                        masked_lm_loss = masked_lm_loss.mean()
                    dev_loss = masked_lm_loss
                    bbb = dev_loss.cpu().detach().numpy()
                    temp.append(bbb)
                dev_loss_fin = sum(temp) / step
                print('####   Dev loss:', dev_loss_fin, output_model_file)
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        if args.local_rank != -1:
            train_sampler.set_epoch(epoch)

        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Beispiel #6
0
def train(rank, args, model, model_t, train_dataset_qa, test_dataset_qa, scale_tune):
    """ Train the model """
    global train_count
    train_count += 1
    world_size = 1 if rank < 0 else torch.distributed.get_world_size()

    if rank in [-1, 0]:
        printlog("Train model",train_count)
        printlog(model)

    per_gpu_train_batch_size = args.per_gpu_train_batch_size
    train_batch_size = per_gpu_train_batch_size * world_size
    gradient_accumulation_steps = args.total_train_batch_size // train_batch_size
    num_train_epochs = args.num_train_epochs

    if scale_tune:
        gradient_accumulation_steps = 1
        num_train_epochs = 1

    if rank < 0:
        #single process take all samples
        sampler = RandomSampler(train_dataset_qa)
        dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=train_batch_size, num_workers=4)
    else:
        #special sampler that divide samples beween processes
        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_qa, rank=rank)
        dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=per_gpu_train_batch_size)

    steps_total = int(len(dataloader) // gradient_accumulation_steps * num_train_epochs)

    # Prepare optimizer and schedule
    freeze_list = args.freeze_list.split(',') if args.freeze_list else []
    named_params = []
    for n, p in model.named_parameters():
        if n.lower()!="none" and any(fn in n for fn in freeze_list):
            if rank in [-1, 0]:
                logger.warning("rank {} {} param is frozen and excluded from tune".format(rank,n))
            continue
        named_params.append( (n, p) )

    # split parameters to scale and the rest
    named_params_scale = [(n, p) for n, p in named_params if '.scale' in n]
    named_params_rest = [(n, p) for n, p in named_params if '.scale' not in n]

    if scale_tune:
        #keep only scale parameters
        named_params = named_params_scale
        named_params_rest = []

    groups = []
    if named_params_scale:
        groups.append({'params': [p for n, p in named_params_scale], 'lr': 0.01})
    if named_params_rest:
        groups.append({'params': [p for n, p in named_params_rest],  'lr': args.learning_rate})

    optimizer = AdamW(
        groups,
        eps=1e-08,
        lr=args.learning_rate,
        weight_decay=0)

    def lr_lambda(current_step):
        p = float(current_step) / float(steps_total)
        return 1 - p

    scheduler = LambdaLR(optimizer, lr_lambda)

    if rank in [-1, 0]:
        for n,p in named_params:
            printlog('param for tune',n)
        printlog("scale_tune", scale_tune )
        printlog("dataset size", len(train_dataset_qa) )
        printlog("epoches", num_train_epochs )
        printlog("per_gpu_train_batch_size", per_gpu_train_batch_size )
        printlog("n_gpu", args.n_gpu )
        printlog("world_size", world_size )
        printlog("gradient_accumulation_steps", gradient_accumulation_steps )
        printlog("total train batch size", train_batch_size * gradient_accumulation_steps )
        printlog("steps_total",steps_total )

    global_step = 0
    model.zero_grad()
    indicators = collections.defaultdict(list)

    softplus = torch.nn.Softplus()

    loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) if args.loss_cfg else dict()

    for epoch in range(math.ceil(num_train_epochs)):
        indicators = collections.defaultdict(list)
        model.train()
        set_output_hidden_states(rank, model, (model_t is not None))
        utils.sync_models(rank, model)
        if model_t is not None:
            set_output_hidden_states(rank, model_t, True)
            model_t.train()
        if rank > -1:
            #set epoch to make different samples division betwen process for different epoches
            sampler.set_epoch(epoch)

        for step, batch in enumerate(dataloader):
            epoch_fp = epoch + step/len(dataloader)
            if epoch_fp > num_train_epochs:
                break

            epoch_fp = epoch + step/len(dataloader)

            losses = []

            inputs = get_inputs(batch, args.device)
            targets = get_targets(batch, args.device)
            outputs = model(**inputs, **targets, output_hidden_states=(model_t is not None))
            losses.append(outputs[0])
            outputs = outputs[1:]

            if model_t is not None:
                with torch.no_grad():
                    outputs_t = model_t(**inputs, output_hidden_states=True)
                    hidden_t = outputs_t[2]
                    assert isinstance(hidden_t, (tuple,list)), "hidden states output is not detected right"
                    assert len(hidden_t) == model_t.config.num_hidden_layers+1, "hidden states output is not detected right"

                if args.kd_weight>0:
                    # Calculate knowladge distilation loss
                    kd_losses = []
                    for logit_s,logit_t in zip(outputs[0:2],outputs_t[0:2]):
                        T = 1
                        prob_t = torch.nn.functional.softmax(logit_t.detach() / T, dim=1)
                        logprob_s = torch.nn.functional.log_softmax(logit_s / T, dim=1)
                        kd_losses.append( -(logprob_s * prob_t).mean() * (T * T * prob_t.shape[1]) )
                    losses.append(args.kd_weight*sum(kd_losses)/len(kd_losses))


                hidden_s = outputs[2]
                assert isinstance(hidden_s, (tuple,list)), "hidden states output is not detected right"
                assert len(hidden_s) == model.config.num_hidden_layers+1, "hidden states output is not detected right"

                def align_and_loss_outputs(out_s, out_t):
                    if len(out_s) != len(out_t):
                        #the student and teacher outputs are not aligned. try to find teacher output for each student output
                        n_s, n_t = len(out_s), len(out_t)
                        out_t = [out_t[(i*(n_t-1))//(n_s-1)] for i in range(n_s)]
                    assert len(out_s) == len(out_t), "can not align number of outputs between student and teacher"
                    assert all(s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)), "output shapes for student and teacher are not the same"
                    return [(s - t.detach()).pow(2).mean() for s,t in zip(out_s, out_t)]

                sw_losses = align_and_loss_outputs(hidden_s,hidden_t)

                losses.extend([args.supervision_weight*l for l in sw_losses])

            #average over batch
            losses = [l.mean() for l in losses]

            l = sum(losses)/len(losses)
            indicators['loss'].append(l.item())
            indicators['ll'].append([lll.item() for lll in losses])

            (l/gradient_accumulation_steps).backward()

            del l

            if (step + 1) % gradient_accumulation_steps == 0:
                global_step += 1

                utils.sync_grads(rank, named_params, report_no_grad_params=(global_step==1))
                torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()


                if global_step % 50 == 0:
                    # Log metrics
                    wall_time = epoch + step / len(dataloader)

                    lrp = " ".join(['{:.2f}'.format(t) for t in np.log10(scheduler.get_last_lr())])
                    str_out = "{} ep {:.2f} lrp {}".format(train_count, epoch_fp, lrp)

                    for k,v in indicators.items():
                        v = np.array(v)
                        if len(v.shape)==1:
                            v = v[:,None]

                        if rank>-1:
                            #sync indicators
                            vt = torch.tensor(v).to(args.device)
                            torch.distributed.all_reduce(vt, op=torch.distributed.ReduceOp.SUM)
                            v = vt.cpu().numpy() / float(world_size)

                        str_out += " {} {}".format(k," ".join(["{:.3f}".format(t) for t in v.mean(0)]))


                    if 'time_last' in locals():
                        #estimate processing times
                        dt_iter = (time.time() - time_last) / len(indicators['loss'])
                        dt_ep = dt_iter * len(dataloader)
                        str_out += " it {:.1f}s".format(dt_iter)
                        str_out += " ep {:.1f}m".format(dt_ep / (60))
                        str_out += " eta {:.1f}h".format(dt_ep * (num_train_epochs - epoch_fp) / (60 * 60))
                    time_last = time.time()

                    indicators = collections.defaultdict(list)
                    if rank in [-1, 0]:
                        logger.info(str_out)

        if rank in [-1, 0]:
            check_point_name = 'checkpoint-{:02}'.format(train_count)
            check_point_name = check_point_name + '-{:02}'.format(epoch + 1)
            model.eval()
            set_output_hidden_states(rank, model, False)
            result_s = evaluate(args, model, test_dataset_qa)
            for k,v in result_s.items():
                logger.info("{} {} {}".format(check_point_name, k, result_s[k]))
        if rank>-1:
            torch.distributed.barrier()
Beispiel #7
0
def train(local_rank, config):
    global_rank = config.node_rank * config.n_gpus + local_rank
    print(f"local rank: {[local_rank]}, global_rank: {[global_rank]}")

    # multi-gpu init
    if torch.cuda.is_available():
        if config.world_size > 1:
            dist.init_process_group(                                   
                backend='nccl',                                    
                init_method='env://',                                   
                world_size=config.world_size,                              
                rank=global_rank                                 
            )
            torch.cuda.set_device(local_rank)
            DEVICE = torch.device("cuda", local_rank)
        else:
            DEVICE = torch.device("cuda")
    else:
        DEVICE = torch.device("cpu")

    # build tokenizer
    tokenizer = T5Tokenizer(
        vocab_file="../data/tokenizer/google_sp.model",
        bos_token="<s>",
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="[PAD]",
        cls_token="[CLS]",
        sep_token="[SEP]",
        mask_token="[MASK]",
        extra_ids=0,
        additional_special_tokens=(),
        do_lower_case=True
    )

    # build data source and reporters
    trn_reporter = StatisticsReporter()
    dev_reporter = StatisticsReporter()

    # get data filepaths
    train_filepaths = []
    dev_filepaths = []
    for corpus in config.corpora:
        if corpus == "jp_cc100":
            from corpus.jp_cc100.config import Config
            corpus_config = Config()

            dev_file_idx = 42
            
            corpus_filepaths = sorted(list(filter(
                lambda x: x.endswith(".txt"),
                os.listdir(corpus_config.doc_data_dir)
            )))
            for file_idx, filepath in enumerate(corpus_filepaths):
                if file_idx == dev_file_idx:
                    dev_filepaths.append(f"{corpus_config.doc_data_dir}/{corpus_filepaths[file_idx]}")
                else:
                    train_filepaths.append(f"{corpus_config.doc_data_dir}/{corpus_filepaths[file_idx]}")

            if config.small_data:
                train_filepaths = train_filepaths[:2]

    # load dev data
    if global_rank == 0:
        dev_docs = []
        for dev_filepath in dev_filepaths:
            dev_docs += load_docs_from_filepath(dev_filepath, tokenizer)
        dev_docs = dev_docs[:10000]
        mp_print("----- Loading dev data -----", global_rank)
        dev_data_source = DataSource(config, tokenizer, dev_docs, "dev", randomize=False)
        mp_print(str(dev_data_source.statistics), global_rank)
        dev_dataloader = torch.utils.data.DataLoader(
            dev_data_source,
            batch_size=config.eval_batch_size,
            num_workers=0,
            collate_fn=collate_fn,
            pin_memory=False
        )

    # build model
    model_config = PretrainedConfig.from_json_file(config.model_config_filepath)
    model = GPT2LMHeadModel(model_config)
    model = model.to(DEVICE)

    # load model from checkpoint
    if config.checkpoint_path:
        mp_print("----- Checkpoint loaded -----", global_rank)
        mp_print("checkpoint path: {}".format(config.checkpoint_path), global_rank)
        checkpoint = torch.load(config.checkpoint_path, map_location=DEVICE)
        mp_print("loading model state dict...", global_rank)
        model.load_state_dict(checkpoint["model"])
        model.tie_weights()  # NOTE: don't forget to tie weights after loading weights

    # use mixed precision
    if config.use_amp:
        scaler = amp.GradScaler()

    # use multi gpus
    if config.world_size > 1:
        model = DDP(
            model, 
            device_ids=[local_rank],
            find_unused_parameters=True
        )

    # build optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config.init_lr,
        weight_decay=config.l2_penalty
    )

    # build lr scheduler
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=config.n_warmup_steps,
        num_training_steps=config.n_training_steps,
    )
    
    # init environment or load from checkpoint
    if config.checkpoint_path:
        if config.resume_training:
            mp_print("loading optimizer state dict...", global_rank)
            optimizer.load_state_dict(checkpoint["optimizer"])
            mp_print("recovering lr scheduler...", global_rank)
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            mp_print("recovering others...", global_rank)
            n_step = checkpoint["n_step"]
            start_n_epoch = checkpoint["n_epoch"]
            start_train_file_idx = checkpoint["start_train_file_idx"]
            best_ppl = getattr(checkpoint, "best_ppl", float("inf"))
        else:
            n_step = 0
            start_n_epoch = 0
            start_train_file_idx = 0
            best_ppl = float("inf")
        OUTPUT_FILEID = checkpoint["output_fileid"]
        del checkpoint
    else:
        n_step = 0
        start_n_epoch = 0
        start_train_file_idx = 0
        best_ppl = float("inf")

        # names
        OUTPUT_FILEID = "gpt2-ja-{}.seed_{}.{}".format(
            config.model_size,
            config.seed,
            time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
        )
    if config.filename_note:
        OUTPUT_FILEID += f".{config.filename_note}"
    
    # define logger
    def mlog(s):
        if global_rank == 0:
            if config.enable_log:
                if not os.path.exists("../log/pretrain"):
                    os.makedirs("../log/pretrain")
                with open(f"../log/pretrain/{OUTPUT_FILEID}.log", "a+", encoding="utf-8") as log_f:
                    log_f.write(s+"\n")
            mp_print(s, global_rank)
    if config.enable_log:
        if global_rank == 0:
            tb_writer = SummaryWriter(
                log_dir=f"../log/pretrain/{OUTPUT_FILEID}",
                max_queue=5
            )

    # log hyper parameters
    start_time = time.time()
    mlog("----- Hyper-parameters -----")
    for k, v in sorted(dict(config.__dict__).items()):
        mlog("{}: {}".format(k, v))

    for epoch_idx in range(start_n_epoch, config.n_epochs):
        for train_file_idx in range(start_train_file_idx, len(train_filepaths), config.n_train_files_per_group):
            
            group_train_filepaths = train_filepaths[train_file_idx:train_file_idx+config.n_train_files_per_group]
            
            with mp.Pool(processes=config.n_train_files_per_group) as pool:
                group_train_docs = pool.starmap(
                    load_docs_from_filepath, 
                    [(train_filepath, tokenizer) for train_filepath in group_train_filepaths]
                )
                train_docs = [doc for docs in group_train_docs for doc in docs]

            train_data_source = DataSource(config, tokenizer, train_docs, "train", randomize=True)
            mp_print(str(train_data_source.statistics), global_rank)
            # single gpu or cpu
            if config.world_size == 1 or not torch.cuda.is_available():
                train_data_sampler = RandomSampler(
                    train_data_source,
                    replacement=False
                )
                train_dataloader = torch.utils.data.DataLoader(
                    train_data_source,
                    batch_size=config.batch_size,
                    sampler=train_data_sampler,
                    num_workers=0,
                    collate_fn=collate_fn,
                    pin_memory=True
                )
            # multi gpus
            else:
                train_data_sampler = DistributedSampler(
                    train_data_source,
                    num_replicas=config.world_size,
                    rank=global_rank
                )
                train_dataloader = torch.utils.data.DataLoader(
                    train_data_source,
                    batch_size=config.batch_size,
                    sampler=train_data_sampler,
                    num_workers=0,
                    collate_fn=collate_fn,
                    pin_memory=False
                )

            if isinstance(train_data_sampler, DistributedSampler):
                train_data_sampler.set_epoch(epoch_idx)

            for batch_data in train_dataloader:
                n_step += 1

                # stop if reaches the maximum tranining step
                if n_step >= config.n_training_steps:
                    break

                # forward
                model.train()
                if config.use_amp:
                    with amp.autocast():
                        loss, ppl = forward_step(model, tokenizer, batch_data)
                else:
                    loss, ppl = forward_step(model, tokenizer, batch_data)

                # update statisitcs
                trn_reporter.update_data({"ppl": ppl.item(), "loss": loss.item()})

                # backward
                loss /= config.n_accum_steps
                if config.use_amp:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                del loss

                if n_step % config.n_accum_steps == 0:
                    # clip gradient
                    if config.max_grad_norm > 0.0:
                        if config.use_amp:
                            scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

                    # update model parameters
                    if config.use_amp:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()

                    # zero gradients
                    optimizer.zero_grad()

                # check loss
                if n_step > 0 and n_step % config.check_loss_after_n_step == 0:
                    lr = list(lr_scheduler.optimizer.param_groups)[0]["lr"]
                    log_s = f"{time.time()-start_time:.2f}s Epoch {epoch_idx}, step {n_step}, lr {lr:.5g} - "
                    log_s += trn_reporter.to_string()
                    mlog(log_s)

                    if config.enable_log and global_rank == 0:
                        for k, v in trn_reporter.items():
                            tb_writer.add_scalar(f"{k}/train", np.mean(v), n_step)

                    trn_reporter.clear()

                # evaluation on dev dataset
                if global_rank == 0 and n_step > 0 and n_step % config.validate_after_n_step == 0:
                    
                    # forward
                    with torch.no_grad():
                        model.eval()
                        
                        # use only 1 gpu for evaluation in multi-gpu situation
                        if config.world_size > 1:
                            eval_model = model.module
                        else:
                            eval_model = model

                        for eval_batch_idx, eval_batch_data in enumerate(dev_dataloader):
                            if config.use_amp:
                                with amp.autocast():
                                    loss, ppl = forward_step(eval_model, tokenizer, eval_batch_data)
                            else:
                                loss, ppl = forward_step(eval_model, tokenizer, eval_batch_data)
                            dev_reporter.update_data({"ppl": ppl.item(), "loss": loss.item()})

                            if eval_batch_idx == len(dev_dataloader) - 1:
                                break

                    log_s = f"\n<Dev> - {time.time()-start_time:.3f}s - "
                    log_s += dev_reporter.to_string()
                    mlog(log_s)

                    # Save model if it has better monitor measurement
                    if config.save_model:
                        if not os.path.exists("../data/model/pretrain"):
                            os.makedirs("../data/model/pretrain")

                        model_to_save = model.module if hasattr(model, 'module') else model

                        # save current model
                        checkpoint = {
                            "model": model_to_save.state_dict(),
                            "optimizer": optimizer.state_dict(),
                            "lr_scheduler": lr_scheduler.state_dict(),
                            "n_epoch": epoch_idx,
                            "n_step": n_step,
                            "start_train_file_idx": train_file_idx,
                            "output_fileid": OUTPUT_FILEID,
                            "best_ppl": best_ppl
                        }
                        torch.save(
                            checkpoint,
                            f"../data/model/pretrain/{OUTPUT_FILEID}.checkpoint"
                        )
                        mlog(f"checkpoint saved to data/model/pretrain/{OUTPUT_FILEID}.checkpoint")

                        # save best model
                        cur_ppl = dev_reporter.get_value("ppl")
                        if cur_ppl < best_ppl:
                            best_ppl = cur_ppl

                            torch.save(
                                checkpoint,
                                f"../data/model/pretrain/{OUTPUT_FILEID}.best.checkpoint"
                            )
                            mlog(f"best checkpoint saved to data/model/pretrain/{OUTPUT_FILEID}.best.checkpoint")

                    if config.enable_log:
                        for k, v in dev_reporter.items():
                            tb_writer.add_scalar(f"{k}/dev", np.mean(v), n_step)

                    dev_reporter.clear()

                # decay learning rate
                lr_scheduler.step(dev_reporter.get_value("ppl"))

        # reset starting training file index for every epoch (if might be set to a larger value if resuming from a checkpoint)
        start_train_file_idx = 0
Beispiel #8
0
def main():
    args = process_args()
    if args.loss_type == 'mlm':
        assert args.neg_num == 0 and args.multiple_neg == 0
    elif args.loss_type == 'nsp':
        assert int(args.bi_prob) == 1 and args.max_pred == 0 and args.neg_num > 0

    if args.adaptive_weight == 1:
        assert args.neg_num > 1
    if args.add_boundary == 1:
        assert args.inc_full_hist

    if args.world_size > 1:
        print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank))

    # Input format: [CLS] img [SEP] hist [SEP_0] ques [SEP_1] ans [SEP]
    args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1

    args.mask_image_regions = (args.vis_mask_prob > 0)  # whether to mask out image regions
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    # arguments inspection
    assert args.enable_butd, 'only support region attn! featmap attn deprecated'

    if args.enable_butd:
        if args.visdial_v == '1.0':
            assert (args.len_vis_input == 36) or (args.len_vis_input == 0)
        elif args.visdial_v == '0.9':
            if (args.len_vis_input == 100):
                args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file)
                args.region_det_file_prefix = os.path.join(args.image_root,
                                                           args.region_det_file_prefix) if args.dataset in (
                    'cc', 'coco') and args.region_det_file_prefix != '' else ''

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__, open(os.path.join(
        args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)
    ch = logging.StreamHandler(sys.stdout)
    ch.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(name)s -   %(message)s'))
    ch.setLevel(logging.INFO)
    logger.addHandler(ch)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url,
                                             world_size=args.world_size, rank=args.global_rank)
    logger.info('Arguments: %s\n' % (' '.join(sys.argv[:])))
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(
        args.train_batch_size / args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom
        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {'iter': None, 'score': None}

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank))
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
    assert args.do_train
    logger.info('Max seq length: %d, batch size: %d\n' % (args.max_seq_length, args.train_batch_size))
    bi_uni_pipeline = [Preprocess4TrainVisdial(args.max_pred, args.mask_prob,
                                               list(tokenizer.vocab.keys()),
                                               tokenizer.convert_tokens_to_ids, args.max_seq_length,
                                               new_segment_ids=args.new_segment_ids,
                                               truncate_config={'len_vis_input': args.len_vis_input,
                                                                'max_len_hist_ques': args.max_len_hist_ques,
                                                                'max_len_ans': args.max_len_ans},
                                               mask_image_regions=args.mask_image_regions,
                                               mode="s2s", vis_mask_prob=args.vis_mask_prob,
                                               region_bbox_file=args.region_bbox_file,
                                               region_det_file_prefix=args.region_det_file_prefix,
                                               image_features_hdfpath=args.image_features_hdfpath,
                                               visdial_v=args.visdial_v, pad_hist=args.pad_hist,
                                               finetune=args.finetune,
                                               only_mask_ans=args.only_mask_ans,
                                               add_boundary=args.add_boundary,
                                               only_qa=args.only_qa),
                       Preprocess4TrainVisdial(args.max_pred, args.mask_prob,
                                               list(tokenizer.vocab.keys()),
                                               tokenizer.convert_tokens_to_ids, args.max_seq_length,
                                               new_segment_ids=args.new_segment_ids,
                                               truncate_config={'len_vis_input': args.len_vis_input,
                                                                'max_len_hist_ques': args.max_len_hist_ques,
                                                                'max_len_ans': args.max_len_ans},
                                               mask_image_regions=args.mask_image_regions,
                                               mode="bi", vis_mask_prob=args.vis_mask_prob,
                                               region_bbox_file=args.region_bbox_file,
                                               region_det_file_prefix=args.region_det_file_prefix,
                                               image_features_hdfpath=args.image_features_hdfpath,
                                               visdial_v=args.visdial_v, pad_hist=args.pad_hist,
                                               finetune=args.finetune,
                                               only_mask_ans=args.only_mask_ans,
                                               add_boundary=args.add_boundary,
                                               only_qa=args.only_qa)]

    train_dataset = VisdialDataset(
        args.src_file, args.train_batch_size, data_tokenizer, use_num_imgs=args.use_num_imgs,
        bi_uni_pipeline=bi_uni_pipeline, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob,
        is_train=args.do_train, neg_num=args.neg_num, inc_gt_rel=args.inc_gt_rel,
        inc_full_hist=args.inc_full_hist, just_for_pretrain=args.just_for_pretrain, sub_sample=args.sub_sample)

    if args.world_size == 1:
        train_sampler = RandomSampler(train_dataset, replacement=False)
    else:
        train_sampler = DistributedSampler(train_dataset)

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.train_batch_size, sampler=train_sampler,
                                                   num_workers=args.num_workers,
                                                   collate_fn=batch_list_to_batch_tensors, pin_memory=True)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(len(train_dataloader) * args.num_train_epochs * 1. /
                  args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'img2txt' else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    if (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, 'must init from maximum likelihood training'
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size, relax_projection=relax_projection,
            config_path=args.config_path, task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob, enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type,
            neg_num=args.neg_num, adaptive_weight=args.adaptive_weight, add_attn_fuse=args.add_attn_fuse,
            no_h0=args.no_h0, no_vision=args.no_vision)
        global_step = 0
    else:
        if args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0

        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model, state_dict=model_recover, num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size, relax_projection=relax_projection,
            config_path=args.config_path, task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob, enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type,
            neg_num=args.neg_num, adaptive_weight=args.adaptive_weight, add_attn_fuse=args.add_attn_fuse,
            no_h0=args.no_h0, no_vision=args.no_vision)

        del model_recover
        torch.cuda.empty_cache()

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
        model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
    elif n_gpu > 1:
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(
                optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(
                optimizer, static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             schedule=args.sche_mode,
                             t_total=t_total)


    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)
        logger.info("  Loader length = %d", len(train_dataloader))

        model.train()

        start_epoch = 1
        logger.info("Begin training from epoch = %d", start_epoch)
        t0 = time.time()
        for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"):
            if args.multiple_neg and i_epoch > 1:
                train_dataset = VisdialDataset(
                    args.src_file, args.train_batch_size, data_tokenizer, use_num_imgs=args.use_num_imgs,
                    bi_uni_pipeline=bi_uni_pipeline, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob,
                    is_train=args.do_train, neg_num=args.neg_num, inc_gt_rel=args.inc_gt_rel,
                    inc_full_hist=args.inc_full_hist, just_for_pretrain=args.just_for_pretrain, sub_sample=args.sub_sample)

                if args.world_size == 1:
                    train_sampler = RandomSampler(train_dataset, replacement=False)
                else:
                    train_sampler = DistributedSampler(train_dataset)

                train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                               batch_size=args.train_batch_size, sampler=train_sampler,
                                                               num_workers=args.num_workers,
                                                               collate_fn=batch_list_to_batch_tensors, pin_memory=True)
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            losses = []
            pretext_loss = []
            mlm_losses = []
            nsp_losses = []
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, \
                task_idx, vis_masked_pos, img, vis_pe = batch

                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                if args.enable_butd:
                    conv_feats = img.data  # Bx100x2048
                    vis_pe = vis_pe.data

                loss_tuple = model(conv_feats, vis_pe, input_ids, segment_ids,
                                   input_mask, lm_label_ids, is_next, masked_pos=masked_pos,
                                   masked_weights=masked_weights, task_idx=task_idx,
                                   vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions,
                                   drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0)

                # disable pretext_loss_deprecated for now
                masked_lm_loss, pretext_loss_deprecated, nsp_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu. For dist, this is done through gradient addition.
                    masked_lm_loss = masked_lm_loss.mean()
                    pretext_loss_deprecated = pretext_loss_deprecated.mean()
                    nsp_loss = nsp_loss.mean()
                loss = masked_lm_loss + pretext_loss_deprecated + nsp_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())
                losses.append(loss.item())
                mlm_losses.append(masked_lm_loss.item())
                pretext_loss.append(pretext_loss_deprecated.item())
                nsp_losses.append(nsp_loss.item())

                if step % max(1, nbatches // 10) == 0:
                    logger.info(
                        "Epoch {}, Iter {}, Loss {:.4f}, MLM {:.4f}, NSP {:.4f}, Elapse time {:.2f}\n".format(
                            i_epoch, step, np.mean(losses), np.mean(mlm_losses), np.mean(nsp_losses), time.time() - t0))

                if args.enable_visdom:
                    if vis_window['iter'] is None:
                        vis_window['iter'] = vis.line(
                            X=np.tile(np.arange((i_epoch - 1) * nbatches + step,
                                                (i_epoch - 1) * nbatches + step + 1), (1, 1)).T,
                            Y=np.column_stack((np.asarray([np.mean(losses)]),)),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total'])
                        )
                    else:
                        vis.line(
                            X=np.tile(np.arange((i_epoch - 1) * nbatches + step,
                                                (i_epoch - 1) * nbatches + step + 1), (1, 1)).T,
                            Y=np.column_stack((np.asarray([np.mean(losses)]),)),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total']),
                            win=vis_window['iter'],
                            update='append'
                        )

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                                   warmup_linear(global_step / t_total,
                                                 args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(
                args.output_dir, "model.%d.%.3f.bin" % (i_epoch, np.mean(losses)))
            if args.global_rank in (-1, 0):  # save model if the first device or no dist
                torch.save(copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file)
                logger.info("Save model to %s", output_model_file)
            logger.info("Finish training epoch %d, avg loss: %.2f and takes %.2f seconds" % (
                i_epoch, np.mean(losses), time.time() - t0))
            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
Beispiel #9
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default='tmp',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_file",
        default="training.log",
        type=str,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--do_train",
        action='store_true',
        help="Whether to run training. This should ALWAYS be set to True.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=30,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--global_rank",
                        type=int,
                        default=-1,
                        help="global_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 32-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=100,
                        help="The length of visual token input")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=20,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='b',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=3,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=4,
                        type=int,
                        help="Number of workers for the data loader.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")

    # Others for VLP
    parser.add_argument(
        "--src_file",
        default=['/mnt/dat/COCO/annotations/dataset_coco.json'],
        type=str,
        nargs='+',
        help="The input data file name.")
    parser.add_argument('--enable_visdom', action='store_true')
    parser.add_argument('--visdom_port', type=int, default=8888)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument('--image_root',
                        type=str,
                        default='/mnt/dat/COCO/images')
    parser.add_argument('--dataset',
                        default='coco',
                        type=str,
                        help='coco | flickr30k | cc')
    parser.add_argument('--split',
                        type=str,
                        nargs='+',
                        default=['train', 'restval'])

    parser.add_argument('--world_size',
                        default=1,
                        type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url',
                        default='file://[PT_OUTPUT_DIR]/nonexistent_file',
                        type=str,
                        help='url used to set up distributed training')
    parser.add_argument(
        '--file_valid_jpgs',
        default='/mnt/dat/COCO/annotations/coco_valid_jpgs.json',
        type=str)
    parser.add_argument('--sche_mode',
                        default='warmup_linear',
                        type=str,
                        help="warmup_linear | warmup_constant | warmup_cosine")
    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--use_num_imgs', default=-1, type=int)
    parser.add_argument('--vis_mask_prob', default=0, type=float)
    parser.add_argument('--max_drop_worst_ratio', default=0, type=float)
    parser.add_argument('--drop_after', default=6, type=int)

    parser.add_argument(
        '--s2s_prob',
        default=1,
        type=float,
        help="Percentage of examples that are bi-uni-directional LM (seq2seq)."
    )
    parser.add_argument(
        '--bi_prob',
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.")
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument(
        '--region_bbox_file',
        default=
        'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5',
        type=str)
    parser.add_argument(
        '--region_det_file_prefix',
        default=
        'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval',
        type=str)
    parser.add_argument('--tasks', default='img2txt', help='img2txt | vqa2')
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--scst',
                        action='store_true',
                        help='Self-critical sequence training')

    args = parser.parse_args()

    print('global_rank: {}, local rank: {}'.format(args.global_rank,
                                                   args.local_rank))

    args.max_seq_length = args.max_len_b + args.len_vis_input + 3  # +3 for 2x[SEP] and [CLS]
    args.mask_image_regions = (args.vis_mask_prob > 0
                               )  # whether to mask out image regions
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    # arguments inspection
    assert (args.tasks in ('img2txt', 'vqa2'))
    assert args.enable_butd == True, 'only support region attn! featmap attn deprecated'
    assert (
        not args.scst) or args.dataset == 'coco', 'scst support on coco only!'
    if args.scst:
        assert args.dataset == 'coco', 'scst support on coco only!'
        assert args.max_pred == 0 and args.mask_prob == 0, 'no mask for scst!'
        rl_crit = RewardCriterion()

    if args.enable_butd:
        assert (args.len_vis_input == 100)
        args.region_bbox_file = os.path.join(args.image_root,
                                             args.region_bbox_file)
        args.region_det_file_prefix = os.path.join(
            args.image_root, args.region_det_file_prefix) if args.dataset in (
                'cc', 'coco') and args.region_det_file_prefix != '' else ''

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://localhost:10001',  #args.dist_url,
            world_size=args.world_size,
            rank=args.global_rank)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom
        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {'iter': None, 'score': None}

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir +
        '/.pretrained_model_{}'.format(args.global_rank))
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer

    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_image_regions=args.mask_image_regions,
                mode="s2s",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == 'vqa2'))
        ]
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_image_regions=args.mask_image_regions,
                mode="bi",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == 'vqa2')))

        train_dataset = seq2seq_loader.Img2txtDataset(
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,
            bi_prob=args.bi_prob,
            enable_butd=args.enable_butd,
            tasks=args.tasks)

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs * 1. /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'img2txt' else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, 'must init from maximum likelihood training'
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(
                os.path.join(args.output_dir,
                             "model.{0}.bin".format(recover_step)))
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total * 1. /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0
        if not args.scst:
            model = BertForPreTrainingLossMask.from_pretrained(
                args.bert_model,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                relax_projection=relax_projection,
                config_path=args.config_path,
                task_idx=task_idx_proj,
                max_position_embeddings=args.max_position_embeddings,
                label_smoothing=args.label_smoothing,
                fp32_embedding=args.fp32_embedding,
                cache_dir=args.output_dir +
                '/.pretrained_model_{}'.format(args.global_rank),
                drop_prob=args.drop_prob,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
                tasks=args.tasks)
        else:
            model = BertForSeq2SeqDecoder.from_pretrained(
                args.bert_model,
                max_position_embeddings=args.max_position_embeddings,
                config_path=args.config_path,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                task_idx=task_idx_proj,
                mask_word_id=mask_word_id,
                search_beam_size=1,
                eos_id=eos_word_ids,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input)

        del model_recover
        torch.cuda.empty_cache()

    # deprecated
    # from vlp.resnet import resnet
    # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    # cnn.to(device)
    if args.local_rank != -1:
        try:
            # from apex.parallel import DistributedDataParallel as DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
        # cnn = DDP(cnn)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)
        # cnn = DataParallelImbalance(cnn)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             schedule=args.sche_mode,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir,
                         "optim.{0}.bin".format(recover_step)))
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)
        logger.info("  Loader length = %d", len(train_dataloader))

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              args.num_train_epochs + 1,
                              desc="Epoch"):
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            train_loss = []
            pretext_loss = []
            vqa2_loss = []
            scst_reward = []
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels = batch

                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                if args.enable_butd:
                    conv_feats = img.data  # Bx100x2048
                    vis_pe = vis_pe.data
                else:
                    conv_feats, _ = cnn(img.data)  # Bx2048x7x7
                    conv_feats = conv_feats.view(conv_feats.size(0),
                                                 conv_feats.size(1),
                                                 -1).permute(0, 2,
                                                             1).contiguous()

                if not args.scst:
                    loss_tuple = model(
                        conv_feats,
                        vis_pe,
                        input_ids,
                        segment_ids,
                        input_mask,
                        lm_label_ids,
                        ans_labels,
                        is_next,
                        masked_pos=masked_pos,
                        masked_weights=masked_weights,
                        task_idx=task_idx,
                        vis_masked_pos=vis_masked_pos,
                        mask_image_regions=args.mask_image_regions,
                        drop_worst_ratio=args.max_drop_worst_ratio
                        if i_epoch > args.drop_after else 0)
                    mean_reward = loss_tuple[0].new(1).fill_(0)
                else:
                    # scst training
                    model.eval()
                    position_ids = torch.arange(
                        input_ids.size(1),
                        dtype=input_ids.dtype,
                        device=input_ids.device).unsqueeze(0).expand_as(
                            input_ids)
                    input_dummy = input_ids[:, :args.len_vis_input +
                                            2]  # +2 for [CLS] and [SEP]
                    greedy_res = input_ids.new(
                        input_ids.size(0),
                        input_ids.size(1) - args.len_vis_input - 2).fill_(0)
                    gen_result = input_ids.new(
                        input_ids.size(0),
                        input_ids.size(1) - args.len_vis_input - 2).fill_(0)

                    with torch.no_grad():
                        greedy_res_raw, _ = model(conv_feats,
                                                  vis_pe,
                                                  input_dummy,
                                                  segment_ids,
                                                  position_ids,
                                                  input_mask,
                                                  task_idx=task_idx,
                                                  sample_mode='greedy')
                        for b in range(greedy_res_raw.size(0)):
                            for idx in range(greedy_res_raw.size(1)):
                                if greedy_res_raw[b][idx] not in [
                                        eos_word_ids, pad_word_ids
                                ]:
                                    greedy_res[b][idx] = greedy_res_raw[b][idx]
                                else:
                                    if greedy_res_raw[b][idx] == eos_word_ids:
                                        greedy_res[b][idx] = eos_word_ids
                                    break
                    model.train()
                    gen_result_raw, sample_logprobs = model(
                        conv_feats,
                        vis_pe,
                        input_dummy,
                        segment_ids,
                        position_ids,
                        input_mask,
                        task_idx=task_idx,
                        sample_mode='sample')
                    for b in range(gen_result_raw.size(0)):
                        for idx in range(gen_result_raw.size(1)):
                            if gen_result_raw[b][idx] not in [
                                    eos_word_ids, pad_word_ids
                            ]:
                                gen_result[b][idx] = gen_result_raw[b][idx]
                            else:
                                if gen_result_raw[b][idx] == eos_word_ids:
                                    gen_result[b][idx] = eos_word_ids
                                break

                    gt_ids = input_ids[:, args.len_vis_input + 2:]
                    reward = get_self_critical_reward(greedy_res,
                                                      gt_ids, gen_result,
                                                      gt_ids.size(0))
                    reward = torch.from_numpy(reward).float().to(
                        gen_result.device)
                    mean_reward = reward.mean()
                    loss = rl_crit(sample_logprobs, gen_result.data, reward)

                    loss_tuple = [
                        loss,
                        loss.new(1).fill_(0.),
                        loss.new(1).fill_(0.)
                    ]

                # disable pretext_loss_deprecated for now
                masked_lm_loss, pretext_loss_deprecated, ans_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu. For dist, this is done through gradient addition.
                    masked_lm_loss = masked_lm_loss.mean()
                    pretext_loss_deprecated = pretext_loss_deprecated.mean()
                    ans_loss = ans_loss.mean()
                loss = masked_lm_loss + pretext_loss_deprecated + ans_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())
                train_loss.append(loss.item())
                pretext_loss.append(pretext_loss_deprecated.item())
                vqa2_loss.append(ans_loss.item())
                scst_reward.append(mean_reward.item())
                if step % 100 == 0:
                    logger.info(
                        "Epoch {}, Iter {}, Loss {:.2f}, Pretext {:.2f}, VQA2 {:.2f}, Mean R {:.3f}\n"
                        .format(i_epoch, step, np.mean(train_loss),
                                np.mean(pretext_loss), np.mean(vqa2_loss),
                                np.mean(scst_reward)))

                if args.enable_visdom:
                    if vis_window['iter'] is None:
                        vis_window['iter'] = vis.line(
                            X=np.tile(
                                np.arange((i_epoch - 1) * nbatches + step,
                                          (i_epoch - 1) * nbatches + step + 1),
                                (1, 1)).T,
                            Y=np.column_stack(
                                (np.asarray([np.mean(train_loss)]), )),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total']))
                    else:
                        vis.line(X=np.tile(
                            np.arange((i_epoch - 1) * nbatches + step,
                                      (i_epoch - 1) * nbatches + step + 1),
                            (1, 1)).T,
                                 Y=np.column_stack(
                                     (np.asarray([np.mean(train_loss)]), )),
                                 opts=dict(title='Training Loss',
                                           xlabel='Training Iteration',
                                           ylabel='Loss',
                                           legend=['total']),
                                 win=vis_window['iter'],
                                 update='append')

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(args.output_dir,
                                             "model.{0}.bin".format(i_epoch))
            output_optim_file = os.path.join(args.output_dir,
                                             "optim.{0}.bin".format(i_epoch))
            if args.global_rank in (
                    -1, 0):  # save model if the first device or no dist
                torch.save(
                    copy.deepcopy(model_to_save).cpu().state_dict(),
                    output_model_file)
                # torch.save(optimizer.state_dict(), output_optim_file) # disable for now, need to sanitize state and ship everthing back to cpu

            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
Beispiel #10
0
def main():
    args = _get_parser().parse_args()

    args.device_ids = list(map(int, args.device_ids.split(',')))

    set_seed(args)
    sanity_checks(args)
    init_gpu_params(args)

    tokenizer = Tokenizer(
        os.path.join(args.bert_model, "senti_vocab.txt"),
        os.path.join(args.bert_model, "RoBERTa_Sentiment_kor"))

    train_dataset = NSMCDataSet(data_split="train",
                                tokenizer=tokenizer,
                                max_seq_length=args.max_seq_length,
                                pad_to_max=args.pad_to_max)

    train_sampler = RandomSampler(
        train_dataset) if not args.multi_gpu else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.per_gpu_train_batch_size,
                                  collate_fn=train_dataset.collate_fn)

    model = RobertaForSequenceClassification(
        classifier_dropout=args.classifier_dropout,
        bert_model_dir=args.bert_model,
        pre_trained_model=args.pretrained_bert_model)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs
    warmup_steps = math.ceil(t_total * args.warmup_proportion)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)

    model.zero_grad()
    model.cuda()

    if args.multi_gpu:
        model = DistributedDataParallel(
            model,
            device_ids=[args.device_ids[args.local_rank]],
            output_device=args.device_ids[args.local_rank])

    if args.is_master:
        logger.info(json.dumps(vars(args), indent=4))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Num Epochs = %d", args.num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
            (torch.distributed.get_world_size() if args.multi_gpu else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

    global_steps = 0
    for epoch in range(args.num_train_epochs):
        if args.multi_gpu:
            train_sampler.set_epoch(epoch)

        loss_bce = nn.BCEWithLogitsLoss()

        iter_loss = 0

        model.train()

        pbar = tqdm(train_dataloader, desc="Iter", disable=not args.is_master)
        for step, batch in enumerate(pbar):
            input_ids, attention_mask, labels = batch

            inputs = {
                "input_ids":
                torch.tensor(input_ids, dtype=torch.long).cuda(),
                "attention_mask":
                torch.tensor(attention_mask, dtype=torch.long).cuda()
            }
            logits = model(**inputs)

            labels = torch.tensor(labels, dtype=torch.float).cuda()

            loss = loss_bce(input=logits.view(-1), target=labels.view(-1))
            if args.gradient_accumulation_steps > 1:
                loss /= args.gradient_accumulation_steps
            loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

                global_steps += 1
                if global_steps % args.save_checkpoints_steps == 0 and args.is_master:
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    save_path = os.path.join(args.save_checkpoints_dir,
                                             f"step_{global_steps}.ckpt")
                    torch.save(model_to_save.state_dict(), save_path)

            iter_loss += loss.item()
            pbar.set_postfix({
                "epoch":
                epoch,
                "global_steps":
                global_steps,
                "learning_rate":
                f"{scheduler.get_last_lr()[0]:.10f}",
                "avg_iter_loss":
                f"{iter_loss / (step + 1) * args.gradient_accumulation_steps:.5f}",
                "last_loss":
                f"{loss.item() * args.gradient_accumulation_steps:.5f}"
            })
        pbar.close()

        if args.is_master:
            model_to_save = model.module if hasattr(model, 'module') else model
            save_path = os.path.join(args.save_checkpoints_dir,
                                     f"epoch_{epoch+1}.ckpt")
            torch.save(model_to_save.state_dict(), save_path)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--generation_dataset",
                        default='openi',
                        type=str,
                        help=["mimic-cxr, openi"])
    parser.add_argument("--vqa_rad",
                        default="all",
                        type=str,
                        choices=["all", "chest", "head", "abd"])
    parser.add_argument("--data_set",
                        default="train",
                        type=str,
                        help="train | valid")
    parser.add_argument('--img_hidden_sz',
                        type=int,
                        default=2048,
                        help="Whether to use amp for fp16")

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )

    parser.add_argument(
        "--mlm_task",
        type=str,
        default=True,
        help="The model will train only mlm task!! | True | False")
    parser.add_argument("--train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--num_train_epochs",
                        default=5,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        default=False,
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument("--img_encoding",
                        type=str,
                        default='fully_use_cnn',
                        choices=['random_sample', 'fully_use_cnn'])
    parser.add_argument(
        '--len_vis_input',
        type=int,
        default=256,
        help="The length of visual token input"
    )  #visual token의 fixed length를 100이라 하면, <Unknown> token 100개가 되고, 100개의 word 생성 가능.
    parser.add_argument('--max_len_b',
                        type=int,
                        default=253,
                        help="Truncate_config: maximum length of segment B.")

    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )

    parser.add_argument('--max_pred',
                        type=int,
                        default=10,
                        help="Max tokens of prediction.")
    parser.add_argument(
        '--s2s_prob',
        default=1,
        type=float,
        help=
        "Percentage of examples that are bi-uni-directional LM (seq2seq). This must be turned off!!!!!!! because this is not for seq2seq model!!!"
    )
    parser.add_argument(
        '--bi_prob',
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.")
    parser.add_argument('--hidden_size', type=int, default=768)
    parser.add_argument('--bar', default=False, type=str, help="True or False")

    parser.add_argument("--config_path",
                        default='./pretrained_model/non_cross/config.json',
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--model_recover_path",
        default='./pretrained_model/non_cross/pytorch_model.bin',
        type=str,
        help="The file of fine-tuned pretraining model.")  # model load
    parser.add_argument(
        "--output_dir",
        default='./output_model/base_noncross_mimic_2',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    parser.add_argument(
        "--log_file",
        default="training.log",
        type=str,
        help="The output directory where the log will be written.")

    parser.add_argument('--img_postion',
                        default=True,
                        help="It will produce img_position.")
    parser.add_argument(
        "--do_train",
        action='store_true',
        default=True,
        help="Whether to run training. This should ALWAYS be set to True.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    ############################################################################################################

    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--global_rank",
                        type=int,
                        default=-1,
                        help="global_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        default=False,
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        default=False,
        help=
        "Whether to use 32-bit float precision instead of 32-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        default=False,
                        help="Whether to use amp for fp16")

    parser.add_argument('--new_segment_ids',
                        default=False,
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument(
        '--trunc_seg',
        default='b',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")

    parser.add_argument("--num_workers",
                        default=20,
                        type=int,
                        help="Number of workers for the data loader.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")

    parser.add_argument(
        '--image_root',
        type=str,
        default='/home/mimic-cxr/dataset/image_preprocessing/re_512_3ch/Train')
    parser.add_argument('--split',
                        type=str,
                        nargs='+',
                        default=['train', 'valid'])

    parser.add_argument('--world_size',
                        default=1,
                        type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url',
                        default='file://[PT_OUTPUT_DIR]/nonexistent_file',
                        type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--sche_mode',
                        default='warmup_linear',
                        type=str,
                        help="warmup_linear | warmup_constant | warmup_cosine")
    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--use_num_imgs', default=-1, type=int)
    parser.add_argument('--max_drop_worst_ratio', default=0, type=float)
    parser.add_argument('--drop_after', default=6, type=int)
    parser.add_argument('--tasks',
                        default='report_generation',
                        help='report_generation | vqa')
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")

    args = parser.parse_args()

    print('global_rank: {}, local rank: {}'.format(args.global_rank,
                                                   args.local_rank))
    args.max_seq_length = args.max_len_b + args.len_vis_input + 3  # +3 for 2x[SEP] and [CLS]
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    if args.tasks == 'vqa':
        wandb.init(config=args, project="VQA")
        wandb.config["more"] = "custom"
        args.src_file = '/home/mimic-cxr/dataset/data_RAD'
        args.file_valid_jpgs = '/home/mimic-cxr/dataset/vqa_rad_original_set.json'

    else:
        if args.generation_dataset == 'mimic-cxr':
            wandb.init(config=args, project="report_generation")
            wandb.config["more"] = "custom"
            args.src_file = '/home/mimic-cxr/new_dset/Train_253.jsonl'
            args.file_valid_jpgs = '/home/mimic-cxr/new_dset/Train_253.jsonl'

        else:
            wandb.init(config=args, project="report_generation")
            wandb.config["more"] = "custom"
            args.src_file = '/home/mimic-cxr/dataset/open_i/Train_openi.jsonl'
            args.file_valid_jpgs = '/home/mimic-cxr/dataset/open_i/Valid_openi.jsonl'

    print(" # PID :", os.getpid())
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        print("device", device)
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        print("device", device)
        n_gpu = 1
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=args.dist_url,
                                             world_size=args.world_size,
                                             rank=args.global_rank)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=True)

    if args.do_train:
        print("args.mask_prob", args.mask_prob)
        print("args.train_batch_size", args.train_batch_size)
        bi_uni_pipeline = [
            data_loader.Preprocess4Seq2seq(
                args,
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                args.bar,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mode="s2s",
                len_vis_input=args.len_vis_input,
                local_rank=args.local_rank,
                load_vqa_set=(args.tasks == 'vqa'))
        ]

        bi_uni_pipeline.append(
            data_loader.Preprocess4Seq2seq(
                args,
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                args.bar,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mode="bi",
                len_vis_input=args.len_vis_input,
                local_rank=args.local_rank,
                load_vqa_set=(args.tasks == 'vqa')))

        train_dataset = data_loader.Img2txtDataset(
            args,
            args.data_set,
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,  # this must be set to 1.
            bi_prob=args.bi_prob,
            tasks=args.tasks)

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True)

    t_total = int(
        len(train_dataloader) * args.num_train_epochs * 1. /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'report_generation' else 0

    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    # BERT model will be loaded! from scratch
    if (args.model_recover_path is None):
        _state_dict = {} if args.from_scratch else None
        _state_dict = {}
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            args=args,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks)

        print("scratch model's statedict : ")
        for param_tensor in model.state_dict():
            print(param_tensor, "\t", model.state_dict()[param_tensor].size())
        global_step = 0
        print("The model will train from scratch")

    else:
        print("Task :", args.tasks, args.s2s_prob)
        print("Recoverd model :", args.model_recover_path)
        for model_recover_path in glob.glob(args.model_recover_path.strip()):
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(model_recover_path)

            for key in list(model_recover.keys()):
                model_recover[key.replace('enc.', '').replace(
                    'mlm.', 'cls.')] = model_recover.pop(key)
            global_step = 0

        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            args=args,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks)

        model.load_state_dict(model_recover, strict=False)

        print("The pretrained model loaded and fine-tuning.")
        del model_recover
        torch.cuda.empty_cache()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)

    elif n_gpu > 1:
        model = DataParallelImbalance(model)

    wandb.watch(model)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         schedule=args.sche_mode,
                         t_total=t_total)
    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir,
                         "optim.{0}.bin".format(recover_step)))
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        model.train()
        print("Total Parameters:",
              sum([p.nelement() for p in model.parameters()]))

        if recover_step:
            start_epoch = recover_step + 1
            print("start_epoch", start_epoch)
        else:
            start_epoch = 1

        for i_epoch in trange(start_epoch,
                              args.num_train_epochs + 1,
                              desc="Epoch"):
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            train_loss = []

            avg_loss = 0.0
            batch_count = 0
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, task_idx, img, vis_pe, ans_labels, ans_type, organ = batch
                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                loss_tuple = model(img,
                                   vis_pe,
                                   input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   ans_labels,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   drop_worst_ratio=args.max_drop_worst_ratio
                                   if i_epoch > args.drop_after else 0,
                                   ans_type=ans_type)

                masked_lm_loss, vqa_loss = loss_tuple

                batch_count += 1
                if args.tasks == 'report_generation':
                    masked_lm_loss = masked_lm_loss.mean()
                    loss = masked_lm_loss
                else:
                    vqa_loss = vqa_loss.mean()
                    loss = vqa_loss

                iter_bar.set_description('Iter (loss=%5.3f)' % (loss.item()))
                train_loss.append(loss.item())

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                    args.warmup_proportion)
                    if args.fp16:
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            wandb.log({"train_loss": np.mean(train_loss)})
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_config_file = os.path.join(args.output_dir, 'config.json')

            with open(output_config_file, 'w') as f:
                f.write(model_to_save.config.to_json_string())

            output_model_file = os.path.join(args.output_dir,
                                             "model.{0}.bin".format(i_epoch))
            output_optim_file = os.path.join(args.output_dir,
                                             "optim.{0}.bin".format(i_epoch))
            if args.global_rank in (
                    -1, 0):  # save model if the first device or no dist
                torch.save(
                    copy.deepcopy(model_to_save).cpu().state_dict(),
                    output_model_file)

            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
Beispiel #12
0
def train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc,
          fq_tune_only, model_controller):
    """ Train the model """
    global train_count
    train_count += 1

    world_size = 1 if rank < 0 else torch.distributed.get_world_size()

    if rank in [-1, 0]:
        printlog("Train model", train_count)
        printlog(model)

    q_dataset = train_dataset_qc.q_dataset

    per_gpu_train_batch_size = args.per_gpu_train_batch_size
    train_batch_size = per_gpu_train_batch_size * world_size

    if fq_tune_only:
        gradient_accumulation_steps = 1
        num_train_epochs = 1
    else:
        gradient_accumulation_steps = args.total_train_batch_size // train_batch_size
        num_train_epochs = args.num_train_epochs

    if rank < 0:
        #single process take all
        q_sampler = RandomSampler(q_dataset)
        q_dataloader = DataLoader(q_dataset,
                                  sampler=q_sampler,
                                  batch_size=train_batch_size,
                                  num_workers=4)
    else:
        #special sampler that divide samples between processes
        q_sampler = torch.utils.data.distributed.DistributedSampler(q_dataset,
                                                                    rank=rank)
        q_dataloader = DataLoader(q_dataset,
                                  sampler=q_sampler,
                                  batch_size=per_gpu_train_batch_size)

    steps_total = int(
        len(q_dataloader) // gradient_accumulation_steps * num_train_epochs)

    # Prepare optimizer and schedule
    named_params, groups = utils.make_param_groups(
        rank,
        model,
        args.
        freeze_list,  #list or str with subnames to define frozen parameters
        args.learning_rate,  #learning rate for no FQ parameters
        0.01,  # learning rate for FQ parameters
        fq_tune_only,  #true if only FQ parameters will be optimized
        model_controller)

    optimizer = AdamW(groups, eps=1e-08, lr=args.learning_rate, weight_decay=0)

    def lr_lambda(current_step):
        p = float(current_step) / float(steps_total)
        return 1 - p

    scheduler = LambdaLR(optimizer, lr_lambda)

    if rank in [-1, 0]:
        for n, p in named_params:
            printlog('param for tune', n)
        printlog("fq_tune_only", fq_tune_only)
        printlog("dataset size", len(q_dataset))
        printlog("epoches", num_train_epochs)
        printlog("per_gpu_train_batch_size", per_gpu_train_batch_size)
        printlog("n_gpu", args.n_gpu)
        printlog("world_size", world_size)
        printlog("gradient_accumulation_steps", gradient_accumulation_steps)
        printlog("total train batch size",
                 train_batch_size * gradient_accumulation_steps)
        printlog("steps_total", steps_total)

    global_step = 1
    model.zero_grad()
    indicators = collections.defaultdict(list)

    softplus = torch.nn.Softplus()

    loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')])

    hnm_hist = {}

    for epoch in range(math.ceil(num_train_epochs)):
        indicators = collections.defaultdict(list)
        model.train()
        if model_t:
            model_t.train()
        if rank > -1:
            #set epoch to make different samples division betwen process for different epoches
            q_sampler.set_epoch(epoch)

        utils.sync_models(rank, model)
        for step, q_batch in enumerate(q_dataloader):
            epoch_fp = epoch + step / len(q_dataloader)
            if epoch_fp > num_train_epochs:
                break

            losses = []

            context_ids_pos = q_batch[3]
            q_inputs = get_inputs(q_batch, args.device)
            q_outputs = model(**q_inputs,
                              output_hidden_states=(model_t is not None))
            q_vec = q_outputs[0]

            #get positive embeddings
            c_batch = train_dataset_qc.c_dataset[context_ids_pos.detach().data]
            c_inputs = get_inputs(c_batch, args.device)
            c_outputs = model(**c_inputs,
                              output_hidden_states=(model_t is not None))
            c_vec_pos = c_outputs[0]

            if model_t is not None:
                q_emb_s, q_hidden_s = q_outputs
                c_emb_s, c_hidden_s = c_outputs
                with torch.no_grad():
                    q_emb_t, q_hidden_t = model_t(**q_inputs,
                                                  output_hidden_states=True)
                    c_emb_t, c_hidden_t = model_t(**c_inputs,
                                                  output_hidden_states=True)

                def align_and_loss_outputs(out_s, out_t):
                    if len(out_s) != len(out_t):
                        #the student and teacher outputs are not aligned. try to find teacher output for each student output
                        n_s, n_t = len(out_s), len(out_t)
                        out_t = [
                            out_t[(i * (n_t - 1)) // (n_s - 1)]
                            for i in range(n_s)
                        ]
                    assert len(out_s) == len(
                        out_t
                    ), "can not align number of outputs between student and teacher"
                    assert all(
                        s[0] == s[1]
                        for s in zip(out_s[0].shape, out_t[0].shape)
                    ), "output shapes for student and teacher are not the same"
                    return [(s - t.detach()).pow(2).mean()
                            for s, t in zip(out_s, out_t)]

                l_q = align_and_loss_outputs(q_hidden_s, q_hidden_t)
                l_c = align_and_loss_outputs(c_hidden_s, c_hidden_t)

                emb_loss = loss_cfg.get('emb_loss', '')
                if emb_loss == 'L2':
                    l_q.append((q_emb_s - q_emb_t.detach()).pow(2).mean())
                    l_c.append((c_emb_s - c_emb_t.detach()).pow(2).mean())
                elif emb_loss == 'L1':
                    l_q.append((q_emb_s - q_emb_t.detach()).abs().mean())
                    l_c.append((c_emb_s - c_emb_t.detach()).abs().mean())
                elif emb_loss.lower() not in ['', 'none', '0', 'disable']:
                    raise Exception(
                        'emb_loss={} is unsupported'.format(emb_loss))

                losses.extend([args.supervision_weight * l for l in l_c + l_q])

            triplet_num = int(loss_cfg.get('triplet_num', 1))
            if fq_tune_only:
                triplet_num = 0

            if triplet_num > 0:
                #disable grad to select negatives
                with torch.no_grad():
                    hnm_scores = []
                    hnm_idxs = []

                    #check that current step has no HNM conext vector
                    if global_step not in hnm_hist and args.hnm_num > 0:
                        #generate the new one

                        if world_size > 1 and (args.hnm_num % world_size) != 0:
                            #aligh hnm_num per each replica
                            hnm_plus = world_size - (args.hnm_num % world_size)
                            args.hnm_num += hnm_plus
                            logger.warning(
                                "rank {} args.hnm_num increased by {} from {} to {} to be the same after division by {} replicas."
                                .format(rank, hnm_plus,
                                        args.hnm_num - hnm_plus, args.hnm_num,
                                        world_size))

                        # generate random contexts to calc embedding
                        context_ids_all = torch.randint(
                            low=0,
                            high=len(train_dataset_qc.c_dataset),
                            size=[args.hnm_num])

                        if rank < 0:  #single process take all
                            context_ids = context_ids_all
                        else:
                            #broadcast one sigle indicies to all processes
                            context_ids_all = context_ids_all.to(args.device)
                            torch.distributed.broadcast(context_ids_all, 0)
                            context_ids_all = context_ids_all.cpu()

                            #each process take only small part to calc embedding
                            s = ((rank + 0) * args.hnm_num) // world_size
                            e = ((rank + 1) * args.hnm_num) // world_size
                            context_ids = context_ids_all[s:e]

                        batch_size = min(args.hnm_batch_size,
                                         context_ids.shape[0])

                        s, e = 0, batch_size
                        c_outputs = []
                        while e > s:
                            idx = context_ids.detach()[s:e]
                            c_batch = train_dataset_qc.c_dataset[idx]
                            inputs = get_inputs(c_batch, args.device)
                            outputs = model(**inputs,
                                            output_hidden_states=False)
                            c_outputs.append(outputs[0])
                            s, e = e, min(e + batch_size, context_ids.shape[0])

                        context_emb = torch.cat(c_outputs, dim=0)

                        if rank < 0:
                            # single process calculated all
                            context_emb_all = context_emb
                        else:
                            context_emb_list = [
                                torch.zeros_like(context_emb)
                                for _ in range(world_size)
                            ]
                            torch.distributed.all_gather(
                                context_emb_list, context_emb)
                            context_emb_all = torch.cat(context_emb_list,
                                                        dim=0)

                        hnm_hist[global_step] = (context_ids_all,
                                                 context_emb_all)

                        #check history size and crop the oldest one
                        if len(hnm_hist) > args.hnm_hist_num:
                            del hnm_hist[min(hnm_hist.keys())]

                    #calc HNM scores for current question batch
                    for hist_step, (c_idx, c_vec) in hnm_hist.items():
                        w = args.hnm_hist_alpha**(global_step - hist_step)
                        t1 = q_vec[:, None, :]
                        t2 = c_vec[None, :, :]
                        d = (t1 - t2)
                        score = -d.norm(2, dim=-1)
                        score = score * w

                        hnm_scores.append(score)
                        hnm_idxs.append(c_idx)

                    if hnm_scores:
                        #choose the hardest negative if we have scores
                        score = torch.cat(hnm_scores, dim=-1)
                        idx = torch.cat(hnm_idxs, dim=-1)
                        score = score.cpu()
                        pos_mask = (context_ids_pos[:,
                                                    None] == idx[None, :]).to(
                                                        dtype=score.dtype,
                                                        device=score.device)
                        score = (1 - pos_mask) * score + pos_mask * score.min(
                        )  #make positive context with small score to avoid chose it as hard neg
                        hn_idx = score.argmax(dim=1, keepdim=True)

                        context_ids_neg = idx[hn_idx]
                    else:
                        #just random selection in case of no scores for HNM
                        size = (context_ids_pos.shape[0], 1)
                        context_ids_neg = torch.randint(
                            0,
                            len(train_dataset_qc.c_dataset) - 1, size)
                        shift = (context_ids_neg >= context_ids_pos[:, None])
                        context_ids_neg = context_ids_neg + shift.to(
                            dtype=context_ids_neg.dtype)

                d_pos = (q_vec - c_vec_pos).norm(2, dim=-1)
                # get negative embeddings and calc losses
                for neg_index in range(context_ids_neg.shape[1]):
                    ids = context_ids_neg[:, neg_index]
                    c_batch = train_dataset_qc.c_dataset[ids.detach()]
                    inputs = get_inputs(c_batch, args.device)

                    outputs = model(**inputs, output_hidden_states=False)
                    c_vec_neg = outputs[0]

                    for triplet_index in range(triplet_num):

                        if triplet_index == 0:
                            d_neg = (q_vec - c_vec_neg).norm(2, dim=-1)
                        if triplet_index == 1:
                            d_neg = (c_vec_pos - c_vec_neg).norm(2, dim=-1)

                        d_diff = d_pos - d_neg

                        indicators['dd' + str(triplet_index)].append(
                            [v.mean().item() for v in (d_pos, d_neg, d_diff)])

                        l = softplus(d_diff)
                        losses.append(l)

                        del d_neg
                del d_pos

                #average over batch
                losses = [l.mean() for l in losses]

            l = sum(losses) / len(losses)
            (l / gradient_accumulation_steps).backward()

            indicators['loss'].append(l.item())
            indicators['ll'].append([lll.item() for lll in losses])

            #del losses
            del l

            if (step + 1) % gradient_accumulation_steps == 0:

                utils.sync_grads(rank,
                                 named_params,
                                 report_no_grad_params=(global_step == 1))
                torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if global_step % 10 == 0:
                    # Log metrics
                    wall_time = epoch + step / len(q_dataloader)

                    lrp = [
                        '{:.2f}'.format(i)
                        for i in np.log10(scheduler.get_last_lr())
                    ]

                    str_out = "{} ep {:.2f} lrp {}".format(
                        train_count, epoch_fp, " ".join(lrp))

                    for k, v in indicators.items():
                        v = np.array(v)
                        if len(v.shape) == 1:
                            v = v[:, None]

                        if rank > -1:
                            #sync indicators
                            vt = torch.tensor(v).to(args.device)
                            torch.distributed.all_reduce(
                                vt, op=torch.distributed.ReduceOp.SUM)
                            v = vt.cpu().numpy() / float(world_size)

                        str_out += " {} {}".format(
                            k,
                            " ".join(["{:.3f}".format(t) for t in v.mean(0)]))

                    if 'score' in locals():
                        str_out += " SS {}".format(list(score.shape))

                    if 'time_last' in locals():
                        dt_iter = (time.time() - time_last) / len(
                            indicators['loss'])
                        dt_ep = dt_iter * len(q_dataloader)
                        str_out += " it {:.1f}s".format(dt_iter)
                        str_out += " ep {:.1f}m".format(dt_ep / (60))
                        str_out += " eta {:.1f}h".format(
                            dt_ep * (num_train_epochs - epoch_fp) / (60 * 60))
                    time_last = time.time()

                    indicators = collections.defaultdict(list)
                    if rank in [-1, 0]:
                        logger.info(str_out)

        if rank in [-1, 0]:
            check_point_name = 'checkpoint-{:02}'.format(train_count)
            check_point_name = check_point_name + '-{:02}'.format(epoch + 1)
            result_s = evaluate(args, model.eval(), test_dataset_qc)
            for k, v in result_s.items():
                logger.info("{} {} {}".format(check_point_name, k, v))
        if rank > -1:
            torch.distributed.barrier()
Beispiel #13
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--topic_model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--topic_model_dict_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument('--topic_mode',
                        default=1,
                        type=float,
                        help="1:idea1 1.1:idea1_wo_theta 2:idea2 ")
    parser.add_argument('--topic_model',
                        default=False,
                        type=bool,
                        help="if only use topic model")
    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=192,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--train_batch_size",
        default=32,
        type=int,
        help="Total batch size for training.")  #batch_size = batch_size/n_gpus
    parser.add_argument("--eval_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=30,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()
    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)
    print("args.local_rank", args.local_rank)
    print("args.no_cuda", args.no_cuda)
    if args.local_rank == -1 or args.no_cuda:  #-1 False
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")  #device = cuda
        n_gpu = torch.cuda.device_count()
        print("n_gpu_1", n_gpu)
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
        print("n_gpu_1", n_gpu)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()
    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.data_dir,
            args.topic_model_dict_path,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)
    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2    ### type_vocab_size=6
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        unilm = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
        global_step = 0

    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:  # here is the entrance
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        unilm = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
    #1. 模型初始化,入口定义好
    gsm = GSM(train_dataset.vocabsize)
    gsm_checkpoint = torch.load(args.topic_model_recover_path)
    gsm.load_state_dict(gsm_checkpoint["net"])

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        unilm.half()
        gsm.half()
        if args.fp32_embedding:
            unilm.bert.embeddings.word_embeddings.float()
            unilm.bert.embeddings.position_embeddings.float()
            unilm.bert.embeddings.token_type_embeddings.float()
    unilm.to(device)
    gsm.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        unilm = DDP(unilm,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        unilm = DataParallelImbalance(unilm)
        gsm = DataParallelImbalance(gsm)
    # Prepare optimizer
    total = 0
    param_optimizer = list(unilm.named_parameters())
    param_optimizer_topic = list(gsm.named_parameters())
    for name, parameters in unilm.named_parameters():
        if "idea" in name:
            if "11" in name and "idea2" in name:
                total += np.prod(parameters.size())
                # print(name, ':', parameters.size())
        else:
            total += np.prod(parameters.size())
            # print(name, ':', parameters.size())

    print("gsm have {} paramerters in total".format(
        sum(x.numel() for x in gsm.parameters())))
    print("Number of parameter: %.6fM" % (total / 1e6))
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    if not args.topic_model:
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01,
            'topic':
            False
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0,
            'topic':
            False
        }, {
            'params': [p for n, p in param_optimizer_topic],
            'weight_decay':
            0.0,
            'lr':
            1e-3,
            'topic':
            True
        }]
    else:
        optimizer_grouped_parameters = [{
            'params': [p for n, p in param_optimizer_topic],
            'weight_decay':
            0.0,
            'lr':
            1e-3,
            'topic':
            True
        }]
    #一部分是有weight的,一部分是没有weight_dacay的
    # print("optimizer_grouped_parameters", optimizer_grouped_parameters)
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)
    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        unilm.train()
        gsm.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        print("000000", args.local_rank, start_epoch,
              int(args.num_train_epochs) + 1)
        topicloss = []
        unilmloss = []
        topicloss_lst = []
        unilmloss_lst = []
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):

            loss_sum = 0.0
            ppx_sum = 0.0
            word_count = 0.0
            doc_count = 0.0
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:  #false
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch
                else:  #这里加了bows
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, bows = batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                p_x, mus, log_vars, theta, beta, topic_embedding = gsm(bows)
                if not args.topic_model:
                    loss_tuple = unilm(input_ids,
                                       theta,
                                       beta,
                                       topic_embedding,
                                       args.topic_mode,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv)
                    masked_lm_loss, next_sentence_loss = loss_tuple

                ## topic loss
                logsoftmax = torch.log(p_x + 1e-10)
                rec_loss = -1.0 * torch.sum(
                    bows * logsoftmax
                )  #bows*logsoftmax = [batch_size, |V|], 其中torch.sum 把所有的loss全部加起来了,也可以只用加某一维度。
                rec_loss_per = -1.0 * torch.sum(bows * logsoftmax, dim=1)
                rec_loss_per = rec_loss_per.cpu().detach().numpy()

                kl_div = -0.5 * torch.sum(1 + log_vars - mus.pow(2) -
                                          log_vars.exp())
                loss_topic = rec_loss + kl_div
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    loss_topic = loss_topic.mean()
                    if not args.topic_model:
                        masked_lm_loss = masked_lm_loss.mean()
                        next_sentence_loss = next_sentence_loss.mean()
                if not args.topic_model:
                    loss_unilm = masked_lm_loss + next_sentence_loss

                # cal perplexity
                word_count_list = []
                loss_sum += loss_topic.item()
                for bow in bows:
                    word_num = torch.sum(bow).cpu().numpy()
                    word_count_list.append(word_num)
                    word_count += word_num

                word_count_np = np.array(word_count_list)
                doc_count += len(bows)
                ppx_sum += np.sum(np.true_divide(rec_loss_per, word_count_np))
                topicloss_lst.append(loss_topic.item() / len(bows))
                if not args.topic_model:
                    unilmloss_lst.append(loss_unilm.item())
                #topic_loss end
                if not args.topic_model:
                    loss = loss_unilm + loss_topic
                else:
                    loss = loss_topic
                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:  # =1
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            if not param_group['topic']:
                                param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                if not args.topic_model:
                    iter_bar.set_description(
                        'Iter (loss_unilm=%5.3f),Iter (ppl=%5.3f)' %
                        (loss_unilm.item(),
                         np.sum(np.true_divide(rec_loss_per, word_count_np))))
                else:
                    iter_bar.set_description(
                        'Iter (loss_topic=%5.3f), (ppl=%5.3f)' %
                        (loss_topic.item(),
                         np.sum(np.true_divide(rec_loss_per, word_count_np))))
            #Save a trained model
            ppx_word = np.exp(loss_sum / word_count)
            ppx_document = np.exp(ppx_sum / doc_count)
            print("********")
            print("word_count", word_count)
            print("ppx_word", ppx_word)
            print("ppx_document", ppx_document)
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                #save unilm model
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                unilm_model_to_save = unilm.module if hasattr(
                    unilm, 'module') else unilm  # Only save the model it-self
                output_unilm_model_file = os.path.join(
                    args.output_dir, "unilm.{0}.bin".format(i_epoch))
                torch.save(unilm_model_to_save.state_dict(),
                           output_unilm_model_file)
                #save topic model
                logger.info(
                    "** ** * Saving topic model and optimizer ** ** * ")
                topic_model_to_save = gsm.module if hasattr(
                    gsm, 'module') else gsm  # Only save the model it-self
                output_topic_model_file = os.path.join(
                    args.output_dir, "topic.{0}.ckpt".format(i_epoch))
                torch.save(topic_model_to_save.state_dict(),
                           output_topic_model_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
        smth_pts = smooth_curve(topicloss_lst)
        # plt.plot(range(len(topicloss_lst)), topicloss_lst)
        plt.plot(range(len(smth_pts)), smth_pts)
        plt.xlabel('epochs')
        plt.title('Topic Model Train Loss')
        plt.savefig(args.output_dir + '/topic_loss.png')
        plt.cla()
        plt.plot(range(len(unilmloss_lst)), unilmloss_lst)
        plt.xlabel('epochs')
        plt.title('Unilm Train Loss')
        plt.savefig(args.output_dir + '/unilm_loss.png')
Beispiel #14
0
def train(args, train_dataset, label_list, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        if args.log_dir:
            tb_writer = SummaryWriter(args.log_dir)
        else:
            tb_writer = SummaryWriter()
        log_writer = open(os.path.join(args.output_dir, "evaluate_logs.txt"),
                          'a')

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # model recover
    recover_step = get_max_steps(args.output_dir)
    model_recover_path = None
    if recover_step:
        model_recover_path = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(recover_step))
        logger.info(" ** Recover model checkpoint in %s ** ",
                    model_recover_path)
        model.load_state_dict(
            torch.load(os.path.join(model_recover_path, WEIGHTS_NAME),
                       map_location='cpu'))
        model.to(args.device)

        # check if saved optimizer or scheduler states exist
        if os.path.isfile(os.path.join(
                model_recover_path, "optimizer.pt")) and os.path.isfile(
                    os.path.join(model_recover_path, "scheduler.pt")):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_recover_path, "optimizer.pt"),
                           map_location='cpu'))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_recover_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  Logging steps = %d", args.logging_steps)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if model_recover_path and os.path.exists(model_recover_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = recover_step
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss, best_avg = 0.0, 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility

    def logging():
        results = evaluate(args,
                           model,
                           tokenizer,
                           label_list,
                           single_gpu=True,
                           splits=args.eval_splits.split(','))
        for task, result in results.items():
            for key, value in result.items():
                tb_writer.add_scalar("eval_{}_{}".format(task, key), value,
                                     global_step)
        log_writer.write("{0}\t{1}\n".format(global_step, json.dumps(results)))
        log_writer.flush()
        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
        tb_writer.add_scalar("loss",
                             (tr_loss - logging_loss) / args.logging_steps,
                             global_step)
        return results

    def save_checkpoint(cur_step):
        output_dir = os.path.join(args.output_dir,
                                  "checkpoint-{}".format(cur_step))
        logger.info("Saving model checkpoint to %s", output_dir)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # Take care of distributed/parallel training
        model_to_save = model.module if hasattr(model, "module") else model
        model_to_save.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        torch.save(args, os.path.join(output_dir, "training_args.bin"))

        torch.save(optimizer.state_dict(),
                   os.path.join(output_dir, "optimizer.pt"))
        torch.save(scheduler.state_dict(),
                   os.path.join(output_dir, "scheduler.pt"))
        logger.info("Saving optimizer and scheduler states to %s", output_dir)

    for epc in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        if not args.use_all_samples_per_epoch:
            if args.local_rank != -1:
                train_sampler.set_epoch(epc)
            if epc > 0:
                train_dataset.shuffle()

        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            if args.filter_k == 0:
                # no cross-attention layer
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "labels": batch[3]
                }
            else:
                input_ids = [d[0].to(args.device) for d in batch]
                attention_mask = [d[1].to(args.device) for d in batch]
                labels = [d[3].to(args.device) for d in batch][0]
                inputs = {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "labels": labels
                }

                if args.alpha > 0:
                    soft_labels = [d[4].to(args.device) for d in batch][0]
                    inputs["soft_labels"] = soft_labels

            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2] if args.model_type in ["bert"] else None
                )  # XLM and DistilBERT don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                # Save model checkpoint
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    save_checkpoint(global_step)

                    if args.evaluate_during_training:
                        cur_result = logging()
                        logger.info(cur_result)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.local_rank in [-1, 0] and args.logging_each_epoch:
            logging_loss = tr_loss
            save_checkpoint(global_step)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()
        log_writer.close()

    return global_step, tr_loss / (global_step + 1)
Beispiel #15
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--local_debug",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()
    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline,
            corpus_preprocessors=corpus_preprocessors)
        train_dataset.initial()
        print(len(train_dataset.ex_list))
        print(train_dataset.batch_size)

        # assert 1==0

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        # c = 0
        # for i_epoch in trange(0, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)):
        #     if args.local_rank != -1:
        #         train_sampler.set_epoch(i_epoch)
        #     iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)',
        #                     disable=args.local_rank not in (-1, 0))
        #     for step, batch in enumerate(iter_bar):
        #         batch = [
        #             t.to(device) if t is not None else None for t in batch]
        #         if args.has_sentence_oracle:
        #             input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
        #         else:
        #             input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch
        #             oracle_pos, oracle_weights, oracle_labels = None, None, None
        #         c += input_ids.shape[0]
        #
        #         # print(input_ids)
        #         #
        #         # print(input_ids.shape)
        #         # print(segment_ids)
        #         # print(segment_ids.shape)
        #         # print(is_next)
        #         # print(task_idx)
        #         # print(sop_label)
        #         # print(task_idx.shape)
        #         # for i in range(input_mask.shape[0]):
        #         #     print(input_mask[i])
        #     print(c)
        #     print(train_dataset.c)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
                else:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch
                    print(sop_label)
                    print(task_idx)
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                # loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next,
                #                    masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx,
                #                    masked_pos_2=oracle_pos, masked_weights_2=oracle_weights,
                #                    masked_labels_2=oracle_labels, mask_qkv=mask_qkv)
                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   sop_label,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv)
                masked_lm_loss, next_sentence_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                print('mask_lm_loss {}'.format(masked_lm_loss))
                print('next_sentence_loss {}'.format(next_sentence_loss))
                print('----------------------------------------------')
                loss = masked_lm_loss + next_sentence_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
Beispiel #16
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def pad_examples(examples, padding_value=tokenizer.pad_token_id):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=padding_value)

    def collate(examples):
        text_examples = [None] * len(examples)
        text_labels = [None] * len(examples)
        text_type_ids = [None] * len(examples)

        video_examples = [None] * len(examples)
        video_labels = [None] * len(examples)
        video_type_ids = [None] * len(examples)

        joint_examples = [None] * len(examples)
        joint_labels = [None] * len(examples)
        joint_type_ids = [None] * len(examples)

        for i, (te, tl, tti, ve, vl, vti, je, jl, jti) in enumerate(examples):
            text_examples[i] = te
            video_examples[i] = ve
            text_labels[i] = tl
            video_labels[i] = vl

            text_type_ids[i] = tti
            video_type_ids[i] = vti

            joint_examples[i] = je
            joint_labels[i] = jl
            joint_type_ids[i] = jti

        padded_text_ids = pad_examples(text_examples)
        text_attention_mask = torch.ones(padded_text_ids.shape,
                                         dtype=torch.int64)
        text_attention_mask[(padded_text_ids == 0)] = 0

        padded_video_ids = pad_examples(video_examples)
        video_attention_mask = torch.ones(padded_video_ids.shape,
                                          dtype=torch.int64)
        video_attention_mask[(padded_video_ids == 0)] = 0

        padded_joint_ids = pad_examples(joint_examples)
        joint_attention_mask = torch.ones(padded_joint_ids.shape,
                                          dtype=torch.int64)
        joint_attention_mask[(padded_joint_ids == 0)] = 0

        return padded_text_ids, \
               torch.tensor(text_labels, dtype=torch.int64), \
               pad_examples(text_type_ids, padding_value=0), \
               text_attention_mask, \
               padded_video_ids, \
               torch.tensor(video_labels, dtype=torch.int64), \
               pad_examples(video_type_ids, padding_value=0), \
               video_attention_mask, \
               padded_joint_ids, \
               torch.tensor(joint_labels, dtype=torch.int64), \
               pad_examples(joint_type_ids, padding_value=0), \
               joint_attention_mask

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        if args.local_rank != -1:
            train_sampler.set_epoch(epoch)

        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            text_ids = batch[0]
            text_seq_labels = batch[1]
            text_token_type_ids = batch[2]
            text_attention_masks = batch[3]

            video_ids = batch[4]
            video_seq_labels = batch[5]
            video_token_type_ids = batch[6]
            video_attention_masks = batch[7]

            joint_ids = batch[8]
            joint_labels = batch[9]
            joint_token_type_ids = batch[10]
            joint_attention_masks = batch[11]

            text_inputs, text_mask_labels = mask_tokens(
                text_ids, tokenizer, args) if args.mlm else (text_ids,
                                                             text_ids)
            video_inputs, video_mask_labels = mask_tokens(
                video_ids, tokenizer, args) if args.mlm else (video_ids,
                                                              video_ids)
            joint_inputs, joint_mask_labels = mask_tokens(
                joint_ids, tokenizer, args) if args.mlm else (joint_ids,
                                                              joint_ids)

            text_inputs = text_inputs.to(args.device)
            text_mask_labels = text_mask_labels.to(args.device)
            text_seq_labels = text_seq_labels.to(args.device)

            text_token_type_ids = text_token_type_ids.to(args.device)
            video_token_type_ids = video_token_type_ids.to(args.device)
            joint_token_type_ids = joint_token_type_ids.to(args.device)

            text_attention_masks = text_attention_masks.to(args.device)
            video_attention_masks = video_attention_masks.to(args.device)
            joint_attention_masks = joint_attention_masks.to(args.device)

            video_inputs = video_inputs.to(args.device)
            video_mask_labels = video_mask_labels.to(args.device)
            video_seq_labels = video_seq_labels.to(args.device)

            joint_inputs = joint_inputs.to(args.device)
            joint_mask_labels = joint_mask_labels.to(args.device)
            joint_labels = joint_labels.to(args.device)

            model.train()

            outputs = model(
                text_input_ids=text_inputs,
                video_input_ids=video_inputs,
                joint_input_ids=joint_inputs,
                text_token_type_ids=text_token_type_ids,
                video_token_type_ids=video_token_type_ids,
                joint_token_type_ids=joint_token_type_ids,
                text_attention_mask=text_attention_masks,
                video_attention_mask=video_attention_masks,
                joint_attention_mask=joint_attention_masks,
                text_masked_lm_labels=text_mask_labels,
                video_masked_lm_labels=video_mask_labels,
                joint_masked_lm_labels=joint_mask_labels,
                text_next_sentence_label=text_seq_labels,
                video_next_sentence_label=video_seq_labels,
                joint_vis_lin_label=joint_labels,
            )

            loss = outputs[0]
            text_loss = outputs[1]
            video_loss = outputs[2]
            joint_loss = outputs[3]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                # text_loss = text_loss / args.gradient_accumulation_steps
                # video_loss = video_loss / args.gradient_accumulation_steps
                # joint_loss = joint_loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                print('loss:', loss.item(), 'text loss:', text_loss.item(),
                      'video loss:', video_loss.item(), 'joint loss:',
                      joint_loss.item())

                # keep BERT embeddings frozen
                model.bert.embeddings.word_embeddings.weight.grad[
                    globals.frozen_indices] = 0

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    print('writing tf logs...')
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    tb_writer.add_scalar("text_loss", text_loss.item(),
                                         global_step)
                    tb_writer.add_scalar("video_loss", video_loss.item(),
                                         global_step)
                    tb_writer.add_scalar("joint_loss", joint_loss.item(),
                                         global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Beispiel #17
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    #Train File
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data src file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The input data tgt file name.")
    parser.add_argument("--check_file",
                        default=None,
                        type=str,
                        help="The input check knowledge data file name")

    #KS File
    parser.add_argument("--ks_src_file",
                        default=None,
                        type=str,
                        help="The input ks data src file name.")
    parser.add_argument("--ks_tgt_file",
                        default=None,
                        type=str,
                        help="The input ks data tgt file name.")

    parser.add_argument("--predict_input_file",
                        default=None,
                        type=str,
                        help="predict_input_file")
    parser.add_argument("--predict_output_file",
                        default=None,
                        type=str,
                        help="predict_output_file")

    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument("--predict_bleu",
                        default=0.2,
                        type=float,
                        help="The Predicted Bleu for KS Predict ")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run ks predict.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--train_avg_bpe_length",
                        default=25,
                        type=int,
                        help="average bpe length for train.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion_step",
        default=300,
        type=int,
        help=
        "Proportion of training to perform linear learning rate warmup for. ")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=67,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)

    handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"),
                                  encoding='UTF-8')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)

    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)

    logger.addHandler(handler)
    logger.addHandler(console)

    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    #Random Seed

    #torch.backends.cudnn.enabled = False
    #torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
    # if n_gpu > 0:
    # 	torch.cuda.manual_seed_all(args.seed)

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    #Data process pipelines
    bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    C_bi_uni_pipeline = [
        seq2seq_loader.C_Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    ks_predict_bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Seq2seq_predict(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]

    if args.do_train:
        print("Loading QKR Train Dataset", args.data_dir)
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        fn_check = os.path.join(args.data_dir, args.check_file)

        train_dataset = seq2seq_loader.C_Seq2SeqDataset(
            fn_src,
            fn_tgt,
            fn_check,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=C_bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        print("Loading KS Train Dataset", args.data_dir)
        ks_fn_src = os.path.join(args.data_dir, args.ks_src_file)
        ks_fn_tgt = os.path.join(args.data_dir, args.ks_tgt_file)
        ks_train_dataset = seq2seq_loader.Seq2SeqDataset(
            ks_fn_src,
            ks_fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            ks_train_sampler = RandomSampler(ks_train_dataset,
                                             replacement=False)
            _batch_size = args.train_batch_size
        else:
            ks_train_sampler = DistributedSampler(ks_train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        ks_train_dataloader = torch.utils.data.DataLoader(
            ks_train_dataset,
            batch_size=_batch_size,
            sampler=ks_train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
        t_total = int(
            len(train_dataloader) * args.num_train_epochs /
            args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + (
        1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()

    #Recover model
    if args.model_recover_path:
        logger.info(" ** ** * Recover model: %s ** ** * ",
                    args.model_recover_path)
        model_recover = torch.load(args.model_recover_path, map_location='cpu')
        global_step = 0

    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    model = BertForPreTrainingLossMask.from_pretrained(
        args.bert_model,
        state_dict=model_recover,
        num_labels=cls_num_labels,
        num_rel=0,
        type_vocab_size=type_vocab_size,
        config_path=args.config_path,
        task_idx=3,
        num_sentlvl_labels=num_sentlvl_labels,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing,
        fp32_embedding=args.fp32_embedding,
        relax_projection=relax_projection,
        new_pos_ids=args.new_pos_ids,
        ffn_type=args.ffn_type,
        hidden_dropout_prob=args.hidden_dropout_prob,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        num_qkv=args.num_qkv,
        seg_emb=args.seg_emb,
        mask_word_id=mask_word_id,
        search_beam_size=5,
        length_penalty=0,
        eos_id=eos_word_ids,
        sos_id=sos_word_id,
        forbid_duplicate_ngrams=True,
        forbid_ignore_set=None,
        mode="s2s")

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)

    model.tmp_bert_emb.word_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.word_embeddings.weight.clone())
    model.tmp_bert_emb.token_type_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.token_type_embeddings.weight.clone())
    model.tmp_bert_emb.position_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.position_embeddings.weight.clone())
    model.mul_bert_emb.word_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.word_embeddings.weight.clone())
    model.mul_bert_emb.token_type_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.token_type_embeddings.weight.clone())
    model.mul_bert_emb.position_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.position_embeddings.weight.clone())
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from pytorch_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if args.optim_recover_path is not None:
        logger.info(" ** ** * Recover optimizer from : {} ** ** * ".format(
            args.optim_recover_path))
        optim_recover = torch.load(args.optim_recover_path, map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info(
                " ** ** * Recover optimizer: dynamic_loss_scale ** ** * ")
            optimizer.dynamic_loss_scale = True

    #logger.info(" ** ** * CUDA.empty_cache() ** ** * ")
    torch.cuda.empty_cache()

    # ################# TRAIN ############################ #
    if args.do_train:
        max_F1 = 0
        best_step = 0
        logger.info(" ** ** * Running training ** ** * ")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        start_epoch = 1

        for i_epoch in trange(start_epoch,
                              start_epoch + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)

            step = 0
            for batch, ks_batch in zip(train_dataloader, ks_train_dataloader):
                # ################# E step + M step + Mutual Information Loss ############################ #
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, check_ids = batch
                oracle_pos, oracle_weights, oracle_labels = None, None, None

                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv,
                                   tgt_pos=tgt_pos,
                                   labels=labels.half(),
                                   ks_labels=ks_labels,
                                   check_ids=check_ids)

                masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, predict_kl_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                    Mutual_loss = Mutual_loss.mean()
                    Golden_loss = Golden_loss.mean()
                    KL_loss = KL_loss.mean()
                    predict_kl_loss = predict_kl_loss.mean()

                loss = masked_lm_loss + next_sentence_loss + KL_loss + predict_kl_loss + Mutual_loss + Golden_loss
                logger.info("In{}step, masked_lm_loss:{}".format(
                    step, masked_lm_loss))
                logger.info("In{}step, KL_loss:{}".format(step, KL_loss))
                logger.info("In{}step, Mutual_loss:{}".format(
                    step, Mutual_loss))
                logger.info("In{}step, Golden_loss:{}".format(
                    step, Golden_loss))
                logger.info("In{}step, predict_kl_loss:{}".format(
                    step, predict_kl_loss))

                logger.info("******************************************* ")

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total,
                        args.warmup_proportion_step / t_total)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                # ################# Knowledge Selection Loss ############################ #
                if random.randint(0, 4) == 0:
                    ks_batch = [
                        t.to(device) if t is not None else None
                        for t in ks_batch
                    ]

                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, _, labels, ks_labels = ks_batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                    loss_tuple = model(input_ids,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv,
                                       labels=labels,
                                       ks_labels=ks_labels,
                                       train_ks=True)

                    ks_loss, _ = loss_tuple
                    if n_gpu > 1:  # mean() to average on multi-gpu.
                        ks_loss = ks_loss.mean()
                    loss = ks_loss

                    logger.info("In{}step, ks_loss:{}".format(step, ks_loss))
                    logger.info("******************************************* ")

                    # ensure that accumlated gradients are normalized
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        optimizer.backward(loss)
                        if amp_handle:
                            amp_handle._clear_cache()
                    else:
                        loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / t_total,
                            args.warmup_proportion_step / t_total)
                        if args.fp16:
                            # modify learning rate with special warm up BERT uses
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr_this_step
                        optimizer.step()
                        optimizer.zero_grad()

                step += 1
                ###################### Eval Every 5000 Step ############################ #
                if (global_step + 1) % 5000 == 0:
                    next_i = 0
                    model.eval()

                    # Know Rank Stage
                    logger.info(" ** ** * DEV Know Selection Begin ** ** * ")
                    with open(os.path.join(args.data_dir,
                                           args.predict_input_file),
                              "r",
                              encoding="utf-8") as file:
                        src_file = file.readlines()
                    with open(os.path.join(args.data_dir,
                                           "train_tgt_pad.empty"),
                              "r",
                              encoding="utf-8") as file:
                        tgt_file = file.readlines()
                    with open(os.path.join(args.data_dir,
                                           args.predict_output_file),
                              "w",
                              encoding="utf-8") as out:
                        while next_i < len(src_file):
                            batch_src = src_file[next_i:next_i +
                                                 args.eval_batch_size]
                            batch_tgt = tgt_file[next_i:next_i +
                                                 args.eval_batch_size]

                            next_i += args.eval_batch_size

                            ex_list = []
                            for src, tgt in zip(batch_src, batch_tgt):
                                src_tk = data_tokenizer.tokenize(src.strip())
                                tgt_tk = data_tokenizer.tokenize(tgt.strip())
                                ex_list.append((src_tk, tgt_tk))

                            batch = []
                            for idx in range(len(ex_list)):
                                instance = ex_list[idx]
                                for proc in ks_predict_bi_uni_pipeline:
                                    instance = proc(instance)
                                    batch.append(instance)

                            batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(
                                batch)
                            batch = [
                                t.to(device) if t is not None else None
                                for t in batch_tensor
                            ]

                            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                            predict_bleu = args.predict_bleu * torch.ones(
                                [input_ids.shape[0]], device=input_ids.device)
                            oracle_pos, oracle_weights, oracle_labels = None, None, None
                            with torch.no_grad():
                                logits = model(input_ids,
                                               segment_ids,
                                               input_mask,
                                               lm_label_ids,
                                               is_next,
                                               masked_pos=masked_pos,
                                               masked_weights=masked_weights,
                                               task_idx=task_idx,
                                               masked_pos_2=oracle_pos,
                                               masked_weights_2=oracle_weights,
                                               masked_labels_2=oracle_labels,
                                               mask_qkv=mask_qkv,
                                               labels=predict_bleu,
                                               train_ks=True)

                                logits = torch.nn.functional.softmax(logits,
                                                                     dim=1)
                                labels = logits[:, 1].cpu().numpy()
                                for i in range(len(labels)):
                                    line = batch_src[i].strip()
                                    line += "\t"
                                    line += str(labels[i])
                                    out.write(line)
                                    out.write("\n")

                    data_path = os.path.join(args.data_dir,
                                             "qkr_dev.ks_score.tk")
                    src_path = os.path.join(args.data_dir, "qkr_dev.src.tk")
                    src_out_path = os.path.join(args.data_dir,
                                                "rank_qkr_dev.src.tk")
                    tgt_path = os.path.join(args.data_dir, "qkr_dev.tgt")

                    knowledge_selection(data_path, src_path, src_out_path)
                    logger.info(" ** ** * DEV Know Selection End ** ** * ")

                    # Decode Stage
                    logger.info(" ** ** * Dev Decode Begin ** ** * ")
                    with open(src_out_path, encoding="utf-8") as file:
                        dev_src_lines = file.readlines()
                    with open(tgt_path, encoding="utf-8") as file:
                        golden_response_lines = file.readlines()

                    decode_result = decode_batch(model, dev_src_lines)
                    logger.info(" ** ** * Dev Decode End ** ** * ")

                    # Compute dev F1
                    assert len(decode_result) == len(golden_response_lines)
                    C_F1 = f_one(decode_result, golden_response_lines)[0]
                    logger.info(
                        "** ** * Current F1 is {} ** ** * ".format(C_F1))
                    if C_F1 < max_F1:
                        logger.info(
                            "** ** * Current F1 is lower than Previous F1. So Stop Training ** ** * "
                        )
                        logger.info(
                            "** ** * The best model is {} ** ** * ".format(
                                best_step))
                        break
                    else:
                        max_F1 = C_F1
                        best_step = step
                        logger.info(
                            "** ** * Current F1 is larger than Previous F1. So Continue Training ** ** * "
                        )

                    # Save trained model
                    if (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
                        logger.info(
                            "** ** * Saving fine-tuned model and optimizer ** ** * "
                        )
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir,
                            "model.{}_{}.bin".format(i_epoch, global_step))
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        output_optim_file = os.path.join(
                            args.output_dir, "optim.bin")
                        torch.save(optimizer.state_dict(), output_optim_file)

                        #logger.info(" ** ** * CUDA.empty_cache() ** ** * ")
                        torch.cuda.empty_cache()

    # ################# Predict ############################ #
    if args.do_predict:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq_predict(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]

        next_i = 0
        model.eval()

        with open(os.path.join(args.data_dir, args.predict_input_file),
                  "r",
                  encoding="utf-8") as file:
            src_file = file.readlines()
        with open("train_tgt_pad.empty", "r", encoding="utf-8") as file:
            tgt_file = file.readlines()
        with open(os.path.join(args.data_dir, args.predict_output_file),
                  "w",
                  encoding="utf-8") as out:
            logger.info("** ** * Continue knowledge ranking ** ** * ")
            for next_i in tqdm(
                    range(len(src_file) // args.eval_batch_size + 1)):
                #while next_i < len(src_file):
                batch_src = src_file[next_i *
                                     args.eval_batch_size:(next_i + 1) *
                                     args.eval_batch_size]
                batch_tgt = tgt_file[next_i *
                                     args.eval_batch_size:(next_i + 1) *
                                     args.eval_batch_size]
                #next_i += args.eval_batch_size

                ex_list = []
                for src, tgt in zip(batch_src, batch_tgt):
                    src_tk = data_tokenizer.tokenize(src.strip())
                    tgt_tk = data_tokenizer.tokenize(tgt.strip())
                    ex_list.append((src_tk, tgt_tk))

                batch = []
                for idx in range(len(ex_list)):
                    instance = ex_list[idx]
                    for proc in bi_uni_pipeline:
                        instance = proc(instance)
                        batch.append(instance)

                batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(
                    batch)
                batch = [
                    t.to(device) if t is not None else None
                    for t in batch_tensor
                ]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                predict_bleu = args.predict_bleu * torch.ones(
                    [input_ids.shape[0]], device=input_ids.device)
                oracle_pos, oracle_weights, oracle_labels = None, None, None
                with torch.no_grad():
                    logits = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv,
                                   labels=predict_bleu,
                                   train_ks=True)

                    logits = torch.nn.functional.softmax(logits, dim=1)
                    labels = logits[:, 1].cpu().numpy()
                    for i in range(len(labels)):
                        line = batch_src[i].strip()
                        line += "\t"
                        line += str(labels[i])
                        out.write(line)
                        out.write("\n")
Beispiel #18
0
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
    elif output_mode == "regression":
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.float)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids)
    if not is_distributed:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)

    for epoch_num in trange(int(args.num_train_epochs), desc="Epoch"):
        model.train()
        train_sampler.set_epoch(epoch_num)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            # define a new function to compute loss values for both output_modes
            logits = model(input_ids, segment_ids, input_mask, labels=None)

            if output_mode == "classification":
                if args.focal:
                    loss_fct = FocalLoss(class_num=num_labels,