Esempio n. 1
0
 def load_model(self):
     self.model = GPT2SequenceClassifierModel(
         hidden_size=self.embed_dim,
         num_classes=len(self.classes),
         gpt_model_name=self.model_name,
         max_seq_length=self.max_seq_length,
         finetune_GPT2=self.finetune_GPT2)
     self.model.to(device)
     self.opt = self.optimizer(self.model.parameters())
     if (self.checkpoint_path):
         # grab model and optimizer from checkpoint
         model_chk, opt_chk, self.current_ephoc, amp_chk = self.load_checkpoint(
         )
         print("continuing training from checkpoint at ephoc: ",
               self.current_ephoc)
         self.model.load_state_dict(model_chk)
         self.opt.load_state_dict(opt_chk)
         if (amp_chk):
             amp.load_state_dict(amp_chk)
     else:
         self.current_ephoc = 0
     if self.fp16:
         from apex import amp
         # inspired by: https://github.com/huggingface/transformers/blob/master/examples/question-answering/run_squad.py
         # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if fp16 is set.
         # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
         # remove the need for this code, but it is still valid.
         amp.register_half_function(torch, "einsum")
         print("Converting models and optimizer to FP16")
         self.model, self.opt = amp.initialize(self.model,
                                               self.opt,
                                               opt_level="O1")
def main():
    parser = get_parser()
    args = parser.parse_args()

    if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
            and not args.overwrite_output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Set device
    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    logging.getLogger("transformers.generation_utils").setLevel(logging.ERROR)

    # Load pretrained question generation model and tokenizer
    GPT2_tokenizer = GPT2Tokenizer.from_pretrained(
        args.question_generation_model, do_lower_case=args.do_lower_case)
    GPT2_model = GPT2LMHeadModel.from_pretrained(
        args.question_generation_model)
    GPT2_model.prepare_inputs_for_generation = prepare_inputs_for_generation
    GPT2_model.eval()
    GPT2_model.to(args.device)

    BERT_tokenizer = BertTokenizer.from_pretrained(
        args.answering_model, do_lower_case=args.do_lower_case)
    BERT_model = BertForQuestionAnswering.from_pretrained(args.answering_model)
    BERT_model.eval()
    BERT_model.to(args.device)

    logging.info("Parameters %s", args)

    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            from apex import amp
            amp.register_half_function(torch, "einsum")
            GPT2_model = amp.initialize(GPT2_model,
                                        opt_level=args.fp16_opt_level)
            BERT_model = amp.initialize(BERT_model,
                                        opt_level=args.fp16_opt_level)
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    generate(args, GPT2_tokenizer, GPT2_model, BERT_tokenizer, BERT_model)
Esempio n. 3
0
    def __init__(self,
                 args,
                 device):

        print('initializing Reader...', flush=True)
        # self.model = IterBertForQuestionAnsweringConfidence.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        # self.model = IterBertForQuestionAnsweringConfidenceV2.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        # self.model = IterBertForQuestionAnsweringConfidenceV3.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        # self.model = IterBertForQuestionAnsweringConfidenceV4.from_pretrained(args.reader_path, num_labels=4, no_masking=True)

        # self.model = RobertaForQuestionAnsweringConfidence.from_pretrained(args.reader_path, num_labels=4, no_masking=True)

        if args.reader_version == 'bert':
            self.model = BertForQuestionAnsweringConfidence.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        elif args.reader_version == 'iter':
            self.model = IterBertForQuestionAnsweringConfidence.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        elif args.reader_version == 'iter_v2':
            self.model = IterBertForQuestionAnsweringConfidenceV2.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        elif args.reader_version == 'iter_v3':
            self.model = IterBertForQuestionAnsweringConfidenceV3.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        elif args.reader_version == 'iter_v4':
            self.model = IterBertForQuestionAnsweringConfidenceV4.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        elif args.reader_version == 'roberta':
            self.model = RobertaForQuestionAnsweringConfidence.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        elif args.reader_version == 'roberta_iter':
            self.model = IterRobertaForQuestionAnsweringConfidence.from_pretrained(args.reader_path, num_labels=4, no_masking=True)
        else:
            raise RuntimeError()

        if args.reader_version == 'bert':
            self.tokenizer = BertTokenizer.from_pretrained(args.reader_path, do_lower_case=args.do_lower_case)
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(args.reader_path)
        self.device = device

        self.model.to(device)

        if args.fp16:
            from apex import amp

            if args.fp16_opt_level == 'O1':
                amp.register_half_function(torch, "einsum")

            self.model = amp.initialize(self.model, opt_level=args.fp16_opt_level)

        self.model.eval()
        print('Done!', flush=True)
Esempio n. 4
0
def get_amp(fp16):
    """This function ensures that fp16 execution of torch.einsum is enabled
        if args.fp16 is set. Otherwise, it'll default to "promote" mode,
        where the operations are in fp32.
        Note that running `--fp16_opt_level="O2"` will remove the need for this code.
    """
    # Before we do anything with models, we want to
    if fp16:
        try:
            from apex import amp

            amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex")
    else:
        amp = None
    return amp
Esempio n. 5
0
 def __init__(self,
              model: GWNet,
              scaler,
              lrate,
              wdecay,
              clip=5,
              lr_decay_rate=.97,
              fp16='',
              end_conv_lr=None):
     self.model = model
     if end_conv_lr:
         end_conv2, other_params = model.conv_group
         groups = [{'params': end_conv2, 'lr': end_conv_lr}]
         if lrate > 0:
             groups.append({'params': other_params})
         self.model.freeze_group_b()
         self.optimizer = optim.Adam(groups, lr=lrate, weight_decay=wdecay)
     else:
         self.optimizer = optim.Adam(self.model.parameters(),
                                     lr=lrate,
                                     weight_decay=wdecay)
     self.scaler = scaler
     self.clip = clip
     self.fp16 = fp16
     l1 = lambda epoch: lr_decay_rate**epoch
     self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer,
                                                  lr_lambda=l1)
     if self.fp16:
         try:
             from apex import amp  # Apex is only required if we use fp16 training
         except ImportError:
             raise ImportError(
                 "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
             )
         amp.register_half_function(torch, 'einsum')
         self.model, self.optimizer = amp.initialize(self.model,
                                                     self.optimizer,
                                                     opt_level=self.fp16)
    def __init__(self,
                 in_feature=128,
                 out_feature=10575,
                 s=32.0,
                 m=0.50,
                 easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.s = s
        self.m = m
        self.weight = Parameter(torch.Tensor(out_feature, in_feature))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)

        # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

        if args.use_amp == True:
            amp.register_half_function(torch, 'where')
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default='bert-base-uncased',
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        '--task',
        type=str,
        default=None,
        required=True,
        help="Task code in {hotpot_open, hotpot_distractor, squad, nq}")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=378,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=5,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam. (def: 5e-5)")
    parser.add_argument("--num_train_epochs",
                        default=5.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument('--local_rank', default=-1, type=int)

    # RNN graph retriever-specific parameters
    parser.add_argument("--example_limit", default=None, type=int)

    parser.add_argument("--max_para_num", default=10, type=int)
    parser.add_argument(
        "--neg_chunk",
        default=8,
        type=int,
        help="The chunk size of negative examples during training (to "
        "reduce GPU memory consumption with negative sampling)")
    parser.add_argument(
        "--eval_chunk",
        default=100000,
        type=int,
        help=
        "The chunk size of evaluation examples (to reduce RAM consumption during evaluation)"
    )
    parser.add_argument(
        "--split_chunk",
        default=300,
        type=int,
        help=
        "The chunk size of BERT encoding during inference (to reduce GPU memory consumption)"
    )

    parser.add_argument('--train_file_path',
                        type=str,
                        default=None,
                        help="File path to the training data")
    parser.add_argument('--dev_file_path',
                        type=str,
                        default=None,
                        help="File path to the eval data")

    parser.add_argument('--beam', type=int, default=1, help="Beam size")
    parser.add_argument('--min_select_num',
                        type=int,
                        default=1,
                        help="Minimum number of selected paragraphs")
    parser.add_argument('--max_select_num',
                        type=int,
                        default=3,
                        help="Maximum number of selected paragraphs")
    parser.add_argument(
        "--use_redundant",
        action='store_true',
        help="Whether to use simulated seqs (only for training)")
    parser.add_argument(
        "--use_multiple_redundant",
        action='store_true',
        help="Whether to use multiple simulated seqs (only for training)")
    parser.add_argument(
        '--max_redundant_num',
        type=int,
        default=100000,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )
    parser.add_argument(
        "--no_links",
        action='store_true',
        help=
        "Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)"
    )
    parser.add_argument("--pruning_by_links",
                        action='store_true',
                        help="Whether to do pruning by links (and top 1)")
    parser.add_argument(
        "--expand_links",
        action='store_true',
        help=
        "Whether to expand links with paragraphs in the same article (for NQ)")
    parser.add_argument(
        '--tfidf_limit',
        type=int,
        default=None,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )

    parser.add_argument("--pred_file",
                        default=None,
                        type=str,
                        help="File name to write paragraph selection results")
    parser.add_argument("--tagme",
                        action='store_true',
                        help="Whether to use tagme at inference")
    parser.add_argument(
        '--topk',
        type=int,
        default=2,
        help="Whether to use how many paragraphs from the previous steps")

    parser.add_argument(
        "--model_suffix",
        default=None,
        type=str,
        help="Suffix to load a model file ('pytorch_model_' + suffix +'.bin')")

    parser.add_argument("--db_save_path",
                        default=None,
                        type=str,
                        help="File path to DB")
    parser.add_argument("--fp16", default=False, action='store_true')
    parser.add_argument("--fp16_opt_level", default="O1", type=str)
    parser.add_argument("--do_label",
                        default=False,
                        action='store_true',
                        help="For pre-processing features only.")

    parser.add_argument("--oss_cache_dir", default=None, type=str)
    parser.add_argument("--cache_dir", default=None, type=str)
    parser.add_argument("--dist",
                        default=False,
                        action='store_true',
                        help='use distributed training.')
    parser.add_argument("--save_steps", default=5000, type=int)
    parser.add_argument("--resume", default=None, type=int)
    parser.add_argument("--oss_pretrain", default=None, type=str)
    parser.add_argument("--model_version", default='v1', type=str)
    parser.add_argument("--disable_rnn_layer_norm",
                        default=False,
                        action='store_true')

    args = parser.parse_args()

    if args.dist:
        dist.init_process_group(backend='nccl')
        print(f"local rank: {args.local_rank}")
        print(f"global rank: {dist.get_rank()}")
        print(f"world size: {dist.get_world_size()}")

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
        dist.init_process_group(backend='nccl')

    if args.dist:
        global_rank = dist.get_rank()
        world_size = dist.get_world_size()
        if world_size > 1:
            args.local_rank = global_rank

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.train_file_path is not None:
        do_train = True

        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.output_dir))
        if args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)

    elif args.dev_file_path is not None:
        do_train = False

    else:
        raise ValueError(
            'One of train_file_path: {} or dev_file_path: {} must be non-None'.
            format(args.train_file_path, args.dev_file_path))

    processor = DataProcessor()

    # Configurations of the graph retriever
    graph_retriever_config = GraphRetrieverConfig(
        example_limit=args.example_limit,
        task=args.task,
        max_seq_length=args.max_seq_length,
        max_select_num=args.max_select_num,
        max_para_num=args.max_para_num,
        tfidf_limit=args.tfidf_limit,
        train_file_path=args.train_file_path,
        use_redundant=args.use_redundant,
        use_multiple_redundant=args.use_multiple_redundant,
        max_redundant_num=args.max_redundant_num,
        dev_file_path=args.dev_file_path,
        beam=args.beam,
        min_select_num=args.min_select_num,
        no_links=args.no_links,
        pruning_by_links=args.pruning_by_links,
        expand_links=args.expand_links,
        eval_chunk=args.eval_chunk,
        tagme=args.tagme,
        topk=args.topk,
        db_save_path=args.db_save_path,
        disable_rnn_layer_norm=args.disable_rnn_layer_norm)

    logger.info(graph_retriever_config)
    logger.info(args)

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)

    if args.model_version == 'roberta':
        from modeling_graph_retriever_roberta import RobertaForGraphRetriever
    elif args.model_version == 'v3':
        from modeling_graph_retriever_roberta import RobertaForGraphRetrieverIterV3 as RobertaForGraphRetriever
    else:
        raise RuntimeError()

    ##############################
    # Training                   #
    ##############################
    if do_train:
        _model_state_dict = None
        if args.oss_pretrain is not None:
            _model_state_dict = torch.load(load_pretrain_from_oss(
                args.oss_pretrain),
                                           map_location='cpu')
            logger.info(f"Loaded pretrained model from {args.oss_pretrain}")

        if args.resume is not None:
            _model_state_dict = torch.load(load_buffer_from_oss(
                os.path.join(args.oss_cache_dir,
                             f"pytorch_model_{args.resume}.bin")),
                                           map_location='cpu')

        model = RobertaForGraphRetriever.from_pretrained(
            args.bert_model,
            graph_retriever_config=graph_retriever_config,
            state_dict=_model_state_dict)

        model.to(device)

        global_step = 0

        POSITIVE = 1.0
        NEGATIVE = 0.0

        _cache_file_name = f"cache_roberta_train_{args.max_seq_length}_{args.max_para_num}"
        _examples_cache_file_name = f"examples_{_cache_file_name}"
        _features_cache_file_name = f"features_{_cache_file_name}"

        # Load training examples
        logger.info(f"Loading training examples and features.")
        try:
            if args.cache_dir is not None and os.path.exists(
                    os.path.join(args.cache_dir, _features_cache_file_name)):
                logger.info(
                    f"Loading pre-processed features from {os.path.join(args.cache_dir, _features_cache_file_name)}"
                )
                train_features = torch.load(
                    os.path.join(args.cache_dir, _features_cache_file_name))
            else:
                # train_examples = torch.load(load_buffer_from_oss(os.path.join(oss_features_cache_dir,
                #                                                               _examples_cache_file_name)))
                train_features = torch.load(
                    load_buffer_from_oss(
                        os.path.join(oss_features_cache_dir,
                                     _features_cache_file_name)))
                logger.info(
                    f"Pre-processed features are loaded from oss: "
                    f"{os.path.join(oss_features_cache_dir, _features_cache_file_name)}"
                )
        except:
            train_examples = processor.get_train_examples(
                graph_retriever_config)
            train_features = convert_examples_to_features(
                train_examples,
                args.max_seq_length,
                args.max_para_num,
                graph_retriever_config,
                tokenizer,
                train=True)
            logger.info(
                f"Saving pre-processed features into oss: {oss_features_cache_dir}"
            )
            torch_save_to_oss(
                train_examples,
                os.path.join(oss_features_cache_dir,
                             _examples_cache_file_name))
            torch_save_to_oss(
                train_features,
                os.path.join(oss_features_cache_dir,
                             _features_cache_file_name))

        if args.do_label:
            logger.info("Finished.")
            return

        # len(train_examples) and len(train_features) can be different, depending on the redundant setting
        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        t_total = num_train_steps
        if args.local_rank != -1:
            t_total = t_total // dist.get_world_size()

        optimizer = AdamW(optimizer_grouped_parameters,
                          betas=(0.9, 0.98),
                          lr=args.learning_rate)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, int(t_total * args.warmup_proportion), t_total)

        logger.info(optimizer)
        if args.fp16:
            from apex import amp
            amp.register_half_function(torch, "einsum")

            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)

        if args.local_rank != -1:
            if args.fp16_opt_level == 'O2':
                try:
                    import apex
                    model = apex.parallel.DistributedDataParallel(
                        model, delay_allreduce=True)
                except ImportError:
                    model = torch.nn.parallel.DistributedDataParallel(
                        model, find_unused_parameters=True)
            else:
                model = torch.nn.parallel.DistributedDataParallel(
                    model, find_unused_parameters=True)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        if args.resume is not None:
            _amp_state_dict = os.path.join(args.oss_cache_dir,
                                           f"amp_{args.resume}.bin")
            _optimizer_state_dict = os.path.join(
                args.oss_cache_dir, f"optimizer_{args.resume}.pt")
            _scheduler_state_dict = os.path.join(
                args.oss_cache_dir, f"scheduler_{args.resume}.pt")

            amp.load_state_dict(
                torch.load(load_buffer_from_oss(_amp_state_dict)))
            optimizer.load_state_dict(
                torch.load(load_buffer_from_oss(_optimizer_state_dict)))
            scheduler.load_state_dict(
                torch.load(load_buffer_from_oss(_scheduler_state_dict)))

            logger.info(f"Loaded resumed state dict of step {args.resume}")

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Instantaneous batch size per GPU = %d",
                    args.train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.train_batch_size * args.gradient_accumulation_steps *
            (dist.get_world_size() if args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        model.train()
        epc = 0
        # test
        if args.local_rank in [-1, 0]:
            if args.fp16:
                amp_file = os.path.join(args.oss_cache_dir,
                                        f"amp_{global_step}.bin")
                torch_save_to_oss(amp.state_dict(), amp_file)
            optimizer_file = os.path.join(args.oss_cache_dir,
                                          f"optimizer_{global_step}.pt")
            torch_save_to_oss(optimizer.state_dict(), optimizer_file)
            scheduler_file = os.path.join(args.oss_cache_dir,
                                          f"scheduler_{global_step}.pt")
            torch_save_to_oss(scheduler.state_dict(), scheduler_file)

        tr_loss = 0
        for _ in range(int(args.num_train_epochs)):
            logger.info('Epoch ' + str(epc + 1))

            TOTAL_NUM = len(train_features)
            train_start_index = 0
            CHUNK_NUM = 8
            train_chunk = TOTAL_NUM // CHUNK_NUM
            chunk_index = 0

            random.shuffle(train_features)

            save_retry = False
            while train_start_index < TOTAL_NUM:
                train_end_index = min(train_start_index + train_chunk - 1,
                                      TOTAL_NUM - 1)
                chunk_len = train_end_index - train_start_index + 1

                if args.resume is not None and global_step < args.resume:
                    _chunk_steps = int(
                        math.ceil(chunk_len * 1.0 / args.train_batch_size /
                                  (1 if args.local_rank == -1 else
                                   dist.get_world_size())))
                    _chunk_steps = _chunk_steps // args.gradient_accumulation_steps
                    if global_step + _chunk_steps <= args.resume:
                        global_step += _chunk_steps
                        train_start_index = train_end_index + 1
                        continue

                train_features_ = train_features[
                    train_start_index:train_start_index + chunk_len]

                all_input_ids = torch.tensor(
                    [f.input_ids for f in train_features_], dtype=torch.long)
                all_input_masks = torch.tensor(
                    [f.input_masks for f in train_features_], dtype=torch.long)
                all_segment_ids = torch.tensor(
                    [f.segment_ids for f in train_features_], dtype=torch.long)
                all_output_masks = torch.tensor(
                    [f.output_masks for f in train_features_],
                    dtype=torch.float)
                all_num_paragraphs = torch.tensor(
                    [f.num_paragraphs for f in train_features_],
                    dtype=torch.long)
                all_num_steps = torch.tensor(
                    [f.num_steps for f in train_features_], dtype=torch.long)
                train_data = TensorDataset(all_input_ids, all_input_masks,
                                           all_segment_ids, all_output_masks,
                                           all_num_paragraphs, all_num_steps)

                if args.local_rank != -1:
                    train_sampler = torch.utils.data.DistributedSampler(
                        train_data)
                else:
                    train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(train_data,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              pin_memory=True,
                                              num_workers=4)

                if args.local_rank != -1:
                    train_dataloader.sampler.set_epoch(epc)

                logger.info('Examples from ' + str(train_start_index) +
                            ' to ' + str(train_end_index))
                for step, batch in enumerate(
                        tqdm(train_dataloader,
                             desc="Iteration",
                             disable=args.local_rank not in [-1, 0])):
                    if args.resume is not None and global_step < args.resume:
                        if (step + 1) % args.gradient_accumulation_steps == 0:
                            global_step += 1
                        continue

                    input_masks = batch[1]
                    batch_max_len = input_masks.sum(dim=2).max().item()

                    num_paragraphs = batch[4]
                    batch_max_para_num = num_paragraphs.max().item()

                    num_steps = batch[5]
                    batch_max_steps = num_steps.max().item()

                    # output_masks_cpu = (batch[3])[:, :batch_max_steps, :batch_max_para_num + 1]

                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_masks, segment_ids, output_masks, _, _ = batch
                    B = input_ids.size(0)

                    input_ids = input_ids[:, :batch_max_para_num, :
                                          batch_max_len]
                    input_masks = input_masks[:, :batch_max_para_num, :
                                              batch_max_len]
                    segment_ids = segment_ids[:, :batch_max_para_num, :
                                              batch_max_len]
                    output_masks = output_masks[:, :batch_max_steps, :
                                                batch_max_para_num +
                                                1]  # 1 for EOE

                    target = torch.zeros(output_masks.size()).fill_(
                        NEGATIVE)  # (B, NUM_STEPS, |P|+1) <- 1 for EOE
                    for i in range(B):
                        output_masks[i, :num_steps[i], -1] = 1.0  # for EOE

                        for j in range(num_steps[i].item() - 1):
                            target[i, j, j].fill_(POSITIVE)

                        target[i, num_steps[i] - 1, -1].fill_(POSITIVE)
                    target = target.to(device)

                    neg_start = batch_max_steps - 1
                    while neg_start < batch_max_para_num:
                        neg_end = min(neg_start + args.neg_chunk - 1,
                                      batch_max_para_num - 1)
                        neg_len = (neg_end - neg_start + 1)

                        input_ids_ = torch.cat(
                            (input_ids[:, :batch_max_steps - 1, :],
                             input_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        input_masks_ = torch.cat(
                            (input_masks[:, :batch_max_steps - 1, :],
                             input_masks[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        segment_ids_ = torch.cat(
                            (segment_ids[:, :batch_max_steps - 1, :],
                             segment_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        output_masks_ = torch.cat(
                            (output_masks[:, :, :batch_max_steps - 1],
                             output_masks[:, :, neg_start:neg_start + neg_len],
                             output_masks[:, :, batch_max_para_num:
                                          batch_max_para_num + 1]),
                            dim=2)
                        target_ = torch.cat(
                            (target[:, :, :batch_max_steps - 1],
                             target[:, :, neg_start:neg_start + neg_len],
                             target[:, :,
                                    batch_max_para_num:batch_max_para_num +
                                    1]),
                            dim=2)

                        if neg_start != batch_max_steps - 1:
                            output_masks_[:, :, :batch_max_steps - 1] = 0.0
                            output_masks_[:, :, -1] = 0.0

                        loss = model(input_ids_, segment_ids_, input_masks_,
                                     output_masks_, target_, batch_max_steps)

                        if n_gpu > 1:
                            loss = loss.mean(
                            )  # mean() to average on multi-gpu.
                        if args.gradient_accumulation_steps > 1:
                            loss = loss / args.gradient_accumulation_steps

                        if args.fp16:
                            with amp.scale_loss(loss,
                                                optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()

                        tr_loss += loss.item()
                        neg_start = neg_end + 1

                        # del input_ids_
                        # del input_masks_
                        # del segment_ids_
                        # del output_masks_
                        # del target_

                    if (step + 1) % args.gradient_accumulation_steps == 0:

                        if args.fp16:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), 1.0)
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), 1.0)

                        optimizer.step()
                        scheduler.step()
                        # optimizer.zero_grad()
                        model.zero_grad()
                        global_step += 1

                        if global_step % 50 == 0:
                            _cur_steps = global_step if args.resume is None else global_step - args.resume
                            logger.info(
                                f"Training loss: {tr_loss / _cur_steps}\t"
                                f"Learning rate: {scheduler.get_lr()[0]}\t"
                                f"Global step: {global_step}")

                        if global_step % args.save_steps == 0:
                            if args.local_rank in [-1, 0]:
                                model_to_save = model.module if hasattr(
                                    model, 'module') else model
                                output_model_file = os.path.join(
                                    args.oss_cache_dir,
                                    f"pytorch_model_{global_step}.bin")
                                torch_save_to_oss(model_to_save.state_dict(),
                                                  output_model_file)

                            _suffix = "" if args.local_rank == -1 else f"_{args.local_rank}"
                            if args.fp16:
                                amp_file = os.path.join(
                                    args.oss_cache_dir,
                                    f"amp_{global_step}{_suffix}.bin")
                                torch_save_to_oss(amp.state_dict(), amp_file)
                            optimizer_file = os.path.join(
                                args.oss_cache_dir,
                                f"optimizer_{global_step}{_suffix}.pt")
                            torch_save_to_oss(optimizer.state_dict(),
                                              optimizer_file)
                            scheduler_file = os.path.join(
                                args.oss_cache_dir,
                                f"scheduler_{global_step}{_suffix}.pt")
                            torch_save_to_oss(scheduler.state_dict(),
                                              scheduler_file)

                            logger.info(
                                f"checkpoint of step {global_step} is saved to oss."
                            )

                    # del input_ids
                    # del input_masks
                    # del segment_ids
                    # del output_masks
                    # del target
                    # del batch

                chunk_index += 1
                train_start_index = train_end_index + 1

                # Save the model at the half of the epoch
                if (chunk_index == CHUNK_NUM // 2
                        or save_retry) and args.local_rank in [-1, 0]:
                    status = save(model, args.output_dir, str(epc + 0.5))
                    save_retry = (not status)

                del train_features_
                del all_input_ids
                del all_input_masks
                del all_segment_ids
                del all_output_masks
                del all_num_paragraphs
                del all_num_steps
                del train_data
                del train_sampler
                del train_dataloader
                gc.collect()

            # Save the model at the end of the epoch
            if args.local_rank in [-1, 0]:
                save(model, args.output_dir, str(epc + 1))
                # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                # output_model_file = os.path.join(args.oss_cache_dir, "pytorch_model_" + str(epc + 1) + ".bin")
                # torch_save_to_oss(model_to_save.state_dict(), output_model_file)

            epc += 1

    if do_train:
        return

    ##############################
    # Evaluation                 #
    ##############################
    assert args.model_suffix is not None

    if graph_retriever_config.db_save_path is not None:
        import sys
        sys.path.append('../')
        from pipeline.tfidf_retriever import TfidfRetriever
        tfidf_retriever = TfidfRetriever(graph_retriever_config.db_save_path,
                                         None)
    else:
        tfidf_retriever = None

    if args.oss_cache_dir is not None:
        file_name = 'pytorch_model_' + args.model_suffix + '.bin'
        model_state_dict = torch.load(
            load_buffer_from_oss(os.path.join(args.oss_cache_dir, file_name)))
    else:
        model_state_dict = load(args.output_dir, args.model_suffix)

    model = RobertaForGraphRetriever.from_pretrained(
        args.bert_model,
        state_dict=model_state_dict,
        graph_retriever_config=graph_retriever_config)
    model.to(device)

    model.eval()

    if args.pred_file is not None:
        pred_output = []

    eval_examples = processor.get_dev_examples(graph_retriever_config)

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)

    TOTAL_NUM = len(eval_examples)
    eval_start_index = 0

    while eval_start_index < TOTAL_NUM:
        eval_end_index = min(
            eval_start_index + graph_retriever_config.eval_chunk - 1,
            TOTAL_NUM - 1)
        chunk_len = eval_end_index - eval_start_index + 1

        eval_features = convert_examples_to_features(
            eval_examples[eval_start_index:eval_start_index + chunk_len],
            args.max_seq_length, args.max_para_num, graph_retriever_config,
            tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_masks = torch.tensor([f.input_masks for f in eval_features],
                                       dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_output_masks = torch.tensor(
            [f.output_masks for f in eval_features], dtype=torch.float)
        all_num_paragraphs = torch.tensor(
            [f.num_paragraphs for f in eval_features], dtype=torch.long)
        all_num_steps = torch.tensor([f.num_steps for f in eval_features],
                                     dtype=torch.long)
        all_ex_indices = torch.tensor([f.ex_index for f in eval_features],
                                      dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_masks,
                                  all_segment_ids, all_output_masks,
                                  all_num_paragraphs, all_num_steps,
                                  all_ex_indices)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        for input_ids, input_masks, segment_ids, output_masks, num_paragraphs, num_steps, ex_indices in tqdm(
                eval_dataloader, desc="Evaluating"):
            batch_max_len = input_masks.sum(dim=2).max().item()
            batch_max_para_num = num_paragraphs.max().item()

            batch_max_steps = num_steps.max().item()

            input_ids = input_ids[:, :batch_max_para_num, :batch_max_len]
            input_masks = input_masks[:, :batch_max_para_num, :batch_max_len]
            segment_ids = segment_ids[:, :batch_max_para_num, :batch_max_len]
            output_masks = output_masks[:, :batch_max_para_num +
                                        2, :batch_max_para_num + 1]
            output_masks[:, 1:, -1] = 1.0  # Ignore EOE in the first step

            input_ids = input_ids.to(device)
            input_masks = input_masks.to(device)
            segment_ids = segment_ids.to(device)
            output_masks = output_masks.to(device)

            examples = [
                eval_examples[eval_start_index + ex_indices[i].item()]
                for i in range(input_ids.size(0))
            ]

            with torch.no_grad():
                pred, prob, topk_pred, topk_prob = model.beam_search(
                    input_ids,
                    segment_ids,
                    input_masks,
                    examples=examples,
                    tokenizer=tokenizer,
                    retriever=tfidf_retriever,
                    split_chunk=args.split_chunk)

            for i in range(len(pred)):
                e = examples[i]
                titles = [e.title_order[p] for p in pred[i]]

                # Output predictions to a file
                if args.pred_file is not None:
                    pred_output.append({})
                    pred_output[-1]['q_id'] = e.guid

                    pred_output[-1]['titles'] = titles
                    pred_output[-1]['probs'] = []
                    for prob_ in prob[i]:
                        entry = {'EOE': prob_[-1]}
                        for j in range(len(e.title_order)):
                            entry[e.title_order[j]] = prob_[j]
                        pred_output[-1]['probs'].append(entry)

                    topk_titles = [[e.title_order[p] for p in topk_pred[i][j]]
                                   for j in range(len(topk_pred[i]))]
                    pred_output[-1]['topk_titles'] = topk_titles

                    topk_probs = []
                    for k in range(len(topk_prob[i])):
                        topk_probs.append([])
                        for prob_ in topk_prob[i][k]:
                            entry = {'EOE': prob_[-1]}
                            for j in range(len(e.title_order)):
                                entry[e.title_order[j]] = prob_[j]
                            topk_probs[-1].append(entry)
                    pred_output[-1]['topk_probs'] = topk_probs

                    # Output the selected paragraphs
                    context = {}
                    for ts in topk_titles:
                        for t in ts:
                            context[t] = e.all_paras[t]
                    pred_output[-1]['context'] = context

        eval_start_index = eval_end_index + 1

        del eval_features
        del all_input_ids
        del all_input_masks
        del all_segment_ids
        del all_output_masks
        del all_num_paragraphs
        del all_num_steps
        del all_ex_indices
        del eval_data

    if args.pred_file is not None:
        json.dump(pred_output, open(args.pred_file, 'w'))
Esempio n. 8
0
def main(args):
    set_envs(args)
    print("Using: {}".format(args.device))

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, local_files_only=True)
    classifier = AutoQuestionAnswering.from_pretrained(model_path=args.pretrained_model_path,
                                                        header_mode=args.header_mode,
                                                        cls_index=tokenizer.cls_token_id)
    classifier.freeze_to_layer_by_name(args.freeze_layer_name)
    classifier.train()

    loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
            {
                "params": [p for n, p in classifier.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": args.weight_decay,
            },
            {"params": [p for n, p in classifier.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

    # optimizer = optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()),
    #                     lr=args.learning_rate)

    # Initialization
    opt_level = 'O1'
    if args.cuda:
        classifier = classifier.to(args.device)
        if args.fp16:
            classifier, optimizer = amp.initialize(classifier, optimizer, opt_level=opt_level)
            amp.register_half_function(torch, "einsum")
        # classifier = nn.parallel.DistributedDataParallel(classifier,
        #                                                 device_ids=args.device_ids, 
        #                                                 output_device=0, 
        #                                                 find_unused_parameters=True)

    if args.reload_from_files:
        checkpoint = torch.load(args.model_state_file)
        classifier.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        amp.load_state_dict(checkpoint['amp'])

    train_state = make_train_state(args)

    dataset = HotpotQA_QA_Dataset.build_dataset(args.json_train_path)
    dataset.set_parameters(tokenizer = tokenizer, topN_sents = args.topN_sents,
                            max_length=args.max_length, uncased=args.uncased,
                            permutations=args.permutations, random_seed=args.seed)
    print(dataset)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                            mode='min',
                            factor=0.7,
                            patience=dataset.get_num_batches(args.batch_size)/50)
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer,
    #     num_warmup_steps=args.warmup_steps,
    #     num_training_steps=dataset.get_num_batches(args.batch_size) * args.num_epochs
    # )

    try:
        writer = SummaryWriter(log_dir=args.log_dir,flush_secs=args.flush_secs)
        epoch_bar = tqdm(desc='training routine',
                        total=args.num_epochs,
                        position=0)

        dataset.set_split('train')
        train_bar = tqdm(desc='split=train',
                        total=dataset.get_num_batches(args.batch_size), 
                        position=1)

        dataset.set_split('val')
        val_bar = tqdm(desc='split=val',
                        total=dataset.get_num_batches(args.batch_size), 
                        position=1)

        cursor_train = 0
        cursor_val = 0
        if args.reload_from_files and 'cursor_train' in checkpoint.keys():
            cursor_train = checkpoint['cursor_train'] + 1
            cursor_val = checkpoint['cursor_val'] + 1

        for epoch_index in range(args.num_epochs):
            train_bar.n = 0
            val_bar.n = 0
            train_state['epoch_index'] = epoch_index
            dataset.set_split('train')
            dataset.random_seed = args.seed + epoch_index
            batch_generator = generate_QA_batches(dataset,shuffle=True,
                                            batch_size=args.batch_size, 
                                            device=args.device)
            running_loss = 0.0
            running_ans_span_accuracy = 0.0
            running_yes_no_span_accuracy = 0.0

            classifier.train()

            # dont count running value if denominator == 0.
            batch_index_for_yesnospan = 0
            batch_index_for_span = 0

            for batch_index, batch_dict in enumerate(batch_generator):
                optimizer.zero_grad()
                yes_no_span = batch_dict.pop('yes_no_span')
                res = classifier(**batch_dict)
                start_logits, end_logits, cls_logits = res[0], res[1], res[2]
                
                start_loss = loss_fct(start_logits, batch_dict['start_positions'])
                end_loss = loss_fct(end_logits, batch_dict['end_positions'])
                start_end_loss = (start_loss + end_loss) / 2
                if start_end_loss > 1e5:
                    print(start_logits.gather(-1, batch_dict['start_positions'].view(-1, 1)))
                    print(batch_dict['special_tokens_mask'].gather(-1, batch_dict['start_positions'].view(-1, 1)))
                    print(batch_dict['start_positions'])
                    print('')
                    print(end_logits.gather(-1, batch_dict['end_positions'].view(-1, 1)))
                    print(batch_dict['special_tokens_mask'].gather(-1, batch_dict['end_positions'].view(-1, 1)))
                    print(batch_dict['end_positions'])
                    exit()

                yes_no_span_loss = loss_fct(cls_logits, yes_no_span) / 2
                if yes_no_span_loss > 1e5:
                    print(cls_logits)
                    print(yes_no_span)
                    exit()

                ans_span_accuracy = compute_span_accuracy(start_logits, batch_dict['start_positions'],
                                                            end_logits, batch_dict['end_positions'])
                yes_no_span_accuracy = compute_accuracy(cls_logits, yes_no_span)
                
                loss = start_end_loss + yes_no_span_loss
                running_loss += (loss.item() - running_loss) / (batch_index + 1)

                if ans_span_accuracy: 
                    running_ans_span_accuracy  += \
                                        (ans_span_accuracy - running_ans_span_accuracy) / (batch_index_for_span + 1)
                    batch_index_for_span += 1

                if yes_no_span_accuracy:
                    running_yes_no_span_accuracy  += \
                                        (yes_no_span_accuracy - running_yes_no_span_accuracy) / (batch_index_for_yesnospan + 1)
                    batch_index_for_yesnospan += 1
                
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                optimizer.step()
                scheduler.step(running_loss)  # Update learning rate schedule
                
                # update bar               
                train_bar.set_postfix(running_loss=running_loss,epoch=epoch_index)
                train_bar.update()

                writer.add_scalar('loss/train', loss.item(), cursor_train)
                if ans_span_accuracy:
                    writer.add_scalar('ans_span_accuracy/train', ans_span_accuracy, cursor_train)
                if yes_no_span_accuracy:
                    writer.add_scalar('yes_no_span_accuracy/train', yes_no_span_accuracy, cursor_train)

                writer.add_scalar('running_loss/train', running_loss, cursor_train)
                writer.add_scalar('running_ans_span_accuracy/train', running_ans_span_accuracy, cursor_train)
                writer.add_scalar('running_yes_no_span_accuracy/train', running_yes_no_span_accuracy, cursor_train)
                cursor_train += 1

            train_state['train_running_loss'].append(running_loss)

            # Iterate over val dataset
            # setup: batch generator, set loss and acc to 0; set eval mode on

            dataset.set_split('val')
            batch_generator = generate_QA_batches(dataset,
                                            batch_size=args.batch_size, 
                                            device=args.device)
            running_loss = 0.0
            running_ans_span_accuracy = 0.0
            running_yes_no_span_accuracy = 0.0
            
            classifier.eval()

            batch_index_for_yesnospan = 0
            batch_index_for_span = 0
            
            for batch_index, batch_dict in enumerate(batch_generator):
                with torch.no_grad():

                    yes_no_span = batch_dict.pop('yes_no_span')
                    res = classifier(**batch_dict)
                    start_logits, end_logits, cls_logits = res[0], res[1], res[2]

                    start_loss = loss_fct(start_logits, batch_dict['start_positions'])
                    end_loss = loss_fct(end_logits, batch_dict['end_positions'])
                    start_end_loss = (start_loss + end_loss) / 2
                    yes_no_span_loss = loss_fct(cls_logits, yes_no_span) / 2

                    ans_span_accuracy = compute_span_accuracy(start_logits, batch_dict['start_positions'],
                                                                end_logits, batch_dict['end_positions'])
                    yes_no_span_accuracy = compute_accuracy(cls_logits, yes_no_span)

                    loss = start_end_loss + yes_no_span_loss
                    running_loss += (loss.item() - running_loss) / (batch_index + 1)

                    if ans_span_accuracy: 
                        running_ans_span_accuracy  += \
                                            (ans_span_accuracy - running_ans_span_accuracy) / (batch_index_for_span + 1)
                        batch_index_for_span += 1
                    if yes_no_span_accuracy:
                        running_yes_no_span_accuracy  += \
                                            (yes_no_span_accuracy - running_yes_no_span_accuracy) / (batch_index_for_yesnospan + 1)
                        batch_index_for_yesnospan += 1

                val_bar.set_postfix(running_loss=running_loss,epoch=epoch_index)
                val_bar.update()

                writer.add_scalar('loss/val', loss.item(), cursor_val)
                if ans_span_accuracy:
                    writer.add_scalar('ans_span_accuracy/val', ans_span_accuracy, cursor_val)
                if yes_no_span_accuracy:
                    writer.add_scalar('yes_no_span_accuracy/val', yes_no_span_accuracy, cursor_val)

                writer.add_scalar('running_loss/val', running_loss, cursor_val)
                writer.add_scalar('running_ans_span_accuracy/val', running_ans_span_accuracy, cursor_val)
                writer.add_scalar('running_yes_no_span_accuracy/val', running_yes_no_span_accuracy, cursor_val)
                cursor_val += 1

            train_state['val_running_loss'].append(running_loss)

            if not args.use_mini:
                train_state = update_train_state(args=args,
                                                model=classifier, 
                                                optimizer = optimizer,
                                                train_state=train_state)
            epoch_bar.update()

            if train_state['stop_early']:
                print('STOP EARLY!')
                break

    except KeyboardInterrupt:
        print("Exiting loop")
        if args.use_mini: rm_rf(args.log_dir)
    except :
        print_exc()
        print(f"err in epoch_index {epoch_index}, batch_index {batch_index}.")
Esempio n. 9
0
def run():
    args = get_args()
    fdir = Path(args.dir)
    tb = SummaryWriter(args.logdir)  # 对啦,tensorboard画图的
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    output_dir = Path(args.output)
    output_dir.mkdir(exist_ok=True, parents=True)
    logger.info(args)
    logger.info(f"loading vocab...")
    tokenizer = Tokenizer.from_pretrained(fdir / 'vocab.pkl')
    logger.info(f"loading dataset...")
    train_dataset = torch.load(fdir / 'train.pkl')
    test_dataset = torch.load(fdir / 'test.pkl')
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
    logger.info(f"initializing model...")
    model = init_model_by_key(args, tokenizer)
    model.to(device)
    loss_function = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_id)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if args.fp16:
        try:
            from apex import amp  # 实现不同程度的混合精度加速,提升pytorch的训练速度
            amp.register_half_function(
                torch, 'einsum'
            )  # 某些不常用的函数,在使用前需要注册;某些函数(如einsum)暂不支持FP16加速,建议不要用的太heavy
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min')  # 当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能。
    logger.info(f"num gpu: {torch.cuda.device_count()}")
    global_step = 0
    for epoch in range(args.epochs):
        logger.info(f"***** Epoch {epoch} *****")
        model.train()
        t1 = time.time()
        accu_loss = 0.0
        for step, batch in enumerate(train_loader):
            optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            input_ids, masks, lens, target_ids = batch
            logits = model(input_ids, masks)
            loss = loss_function(logits.view(-1, tokenizer.vocab_size),
                                 target_ids.view(-1))
            if torch.cuda.device_count() > 1:
                loss = loss.mean()
            accu_loss += loss.item()
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            if step % 100 == 0:
                tb.add_scalar('loss', loss.item(),
                              global_step)  # tensorboard画图用的
                logger.info(
                    f"[epoch]: {epoch}, [batch]: {step}, [loss]: {loss.item()}"
                )
            global_step += 1
        scheduler.step(accu_loss)
        t2 = time.time()
        logger.info(
            f"epoch time: {t2-t1:.5}, accumulation loss: {accu_loss:.6}")
        if (epoch + 1) % args.test_epoch == 0:
            predict_demos(model, tokenizer)
            bleu, rl = auto_evaluate(model, test_loader, tokenizer)
            logger.info(f"BLEU: {round(bleu, 9)}, Rouge-L: {round(rl, 8)}")
        if (epoch + 1) % args.save_epoch == 0:
            filename = f"{model.__class__.__name__}_{epoch + 1}.bin"
            filename = output_dir / filename
            save_model(filename, model, args, tokenizer)
Esempio n. 10
0
    def __init__(
        self,
        access_mode,
        fp16,
        fp16_opt_level,
        model,
        model_name,
        device,
        myio,
        save_dir,
        n_best,
        max_answer_length,
        do_lower_case,
        verbose_logging,
        version_2_with_negative,
        null_score_diff_threshold,
        max_steps=1e5,
        log_int=1e4,
        best_int=500,
        verbose_int=1000,
        max_grad_norm=1.0,
        optimizer=None,
        weight_decay=0.0,
        lr=5e-3,
        eps=1e-8,
        warmup_steps=0,
        freeze_embeddings=False,
    ):
        """
        Object to store learning. Used for fine-tuning.
        
        Data stored in myio.IO object called myio.
        """
        self.debug = False

        self.fp16 = fp16
        self.fp16_opt_level = fp16_opt_level
        self.access_mode = access_mode

        self.model = model.to(device)
        self.model_name = model_name
        self.device = device
        self.IO = myio
        self.save_dir = save_dir
        self.max_steps = max_steps
        self.log_int = log_int
        self.best_int = best_int
        self.verbose_int = verbose_int
        self.max_grad_norm = max_grad_norm
        self.weight_decay = weight_decay
        self.lr = lr
        self.eps = eps
        self.warmup_steps = warmup_steps
        self.freeze = freeze_embeddings

        # make directory for recorded weights if doesn't already exist
        self.log_dir = os.path.join(self.save_dir, 'logged')
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)

        # for evaluation
        self.n_best = n_best
        self.max_answer_length = max_answer_length
        self.do_lower_case = do_lower_case
        self.verbose_logging = verbose_logging
        self.version_2_with_negative = version_2_with_negative
        self.null_score_diff_threshold = null_score_diff_threshold

        # data
        self.train_dataloader = None
        self.val_dataloader = None
        self.val_examples = None
        self.val_features = None

        # set optimizer
        self.optimizer = optimizer

        if optimizer is None:
            self.set_optimizer()

        # use mixed precision if needed
        if self.fp16:
            from apex import amp
            amp.register_half_function(torch, "einsum")
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=self.fp16_opt_level)

        # if multiple GPUs on single device
        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(model)
            self.model.to(self.device)

        # stop embedding weight grad tracking
        if self.freeze:
            if isinstance(self.model, nn.DataParallel):
                bert = self.model.module.model.bert
            else:
                bert = self.model.model.bert

            for param in bert.parameters():
                param.requires_grad = False
            log.info("Froze BERT parameters")

            if self.debug:
                # to check updating
                if isinstance(self.model, nn.DataParallel):
                    qabert = self.model.module.model
                else:
                    qabert = self.model.model

                self.fixed_bert = copy.deepcopy(qabert.bert)
                self.fixed_qa = copy.deepcopy(qabert.qa_outputs)
Esempio n. 11
0
def main():
    vocab_path = f'{config.data_dir}/vocabs'
    args = get_args()
    tb = SummaryWriter(args.logdir)
    epochs = args.epochs
    batch_size = args.batch_size
    lr = args.lr

    embed_dim = config.embed_dim
    hidden_dim = config.hidden_dim
    output_dir = config.ouput_dir

    print("yes2")

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    logger.info(f"***** Loading vocab *****")
    word_to_ix = load_vocab(vocab_path)
    vocab_size = len(word_to_ix)
    logger.info(f"***** Initializing dataset *****")
    train_dataloader = init_dataset(args.dir, batch_size)
    logger.info(f"***** Training *****")
    model = TraForEncoder(vocab_size, embed_dim, hidden_dim)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    if args.fp16:
        try:
            from apex import amp
            amp.register_half_function(torch, 'einsum')
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.train()
    loss_func = nn.CrossEntropyLoss(ignore_index=word_to_ix['[PAD]'])
    logger.info(f"Num GPU {torch.cuda.device_count()}")
    global_step = 0
    for epoch in range(epochs):
        logger.info(f"***** Epoch {epoch} *****")
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            seq_ids, exted_att_mask, tag_ids = batch
            logits = model(seq_ids, exted_att_mask)
            loss = loss_func(logits.view(-1, vocab_size), tag_ids.view(-1))
            if torch.cuda.device_count() > 1:
                loss = loss.mean()
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            if step % 100 == 0:
                tb.add_scalar('loss', loss.item(), global_step)
                logger.info(
                    f"[epoch]: {epoch}, [batch]: {step}, [loss]: {loss.item()}"
                )
            global_step += 1
        save_model(model, output_dir, epoch + 1)
Esempio n. 12
0
def train(args, logger, tb_writer):
    logger.info('Args: {}'.format(json.dumps(vars(args), indent=4, sort_keys=True)))
    if args.local_rank in [-1, 0]:
        with open(os.path.join(args.save_dir, 'args.yaml'), 'w') as file:
            yaml.safe_dump(vars(args), file, sort_keys=False)

    device_id = args.local_rank if args.local_rank != -1 else 0
    device = torch.device('cuda', device_id)
    logger.warning(f'Using GPU {args.local_rank}.')

    world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1
    logger.info(f'Total number of GPUs used: {world_size}.')
    effective_batch_size = args.batch_size * world_size * args.accumulation_steps
    logger.info(f'Effective batch size: {effective_batch_size}.')

    num_train_samples_per_epoch, num_dev_samples, num_unique_train_epochs = get_data_sizes(data_dir=args.data_dir,
                                                                                           num_epochs=args.num_epochs,
                                                                                           logger=logger)
    num_optimization_steps = sum(num_train_samples_per_epoch) // world_size // args.batch_size // \
                             args.accumulation_steps
    if args.max_steps > 0:
        num_optimization_steps = min(num_optimization_steps, args.max_steps)
    logger.info(f'Total number of optimization steps: {num_optimization_steps}.')

    # Set random seed
    logger.info(f'Using random seed {args.seed}.')
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get model
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    logger.info(f'Loading model {args.model} for task {args.task}...')
    model = ModelRegistry.get_model(args.task).from_pretrained(args.model)

    if args.local_rank in [-1, 0]:
        with open(os.path.join(args.save_dir, 'config.json'), 'w') as file:
            json.dump(model.config.__dict__, file)

    if args.local_rank == 0:
        torch.distributed.barrier()

    model.to(device)

    # Get optimizer
    logger.info('Creating optimizer...')
    parameter_groups = get_parameter_groups(model)
    optimizer = AdamW(parameter_groups, lr=args.learning_rate, weight_decay=args.weight_decay, eps=1e-8)
    scheduler = get_lr_scheduler(optimizer, num_steps=num_optimization_steps, warmup_proportion=args.warmup_proportion)

    if args.amp:
        amp.register_half_function(torch, 'einsum')
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level)

    if args.local_rank != -1:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    # Get dev data loader
    dev_data_file = os.path.join(args.data_dir, f'dev.jsonl.gz')
    logger.info(f'Creating dev dataset from {dev_data_file}...')
    dev_dataset = DatasetRegistry.get_dataset(args.task)(data_file=dev_data_file,
                                                         data_size=num_dev_samples,
                                                         local_rank=-1)
    dev_loader = DataLoader(dev_dataset,
                            batch_size=2 * args.batch_size,
                            num_workers=1,
                            collate_fn=dev_dataset.collate_fn)

    # Get evaluator
    evaluator = EvaluatorRegistry.get_evaluator(args.task)(data_loader=dev_loader,
                                                           logger=logger,
                                                           tb_writer=tb_writer,
                                                           device=device,
                                                           world_size=world_size,
                                                           args=args)

    # Get saver
    saver = CheckpointSaver(save_dir=args.save_dir,
                            max_checkpoints=args.max_checkpoints,
                            primary_metric=evaluator.primary_metric,
                            maximize_metric=evaluator.maximize_metric,
                            logger=logger)

    global_step = 0
    samples_processed = 0

    # Train
    logger.info('Training...')
    samples_till_eval = args.eval_every
    for epoch in range(1, args.num_epochs + 1):
        # Get train data loader for current epoch
        train_data_file_num = ((epoch - 1) % num_unique_train_epochs) + 1
        train_data_file = os.path.join(args.data_dir, f'epoch_{train_data_file_num}.jsonl.gz')
        logger.info(f'Creating training dataset from {train_data_file}...')
        train_dataset = DatasetRegistry.get_dataset(args.task)(train_data_file,
                                                               data_size=num_train_samples_per_epoch[epoch - 1],
                                                               local_rank=args.local_rank,
                                                               world_size=world_size)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=1,
                                  collate_fn=train_dataset.collate_fn)

        logger.info(f'Starting epoch {epoch}...')
        model.train()
        model.zero_grad()
        loss_values = defaultdict(float)
        samples_till_end = (num_optimization_steps - global_step) * effective_batch_size
        samples_in_cur_epoch = min([len(train_loader.dataset), samples_till_end])
        disable_progress_bar = (args.local_rank not in [-1, 0])
        with tqdm(total=samples_in_cur_epoch, disable=disable_progress_bar) as progress_bar:
            for step, batch in enumerate(train_loader, 1):
                batch = {name: tensor.to(device) for name, tensor in batch.items()}
                current_batch_size = batch['input_ids'].shape[0]

                outputs = model(**batch)
                loss, current_loss_values = outputs[:2]

                loss = loss / args.accumulation_steps
                for name, value in current_loss_values.items():
                    loss_values[name] += value / args.accumulation_steps

                if args.amp:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                samples_processed += current_batch_size * world_size
                samples_till_eval -= current_batch_size * world_size
                progress_bar.update(current_batch_size * world_size)

                if step % args.accumulation_steps == 0:
                    current_lr = scheduler.get_last_lr()[0]

                    if args.amp:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    # Log info
                    progress_bar.set_postfix(epoch=epoch, step=global_step, lr=current_lr, **loss_values)
                    if args.local_rank in [-1, 0]:
                        tb_writer.add_scalar('train/LR', current_lr, global_step)
                        for name, value in loss_values.items():
                            tb_writer.add_scalar(f'train/{name}', value, global_step)
                    loss_values = {name: 0 for name in loss_values}

                    if global_step == args.max_steps:
                        logger.info('Reached maximum number of optimization steps.')
                        break

                    if samples_till_eval <= 0:
                        samples_till_eval = args.eval_every
                        eval_results = evaluator.evaluate(model, global_step)
                        if args.local_rank in [-1, 0]:
                            saver.save(model, global_step, eval_results)

            if not args.do_not_eval_after_epoch:
                eval_results = evaluator.evaluate(model, global_step)
                if args.local_rank in [-1, 0]:
                    saver.save(model, global_step, eval_results)
Esempio n. 13
0
def train_ts(args):
    def build_scheduler(optimizers, args):
        optimizer, optimizer_sparse = optimizers
        scheduler_sparse = None

        if args.scheduler == "cosine":
            # here we do not set eta_min to lr_min to be backward compatible
            # because in previous versions eta_min is default to 0
            # rather than the default value of lr_min 1e-6
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, args.max_step,
                eta_min=args.eta_min)  # should use eta_min arg

        elif args.scheduler == "inv_sqrt":
            # originally used for Transformer (in Attention is all you need)
            def lr_lambda(step):
                # return a multiplier instead of a learning rate
                if step == 0 and args.warmup_step == 0:
                    return 1.0
                else:
                    return (1.0 /
                            (step**0.5) if step > args.warmup_step else step /
                            (args.warmup_step**1.5))

            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lr_lambda=lr_lambda)

        elif args.scheduler == "dev_perf":
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )

        elif args.scheduler == "constant":
            pass

        else:
            raise ValueError(f"scheduler type {args.scheduler} not recognized")

        return scheduler, scheduler_sparse

    ###############################################################################
    # Training code
    ###############################################################################
    def evaluate(eval_iter, model):

        # Turn on evaluation mode which disables dropout.
        model.eval()

        # debug
        # If the model does not use memory at all, make the ext_len longer.
        # Otherwise, make the mem_len longer and keep the ext_len the same.
        # if default_args.mem_len == 0:
        #     model.reset_length(default_args.eval_tgt_len,
        #                        default_args.ext_len + default_args.tgt_len -
        #                        default_args.eval_tgt_len, default_args.mem_len)
        # else:
        #     model.reset_length(default_args.eval_tgt_len,
        #                        default_args.ext_len, default_args.mem_len +
        #                       default_args.tgt_len - default_args.eval_tgt_len)

        # Evaluation
        total_len, total_loss = 0, 0.0
        with torch.no_grad():
            mems = tuple()
            for i, (data, target, seq_len) in enumerate(eval_iter):
                if i >= args.max_eval_steps > 0:
                    break
                ret = model(data, target, *mems)
                loss, mems = ret[0], ret[1:]
                loss = loss.mean()
                total_loss += seq_len * loss.float().item()
                total_len += seq_len

        # Switch back to the training mode
        # model.reset_length(default_args.tgt_len, default_args.ext_len,
        # default_args.mem_len)
        model.train()

        return total_loss / total_len

    # reverse distillation util
    def get_original_batches(model, tr_iter, integration_length):
        model.eval()
        if args.batch_chunk > 1:
            mems = [None for _ in range(args.batch_chunk)]
            first_logits = [[] for _ in range(args.batch_chunk)]
        else:
            mems = None
            first_logits = []
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
        with torch.no_grad():
            for batch, (data, target, seq_len) in enumerate(train_iter):
                if batch == integration_length:
                    break
                if args.batch_chunk > 1:
                    data_chunks = torch.chunk(data, args.batch_chunk, 1)
                    for i in range(args.batch_chunk):
                        data_i = data_chunks[i].contiguous()
                        logits, mems[i] = model._forward(data_i, mems=mems[i])
                        first_logits[i].append(logits.cpu())
                else:
                    logits, mems = model._forward(data, mems=mems)
                    first_logits.append(logits.cpu())
        return first_logits

    def build_optimizer(model, args, reload=False):
        optimizer_sparse = None
        if args.optim.lower() == "sgd":
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
        elif args.optim.lower() == "adam":
            optimizer = optim.Adam(model.parameters(), lr=args.lr)
        elif args.optim.lower() == "adagrad":
            optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        else:
            raise ValueError(f"optimizer type {args.optim} not recognized")

        if reload:
            if args.restart_from is not None:
                optim_name = f"optimizer_{args.restart_from}.pt"
            else:
                optim_name = "optimizer.pt"
            optim_file_name = os.path.join(args.restart_dir, optim_name)
            logging(f"reloading {optim_file_name}")
            if os.path.exists(os.path.join(args.restart_dir, optim_name)):
                with open(os.path.join(args.restart_dir, optim_name),
                          "rb") as optim_file:
                    opt_state_dict = torch.load(optim_file)
                    try:
                        optimizer.load_state_dict(opt_state_dict)
                    # in case the optimizer param groups aren't the same shape,
                    # merge them
                    except:
                        logging("merging optimizer param groups")
                        opt_state_dict["param_groups"][0]["params"] = [
                            param
                            for param_group in opt_state_dict["param_groups"]
                            for param in param_group["params"]
                        ]
                        opt_state_dict["param_groups"] = [
                            opt_state_dict["param_groups"][0]
                        ]
                        optimizer.load_state_dict(opt_state_dict)
            else:
                logging("Optimizer was not saved. Start from scratch.")

        return optimizer, optimizer_sparse

    def log_val(val_loss, step, compute):
        logging("-" * 100)
        log_str = ("| Eval {:3d} at step {:>8d} | time: {:5.2f}s "
                   "| valid loss {:5.2f}".format(
                       step // args.eval_interval,
                       step,
                       (time.time() - eval_start_time),
                       val_loss,
                   ))
        log_str += " | bpc {:9.5f}".format(val_loss / math.log(2))
        logging(log_str)
        logging("-" * 100)

    def epoch_loop(
        epoch,
        model,
        optimizers,
        schedulers,
    ):
        nonlocal train_step

        # Turn on training mode which enables dropout.
        if isinstance(model, nn.DataParallel):
            parent_model = model.module
        else:
            parent_model = model
        optimizer, optimizer_sparse = optimizers
        scheduler, scheduler_sparse = schedulers

        # global train_step, best_val_loss, eval_start_time, log_start_time
        train_losses = []
        model.train()
        if args.batch_chunk > 1:
            mems = [tuple() for _ in range(args.batch_chunk)]
        else:
            mems = tuple()
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter

        log_start_time = time.time()
        best_val_loss = float("Infinity")
        for batch, (data, target, seq_len) in enumerate(train_iter):
            model.zero_grad()
            if args.batch_chunk > 1:
                data_chunks = torch.chunk(data, args.batch_chunk, 1)
                target_chunks = torch.chunk(target, args.batch_chunk, 1)
                for i in range(args.batch_chunk):
                    data_i = data_chunks[i].contiguous()
                    target_i = target_chunks[i].contiguous()
                    ret = model(data_i, target_i, *mems[i])
                    loss, mems[i] = ret[0], ret[1:]
                    loss = loss.float().mean().type_as(loss) / args.batch_chunk
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    train_losses.append(loss.float().item())
            else:
                ret = model(data, target, *mems)
                loss, mems = ret[0], ret[1:]
                loss = loss.float().mean().type_as(loss)
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                train_losses.append(loss.float().item())

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            parent_model.compute += openai_compute(
                non_emb_param_count(parent_model, nseries), data.numel(), 1)

            # step-wise learning rate annealing
            train_step += 1
            parent_model.training_steps += 1

            # check for yet-to-thaw parameters
            if getattr(parent_model, "freeze_countdown", 0) > 0:
                parent_model.freeze_countdown -= 1

                # if this is the last step
                if parent_model.freeze_countdown == 0:
                    for parameter in parent_model.parameters():
                        parameter.requires_grad = True
                    logging("thawing all parameters")

            if args.scheduler in ["cosine", "constant", "dev_perf"]:
                # linear warmup stage
                if train_step < args.warmup_step:
                    curr_lr = args.lr * train_step / args.warmup_step
                    optimizer.param_groups = curr_lr
                else:
                    if args.scheduler == "cosine":
                        scheduler.step(train_step)
            elif args.scheduler == "inv_sqrt":
                scheduler.step(train_step)

            if train_step % args.log_interval == 0:
                cur_loss = np.mean(train_losses)
                elapsed = time.time() - log_start_time
                log_str = ("| epoch {:3d} step {:>8d} "
                           "| {:>6d} batches "
                           "| lr {:.3g} "
                           "| ms/batch {:5.2f} "
                           "| loss {:5.2f}".format(
                               epoch,
                               train_step,
                               batch + 1,
                               optimizer.param_groups[0]["lr"],
                               elapsed * 1000 / args.log_interval,
                               cur_loss,
                           ))

                log_str += " | bpc {:9.5f}".format(cur_loss / math.log(2))
                logging(log_str)

                train_losses = []
                log_start_time = time.time()

            if train_step % args.eval_interval == 0:
                val_loss = evaluate(va_iter, model)
                log_val(val_loss,
                        step=train_step,
                        compute=parent_model.compute)
                # Save the model if the validation loss is the best we've seen so
                # far.
                if not best_val_loss or val_loss < best_val_loss:
                    best_val_loss = val_loss
                    if not args.debug:
                        if args.fp16:
                            with open(
                                    os.path.join(args.work_dir,
                                                 "amp_checkpoint.pt"),
                                    "wb",
                            ) as f:
                                checkpoint = {
                                    "model": model.state_dict(),
                                    "optimizer": optimizer.state_dict(),
                                    "amp": amp.state_dict(),
                                }
                                torch.save(checkpoint, f)
                        else:
                            with open(os.path.join(args.work_dir, "model.pt"),
                                      "wb") as f:
                                torch.save(parent_model, f)
                            with open(
                                    os.path.join(args.work_dir,
                                                 "optimizer.pt"),
                                    "wb",
                            ) as f:
                                torch.save(optimizer.state_dict(), f)

                # dev-performance based learning rate annealing
                if args.scheduler == "dev_perf":
                    scheduler.step(val_loss)

                eval_start_time = time.time()

            if train_step == args.max_step:
                break

    def expand_model(
        strategy,
        integration,
        integration_length,
        n_add,
        model: MemTransformerLM,
        optimizers,
        schedulers,
        tr_iter,
        va_iter,
        epoch,
        step,
    ):
        optimizer, _ = optimizers
        scheduler, _ = schedulers
        if integration:
            if not integration_length or integration_length <= 0:
                warnings.warn(
                    f"integration {integration} passed but integration_length is {integration_length}"
                )
            else:
                logging(
                    f"applying integration strategy {integration} with integration length {integration_length}"
                )

        # pre-expansion validation
        logging(f"evaluating before expanding")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

        # infer example logits for reverse distillation
        if "reverse_distil" in integration:
            first_logits = get_original_batches(model, tr_iter,
                                                integration_length)

        # expansion
        logging(
            f"adding {n_add} layers before starting epoch {epoch} with method {strategy}"
        )
        new_layers = model.expand_layers(n_add,
                                         strategy=strategy,
                                         function=initialization_func)

        # optimizer update
        optimizer.add_param_group({
            "params":
            new_layers.parameters(),
            "lr":
            optimizer.param_groups[0]["lr"],
            "initial_lr":
            optimizer.param_groups[0]["initial_lr"],
        })
        scheduler.base_lrs.append(optimizer.param_groups[-1]["initial_lr"])

        # training loop for reverse distillation
        if "reverse_distil" in integration:
            fit_to_previous_model(model, new_layers, tr_iter, first_logits,
                                  integration)

        # freezing parameters for frozen restart, we do this afterwards else the
        # new layers get copied also without grads
        if "freeze" in integration and integration_length > 0:
            for param_group in optimizer.param_groups[:-1]:
                for parameter in param_group["params"]:
                    parameter.requires_grad = False
            model.freeze_countdown = integration_length

        # post-expansion validation
        logging(f"reevaluating")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

    def expand_state(param, state):
        if param.shape != state.shape:
            ratios = [
                param.shape[i] // state.shape[i]
                for i in range(len(param.shape))
            ]
            return state.repeat(*ratios)
        else:
            return state

    def widen_model(
        strategy,
        ratio,
        model: MemTransformerLM,
        optimizers,
        va_iter,
        epoch,
        step,
    ):
        optimizer, _ = optimizers

        # pre-expansion validation
        logging(f"evaluating before widening")

        # debug
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, compute=model.compute, step=step)

        # infer example logits for reverse distillation expansion
        logging(
            f"adding {ratio} layers before starting epoch {epoch} with method {strategy}"
        )
        model.add_heads(ratio, strategy=strategy, function=initialization_func)

        # optimizer update
        for param, states in optimizer.state.items():
            if isinstance(param, nn.Parameter):
                states["exp_avg"] = expand_state(param, states["exp_avg"])
                states["exp_avg_sq"] = expand_state(param,
                                                    states["exp_avg_sq"])

        # training loop for reverse distillation
        # post-expansion validation
        logging(f"reevaluating")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

    # reverse distillation trainer
    def fit_to_previous_model(model, new_layers, tr_iter, first_logits,
                              integration):
        mse_loss = torch.nn.MSELoss()
        if "partial" in integration:
            distil_optimizer, distil_optimizer_sparse = build_optimizer(
                new_layers, reload=False)
        else:
            distil_optimizer, distil_optimizer_sparse = build_optimizer(
                model, reload=False)
        if args.cuda and args.fp16:
            model, distil_optimizer = amp.initialize(model,
                                                     distil_optimizer,
                                                     opt_level=args.fp16)

        model.train()
        if args.batch_chunk > 1:
            mems = [None for _ in range(args.batch_chunk)]
        else:
            mems = None
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
        for batch, (data, _, _) in enumerate(train_iter):
            if batch == len(first_logits):
                break
            model.zero_grad()
            if args.batch_chunk > 1:
                data_chunks = torch.chunk(data, args.batch_chunk, 1)
                for i in range(args.batch_chunk):
                    data_i = data_chunks[i].contiguous()
                    logits, mems[i] = model._forward(data_i, mems=mems[i])
                    target_logits = first_logits[i][batch].to(logits.device)
                    loss = mse_loss(logits, target_logits) / args.batch_chunk
                    if args.fp16:
                        with amp.scale_loss(loss,
                                            distil_optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
            else:
                logits, mems = model._forward(data, mems=mems)
                target_logits = first_logits[batch].to(logits.device)
                loss = mse_loss(logits, target_logits)
                if args.fp16:
                    with amp.scale_loss(loss, distil_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(distil_optimizer), args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            distil_optimizer.step()

    ###################################################################################
    #
    # main()
    #
    args.tied = not args.not_tied

    if args.d_embed < 0:
        args.d_embed = args.d_model

    # Validate `--fp16` option
    if args.fp16:
        if not args.cuda:
            print("WARNING: --fp16 requires --cuda, ignoring --fp16 option")
            args.fp16 = False
        else:
            try:
                from apex import amp

                if args.fp16 == "O1":
                    amp.register_half_function(torch, "einsum")
            except:
                print("WARNING: apex not installed, ignoring --fp16 option")
                args.fp16 = False

    device = torch.device("cuda" if args.cuda else "cpu")

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run "
                "with --cuda ")
        else:
            torch.cuda.manual_seed_all(args.seed)

    ############################################################################
    # Logging
    ############################################################################

    assert args.ext_len >= 0, "extended context length must be non-negative"
    assert args.d_batch % args.batch_chunk == 0

    args.work_dir = "{}-{}".format(args.work_dir, args.dataset)
    args.work_dir = os.path.join(args.work_dir, time.strftime("%Y%m%d-%H%M%S"))
    logging = create_exp_dir(
        args.work_dir,
        scripts_to_save=["train_ts.py", "mem_transformer.py"],
        debug=args.debug,
    )

    ############################################################################
    # Load data
    ############################################################################
    time_series = get_time_series(args.datadir, args.dataset)
    nseries = len(time_series.vocab)
    args.n_token = nseries

    eval_batch_size = 20
    tr_iter = time_series.get_iterator(
        "train",
        args.d_batch,
        args.tgt_len,
        device=device,
        ext_len=args.ext_len,
    )
    va_iter = time_series.get_iterator(
        "valid",
        eval_batch_size,
        args.eval_tgt_len,
        device=device,
        ext_len=args.ext_len,
    )
    te_iter = time_series.get_iterator(
        "test",
        eval_batch_size,
        args.eval_tgt_len,
        device=device,
        ext_len=args.ext_len,
    )

    cutoffs, tie_projs = [], [False]

    ############################################################################
    # Define model
    ############################################################################

    initialization_func = partial(
        weights_init,
        init=args.init,
        init_range=args.init_range,
        init_std=args.init_std,
        proj_init_std=args.proj_init_std,
    )

    if args.restart and not args.fp16:
        if args.restart_from is not None:
            model_name = f"model_{args.restart_from}.pt"
        else:
            model_name = "model.pt"
        model_file_name = os.path.join(args.restart_dir, model_name)
        logging(f"reloading {model_file_name}")
        with open(model_file_name, "rb") as f:
            model = torch.load(f)
        # backwards compatibility with older saves
        if isinstance(model, nn.DataParallel):
            model = model.module
        model.backward_compatible(tie_weight=args.tied, tie_projs=tie_projs)
        if not args.fp16:
            model = model.float()
        model.apply(update_dropout)
        model.apply(update_dropatt)

    else:
        model = MemTransformerLM(
            nseries,
            args.n_layer,
            args.n_head,
            args.d_model,
            args.d_head,
            args.d_inner,
            args.dropout,
            args.dropatt,
            tie_weight=args.tied,
            d_embed=args.d_embed,
            div_val=args.div_val,
            tie_projs=tie_projs,
            pre_lnorm=args.pre_lnorm,
            tgt_len=args.tgt_len,
            ext_len=args.ext_len,
            mem_len=args.mem_len,
            cutoffs=cutoffs,
            same_length=args.same_length,
            clamp_len=args.clamp_len,
        )
        model.apply(initialization_func)

        # debug
        # model.word_emb.apply(initialization_func)
        # ensure embedding init is not overridden by out_layer in case of
        # weight sharing
    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = non_emb_param_count(model, nseries)

    logging("=" * 100)
    for k, v in args.__dict__.items():
        logging("    - {} : {}".format(k, v))
    logging("=" * 100)
    logging("#params = {}".format(args.n_all_param))
    logging("#non emb params = {}".format(args.n_nonemb_param))

    para_model = parallelize_model(model, args)
    optimizers = build_optimizer(para_model,
                                 args,
                                 reload=args.restart and not args.fp16)
    optimizer, optimizer_sparse = optimizers
    schedulers = build_scheduler(optimizers, args)
    scheduler, scheduler_sparse = schedulers

    if args.cuda and args.fp16:
        para_model, optimizer = amp.initialize(para_model,
                                               optimizer,
                                               opt_level=args.fp16)

        if args.restart:
            if args.restart_from is not None:
                checkpoint_name = f"amp_checkpoint_{args.restart_from}.pt"
            else:
                checkpoint_name = "amp_checkpoint.pt"
            with open(os.path.join(args.work_dir, checkpoint_name), "rb") as f:
                checkpoint = torch.load(f)
                model.load_state_dict(checkpoint["model"])
                optimizer.load_state_dict(checkpoint["optimizer"])
                amp.load_state_dict(checkpoint["amp"])

    ############################################################################
    # Training loop
    ############################################################################

    # Loop over epochs.
    if args.reset_lr:
        # then they're different and we use train_step only for the new lr
        # scheduling
        train_step = 0
        optimizer.defaults["lr"] = args.lr
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.lr
            param_group["initial_lr"] = args.lr
        scheduler.base_lrs = [args.lr] * len(scheduler.base_lrs)
    else:
        train_step = model.training_steps

    best_val_loss = None

    # Reload previous step number in case of default_args.restart
    if train_step > 0:
        logging(f"restarting from step {train_step}")

    log_start_time = time.time()
    eval_start_time = time.time()

    def run_training():
        nonlocal train_step

        for epoch in itertools.count(start=first_epoch):
            # we check before the training loop; expanding at epoch 0 means
            # before training (for debug purposes)
            if args.expand and str(epoch - 1) in args.expansion_dict:
                n_add = int(args.expansion_dict[str(epoch - 1)])
                expand_model(
                    args.expand,
                    args.integration,
                    args.integration_length,
                    n_add,
                    model,
                    optimizers,
                    schedulers,
                    tr_iter,
                    va_iter,
                    epoch,
                    train_step,
                )
            if args.widen and str(epoch - 1) in args.widen_dict:
                ratio = int(args.widen_dict[str(epoch - 1)])
                widen_model(
                    args.widen,
                    ratio,
                    model,
                    optimizers,
                    va_iter,
                    epoch,
                    train_step,
                )
            epoch_loop(epoch, para_model, optimizers, schedulers)
            if train_step >= args.max_step:
                logging("-" * 100)
                logging("End of training")
                break
            if not args.debug and args.log_first_epochs:
                if epoch <= args.log_first_epochs:
                    logging(f"saving model at the end of epoch {epoch}")
                    if args.fp16:
                        with open(
                                os.path.join(args.work_dir,
                                             f"amp_checkpoint_{epoch}.pt"),
                                "wb",
                        ) as f:
                            checkpoint = {
                                "model": model.state_dict(),
                                "optimizer": optimizer.state_dict(),
                                "amp": amp.state_dict(),
                            }
                            torch.save(checkpoint, f)
                    else:
                        with open(
                                os.path.join(args.work_dir,
                                             f"model_{epoch}.pt"),
                                "wb",
                        ) as f:
                            torch.save(model, f)
                        with open(
                                os.path.join(args.work_dir,
                                             f"optimizer_{epoch}.pt"),
                                "wb",
                        ) as f:
                            torch.save(optimizer.state_dict(), f)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        if args.restart_from:
            first_epoch = args.restart_from + 1
            print(f"restarting from epoch {first_epoch}")
        else:
            first_epoch = 1
        run_training()

    except KeyboardInterrupt:
        logging("-" * 100)
        logging("Exiting from training early")

    # Load the best model.
    if args.fp16:
        with open(os.path.join(args.work_dir, "amp_checkpoint.pt"), "rb") as f:
            checkpoint = torch.load(f)
            model.load_state_dict(checkpoint["model"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            amp.load_state_dict(checkpoint["amp"])
    else:
        with open(os.path.join(args.work_dir, "model.pt"), "rb") as f:
            model = torch.load(f)
    para_model = model.to(device)

    # Run on test data.
    test_loss = evaluate(te_iter, para_model)
    logging("=" * 100)
    logging("| End of training | test loss {:5.2f} | test bpc {:9.5f}".format(
        test_loss, test_loss / math.log(2)))
    logging("=" * 100)
Esempio n. 14
0
                return None, dw, db, None, None, None, None, None
        else:
            if (not ctx.needs_input_grad[1] and not ctx.needs_input_grad[0]):
                return None, None, None, None, None, None, None  
            dx, dw = NHWC.cudnn_convolution_transpose_backward_nhwc(x, grad_y, w,
                                                       ctx.padding, ctx.output_padding, ctx.stride, ctx.dilation, ctx.groups,
                                                       torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
                                                       list(ctx.needs_input_grad[0:2]))
            if (not ctx.needs_input_grad[1]):
                return None, None, None, None, None, None, None, None  
            elif ctx.needs_input_grad[0]:
                return dx, dw, None, None, None, None, None, None
            else:
                return None, dw, None, None, None, None, None, None

amp.register_half_function(conv2d_NHWC_impl,'apply')
amp.register_half_function(conv2d_transpose_NHWC_impl,'apply')

class Conv2d_NHWC(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d_NHWC, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias=bias, padding_mode='zeros')
        # permute filters
        self.weight = torch.nn.Parameter(self.weight.permute(0, 2, 3, 1).contiguous())
    def forward(self, x):
Esempio n. 15
0
from joeynmt.helpers import log_data_info, load_config, log_cfg, \
    store_attention_plots, load_checkpoint, make_model_dir, \
    make_logger, set_seed, symlink_update, latest_checkpoint_update, \
    ConfigurationError
from joeynmt.model import Model, _DataParallel
from joeynmt.prediction import validate_on_data
from joeynmt.loss import XentLoss
from joeynmt.data import load_data, make_data_iter
from joeynmt.builders import build_optimizer, build_scheduler, \
    build_gradient_clipper
from joeynmt.prediction import test

# for fp16 training
try:
    from apex import amp
    amp.register_half_function(torch, "einsum")
except ImportError as no_apex:
    # error handling in TrainManager object construction
    pass

logger = logging.getLogger(__name__)


# pylint: disable=too-many-instance-attributes
class TrainManager:
    """ Manages training loop, validations, learning rate scheduling
    and early stopping."""
    def __init__(self,
                 model: Model,
                 config: dict,
                 batch_class: Batch = Batch) -> None:
Esempio n. 16
0
        'same_length': False,
        'clamp_len': -1,
        'seed': 1111,
        'max_step': 100,
        'cuda': True,
        'multi_gpu': False,
        'gpu0_bsz': -1,
        'debug': False,
        'knockknock': True,
        'tied': True
    })

    device = torch.device('cuda' if default_args.cuda else 'cpu')

    if args.fp16 == "O1":
        amp.register_half_function(torch, 'einsum')

    cutoffs, tie_projs = [], [False]
    if default_args.adaptive:
        assert default_args.dataset in ['wt103', 'lm1b']
        if default_args.dataset == 'wt103':
            cutoffs = [20000, 40000, 200000]
            tie_projs += [True] * len(cutoffs)
        elif default_args.dataset == 'lm1b':
            cutoffs = [60000, 100000, 640000]
            tie_projs += [False] * len(cutoffs)

    for n_layer, d_model, batch_size in product(args.n_layers, args.d_models,
                                                args.batch_sizes):

        n_layer, d_model, batch_size = int(n_layer), int(d_model), int(
Esempio n. 17
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    # Optimizer parameters
    parser.add_argument("--adam_epsilon",
                        default=1e-6,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--adam_betas", default="(0.9, 0.999)", type=str)
    parser.add_argument("--no_bias_correction",
                        default=False,
                        action='store_true')

    # Other parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        help="SQuAD-format json file for training.")
    parser.add_argument("--predict_file",
                        default=None,
                        type=str,
                        help="SQuAD-format json file for evaluation.")
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_label", action='store_true')
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--fp16_opt_level', default='O1', type=str)
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument(
        '--version_2_with_negative',
        action='store_true',
        help=
        'If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument(
        '--null_score_diff_threshold',
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null."
    )
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument(
        '--no_masking',
        action='store_true',
        help='If true, we do not mask the span loss for no-answer examples.')
    parser.add_argument(
        '--skip_negatives',
        action='store_true',
        help=
        'If true, we skip negative examples during training; this is mainly for ablation.'
    )
    # For Natural Questions
    parser.add_argument(
        '--max_answer_len',
        type=int,
        default=1000000,
        help=
        "maximum length of answer tokens (might be set to 5 for Natural Questions!)"
    )

    # balance the two losses.
    parser.add_argument(
        '--lambda_scale',
        type=float,
        default=1.0,
        help=
        "If you would like to change the two losses, please change the lambda scale."
    )

    # Save checkpoints more
    parser.add_argument(
        '--save_gran',
        type=str,
        default="10,3",
        help='"10,5" means saving a checkpoint every 1/10 of the total updates,'
        'but start saving from the 5th attempt')
    parser.add_argument('--oss_cache_dir', default=None, type=str)
    parser.add_argument('--cache_dir', default=None, type=str)
    parser.add_argument('--dist', default=False, action='store_true')

    args = parser.parse_args()
    print(args)

    if args.dist:
        dist.init_process_group(backend='nccl')
        print(f"local rank: {args.local_rank}")
        print(f"global rank: {dist.get_rank()}")
        print(f"world size: {dist.get_world_size()}")

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of
        # synchronizing nodes/GPUs
        dist.init_process_group(backend='nccl')

    if args.dist:
        global_rank = dist.get_rank()
        world_size = dist.get_world_size()
        if world_size > 1:
            args.local_rank = global_rank

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory () already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Prepare model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    model = IterBertForQuestionAnsweringConfidence.from_pretrained(
        args.bert_model,
        num_labels=4,
        no_masking=args.no_masking,
        lambda_scale=args.lambda_scale)

    model.to(device)

    train_examples = None
    train_features = None
    num_train_optimization_steps = None
    if args.do_train:
        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}'.format(
            model.base_model_prefix, str(args.max_seq_length),
            str(args.doc_stride), str(args.max_query_length),
            tokenizer.do_lower_case)
        cached_train_features_file_name = cached_train_features_file.split(
            '/')[-1]
        _oss_feature_save_path = os.path.join(oss_features_cache_dir,
                                              cached_train_features_file_name)

        try:
            if args.cache_dir is not None and os.path.exists(
                    os.path.join(args.cache_dir,
                                 cached_train_features_file_name)):
                logger.info(
                    f"Loading pre-processed features from {os.path.join(args.cache_dir, cached_train_features_file_name)}"
                )
                train_features = torch.load(
                    os.path.join(args.cache_dir,
                                 cached_train_features_file_name))
            else:
                logger.info(
                    f"Loading pre-processed features from oss: {_oss_feature_save_path}"
                )
                train_features = torch.load(
                    load_buffer_from_oss(_oss_feature_save_path))
        except:
            train_examples = read_squad_examples(
                input_file=args.train_file,
                is_training=True,
                version_2_with_negative=args.version_2_with_negative,
                max_answer_len=args.max_answer_len,
                skip_negatives=args.skip_negatives)
            train_features = convert_examples_to_features_yes_no(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
            if args.local_rank in [-1, 0]:
                torch_save_to_oss(train_features, _oss_feature_save_path)
                logger.info(
                    f"Saving train features into oss: {_oss_feature_save_path}"
                )

        num_train_optimization_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs

    if args.do_label:
        logger.info("finished.")
        return

    if args.do_train:
        # Prepare optimizer
        param_optimizer = list(model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        t_total = num_train_optimization_steps
        if args.local_rank != -1:
            t_total = t_total // dist.get_world_size()

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          betas=eval(args.adam_betas),
                          eps=args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, int(t_total * args.warmup_proportion),
            num_train_optimization_steps)

        if args.fp16:
            from apex import amp

            if args.fp16_opt_level == 'O1':
                amp.register_half_function(torch, "einsum")

            if args.loss_scale == 0:
                model, optimizer = amp.initialize(
                    model, optimizer, opt_level=args.fp16_opt_level)
            else:
                model, optimizer = amp.initialize(
                    model,
                    optimizer,
                    opt_level=args.fp16_opt_level,
                    loss_scale=args.loss_scale)
        if args.local_rank != -1:
            if args.fp16_opt_level == 'O2':
                try:
                    import apex
                    model = apex.parallel.DistributedDataParallel(
                        model, delay_allreduce=True)
                except ImportError:
                    model = torch.nn.parallel.DistributedDataParallel(
                        model, find_unused_parameters=True)
            else:
                model = torch.nn.parallel.DistributedDataParallel(
                    model, find_unused_parameters=True)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        global_step = 0

        logger.info("***** Running training *****")
        if train_examples:
            logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Instantaneous batch size per GPU = %d",
                    args.train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.train_batch_size * args.gradient_accumulation_steps *
            (dist.get_world_size() if args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_start_positions = torch.tensor(
            [f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor(
            [f.end_position for f in train_features], dtype=torch.long)
        all_switches = torch.tensor([f.switch for f in train_features],
                                    dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_start_positions,
                                   all_end_positions, all_switches)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size,
                                      pin_memory=True,
                                      num_workers=4)

        if args.save_gran is not None:
            save_chunk, save_start = args.save_gran.split(',')
            save_chunk = t_total // int(save_chunk)
            save_start = int(save_start)

        model.train()
        tr_loss = 0
        for _epc in trange(int(args.num_train_epochs), desc="Epoch"):
            if args.local_rank != -1:
                train_dataloader.sampler.set_epoch(_epc)
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Iteration",
                         disable=args.local_rank not in [-1, 0])):
                if n_gpu == 1:
                    # multi-gpu does scattering it-self
                    batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, start_positions, end_positions, switches = batch
                loss = model(input_ids=input_ids,
                             token_type_ids=segment_ids,
                             attention_mask=input_mask,
                             start_positions=start_positions,
                             end_positions=end_positions,
                             switch_list=switches)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1

                    if global_step % 50 == 0:
                        logger.info(f"Training loss: {tr_loss / global_step}\t"
                                    f"Learning rate: {scheduler.get_lr()[0]}\t"
                                    f"Global step: {global_step}")

                    if args.save_gran is not None and args.local_rank in [
                            -1, 0
                    ]:
                        if (global_step % save_chunk == 0) and (
                                global_step // save_chunk >= save_start):
                            logger.info('Saving a checkpoint....')
                            output_dir_per_epoch = os.path.join(
                                args.output_dir,
                                str(global_step) + 'steps')
                            os.makedirs(output_dir_per_epoch)

                            # Save a trained model, configuration and tokenizer
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self

                            if args.oss_cache_dir is not None:
                                _oss_model_save_path = os.path.join(
                                    args.oss_cache_dir, f"{global_step}steps")
                                torch_save_to_oss(
                                    model_to_save.state_dict(),
                                    _oss_model_save_path +
                                    "/pytorch_model.bin")
                            model_to_save.save_pretrained(output_dir_per_epoch)
                            tokenizer.save_pretrained(output_dir_per_epoch)
                            logger.info('Done')

    if args.do_train and (args.local_rank == -1 or dist.get_rank() == 0):
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self

        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
        torch_save_to_oss(
            model_to_save.state_dict(),
            os.path.join(args.oss_cache_dir, "pytorch_model.bin"))

        # Load a trained model and vocabulary that you have fine-tuned
        # model = IterBertForQuestionAnsweringConfidence.from_pretrained(
        #     args.output_dir, num_labels=4, no_masking=args.no_masking)
        tokenizer = AutoTokenizer.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)

    if args.do_train is False and args.do_predict is True:
        model = IterBertForQuestionAnsweringConfidence.from_pretrained(
            args.output_dir, num_labels=4, no_masking=args.no_masking)
        tokenizer = AutoTokenizer.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
    elif args.do_train is True and args.do_predict is True:
        model = IterBertForQuestionAnsweringConfidence.from_pretrained(
            args.output_dir, num_labels=4, no_masking=args.no_masking)
        tokenizer = AutoTokenizer.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
    else:
        model = IterBertForQuestionAnsweringConfidence.from_pretrained(
            args.bert_model,
            num_labels=4,
            no_masking=args.no_masking,
            lambda_scale=args.lambda_scale)

    model.to(device)

    if args.do_predict and (args.local_rank == -1 or dist.get_rank() == 0):
        eval_examples = read_squad_examples(
            input_file=args.predict_file,
            is_training=False,
            version_2_with_negative=args.version_2_with_negative,
            max_answer_len=args.max_answer_length,
            skip_negatives=args.skip_negatives)
        eval_features = convert_examples_to_features_yes_no(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0),
                                         dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_example_index)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices in tqdm(
                eval_dataloader,
                desc="Evaluating",
                disable=args.local_rank not in [-1, 0]):
            if len(all_results) % 1000 == 0:
                logger.info("Processing example: %d" % (len(all_results)))
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits, batch_switch_logits = model(
                    input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                switch_logits = batch_switch_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(
                    RawResult(unique_id=unique_id,
                              start_logits=start_logits,
                              end_logits=end_logits,
                              switch_logits=switch_logits))
        output_prediction_file = os.path.join(args.output_dir,
                                              "predictions.json")
        output_nbest_file = os.path.join(args.output_dir,
                                         "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(args.output_dir,
                                                 "null_odds.json")
        write_predictions_yes_no_no_empty_answer(
            eval_examples, eval_features, all_results, args.n_best_size,
            args.max_answer_length, args.do_lower_case, output_prediction_file,
            output_nbest_file, output_null_log_odds_file, args.verbose_logging,
            args.version_2_with_negative, args.null_score_diff_threshold,
            args.no_masking)