Beispiel #1
0
def main(step, args, model_state_dict, optimizer_state_dict):
    #
    # PART2
    #

    model = build_model(args).cuda()
    one_ll = next(model.children()).weight
    optimizer = FusedAdam(model.parameters())
    ASP.init_model_for_pruning(model,
                               args.pattern,
                               verbosity=args.verbosity,
                               whitelist=args.whitelist,
                               allow_recompute_mask=args.allow_recompute_mask)
    ASP.init_optimizer_for_pruning(optimizer)

    torch.manual_seed(args.seed2)
    model.load_state_dict(model_state_dict)
    optimizer.load_state_dict(optimizer_state_dict)

    print("Model sparsity is %s" %
          ("enabled" if ASP.sparsity_is_enabled() else "disabled"))

    # train for a few steps with sparse weights
    print("SPARSE :: ", one_ll)
    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)
Beispiel #2
0
def create_optimizers(model,
                      args,
                      lr_schedule,
                      prev_optimizer=None,
                      prev_scheduler=None):
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = FusedAdam(params, lr=args.lr)
    optimizer = FP16_Optimizer(optimizer,
                               dynamic_loss_scale=True,
                               verbose=False)

    if prev_optimizer is not None:
        optimizer.load_state_dict(prev_optimizer.state_dict())

    if args.warmup < 0:
        print('No learning rate schedule used.')
    else:
        print('Using learning rate schedule.')
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer.optimizer,
                                                      lr_schedule)
        if prev_scheduler is not None:
            # Continue LR schedule from previous scheduler
            scheduler.load_state_dict(prev_scheduler.state_dict())

    loss_model = SimpleDistributedDataParallel(model, args.world_size)
    return loss_model, optimizer, scheduler if args.warmup > 0 else None
Beispiel #3
0
def load_opt(H, vae, logprint):
    optimizer = AdamW(vae.parameters(), weight_decay=H.wd, lr=H.lr, betas=(H.adam_beta1, H.adam_beta2))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_warmup(H.warmup_iters))

    if H.restore_optimizer_path:
        optimizer.load_state_dict(
            torch.load(distributed_maybe_download(H.restore_optimizer_path, H.local_rank, H.mpi_size), map_location='cpu'))
    if H.restore_log_path:
        cur_eval_loss, iterate, starting_epoch = restore_log(H.restore_log_path, H.local_rank, H.mpi_size)
    else:
        cur_eval_loss, iterate, starting_epoch = float('inf'), 0, 0
    logprint('starting at epoch', starting_epoch, 'iterate', iterate, 'eval loss', cur_eval_loss)
    return optimizer, scheduler, cur_eval_loss, iterate, starting_epoch
Beispiel #4
0
def create_optimizer(model, learning_rate, t_total, loss_scale, fp16,
                     warmup_proportion, state_dict):
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = [
        'bias',
        'LayerNorm.bias',
        'LayerNorm.weight',
        'adapter.down_project.weight',
        'adapter.up_project.weight',
    ]
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex "
                "to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=learning_rate,
                             warmup=warmup_proportion,
                             t_total=t_total)

    if state_dict is not None:
        optimizer.load_state_dict(state_dict)
    return optimizer
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default='tmp',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_file",
        default="training.log",
        type=str,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--do_train",
        action='store_true',
        help="Whether to run training. This should ALWAYS be set to True.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=30,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--global_rank",
                        type=int,
                        default=-1,
                        help="global_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 32-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=100,
                        help="The length of visual token input")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=20,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='b',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=3,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=4,
                        type=int,
                        help="Number of workers for the data loader.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")

    # Others for VLP
    parser.add_argument(
        "--src_file",
        default=['/mnt/dat/COCO/annotations/dataset_coco.json'],
        type=str,
        nargs='+',
        help="The input data file name.")
    parser.add_argument('--enable_visdom', action='store_true')
    parser.add_argument('--visdom_port', type=int, default=8888)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument('--image_root',
                        type=str,
                        default='/mnt/dat/COCO/images')
    parser.add_argument('--dataset',
                        default='coco',
                        type=str,
                        help='coco | flickr30k | cc')
    parser.add_argument('--split',
                        type=str,
                        nargs='+',
                        default=['train', 'restval'])

    parser.add_argument('--world_size',
                        default=1,
                        type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url',
                        default='file://[PT_OUTPUT_DIR]/nonexistent_file',
                        type=str,
                        help='url used to set up distributed training')
    parser.add_argument(
        '--file_valid_jpgs',
        default='/mnt/dat/COCO/annotations/coco_valid_jpgs.json',
        type=str)
    parser.add_argument('--sche_mode',
                        default='warmup_linear',
                        type=str,
                        help="warmup_linear | warmup_constant | warmup_cosine")
    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--use_num_imgs', default=-1, type=int)
    parser.add_argument('--vis_mask_prob', default=0, type=float)
    parser.add_argument('--max_drop_worst_ratio', default=0, type=float)
    parser.add_argument('--drop_after', default=6, type=int)

    parser.add_argument(
        '--s2s_prob',
        default=1,
        type=float,
        help="Percentage of examples that are bi-uni-directional LM (seq2seq)."
    )
    parser.add_argument(
        '--bi_prob',
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.")
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument(
        '--region_bbox_file',
        default=
        'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5',
        type=str)
    parser.add_argument(
        '--region_det_file_prefix',
        default=
        'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval',
        type=str)
    parser.add_argument('--tasks', default='img2txt', help='img2txt | vqa2')
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--scst',
                        action='store_true',
                        help='Self-critical sequence training')

    args = parser.parse_args()

    print('global_rank: {}, local rank: {}'.format(args.global_rank,
                                                   args.local_rank))

    args.max_seq_length = args.max_len_b + args.len_vis_input + 3  # +3 for 2x[SEP] and [CLS]
    args.mask_image_regions = (args.vis_mask_prob > 0
                               )  # whether to mask out image regions
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    # arguments inspection
    assert (args.tasks in ('img2txt', 'vqa2'))
    assert args.enable_butd == True, 'only support region attn! featmap attn deprecated'
    assert (
        not args.scst) or args.dataset == 'coco', 'scst support on coco only!'
    if args.scst:
        assert args.dataset == 'coco', 'scst support on coco only!'
        assert args.max_pred == 0 and args.mask_prob == 0, 'no mask for scst!'
        rl_crit = RewardCriterion()

    if args.enable_butd:
        assert (args.len_vis_input == 100)
        args.region_bbox_file = os.path.join(args.image_root,
                                             args.region_bbox_file)
        args.region_det_file_prefix = os.path.join(
            args.image_root, args.region_det_file_prefix) if args.dataset in (
                'cc', 'coco') and args.region_det_file_prefix != '' else ''

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://localhost:10001',  #args.dist_url,
            world_size=args.world_size,
            rank=args.global_rank)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom
        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {'iter': None, 'score': None}

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir +
        '/.pretrained_model_{}'.format(args.global_rank))
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer

    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_image_regions=args.mask_image_regions,
                mode="s2s",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == 'vqa2'))
        ]
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_image_regions=args.mask_image_regions,
                mode="bi",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == 'vqa2')))

        train_dataset = seq2seq_loader.Img2txtDataset(
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,
            bi_prob=args.bi_prob,
            enable_butd=args.enable_butd,
            tasks=args.tasks)

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs * 1. /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'img2txt' else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, 'must init from maximum likelihood training'
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(
                os.path.join(args.output_dir,
                             "model.{0}.bin".format(recover_step)))
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total * 1. /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0
        if not args.scst:
            model = BertForPreTrainingLossMask.from_pretrained(
                args.bert_model,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                relax_projection=relax_projection,
                config_path=args.config_path,
                task_idx=task_idx_proj,
                max_position_embeddings=args.max_position_embeddings,
                label_smoothing=args.label_smoothing,
                fp32_embedding=args.fp32_embedding,
                cache_dir=args.output_dir +
                '/.pretrained_model_{}'.format(args.global_rank),
                drop_prob=args.drop_prob,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
                tasks=args.tasks)
        else:
            model = BertForSeq2SeqDecoder.from_pretrained(
                args.bert_model,
                max_position_embeddings=args.max_position_embeddings,
                config_path=args.config_path,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                task_idx=task_idx_proj,
                mask_word_id=mask_word_id,
                search_beam_size=1,
                eos_id=eos_word_ids,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input)

        del model_recover
        torch.cuda.empty_cache()

    # deprecated
    # from vlp.resnet import resnet
    # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    # cnn.to(device)
    if args.local_rank != -1:
        try:
            # from apex.parallel import DistributedDataParallel as DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
        # cnn = DDP(cnn)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)
        # cnn = DataParallelImbalance(cnn)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             schedule=args.sche_mode,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir,
                         "optim.{0}.bin".format(recover_step)))
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)
        logger.info("  Loader length = %d", len(train_dataloader))

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              args.num_train_epochs + 1,
                              desc="Epoch"):
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            train_loss = []
            pretext_loss = []
            vqa2_loss = []
            scst_reward = []
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels = batch

                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                if args.enable_butd:
                    conv_feats = img.data  # Bx100x2048
                    vis_pe = vis_pe.data
                else:
                    conv_feats, _ = cnn(img.data)  # Bx2048x7x7
                    conv_feats = conv_feats.view(conv_feats.size(0),
                                                 conv_feats.size(1),
                                                 -1).permute(0, 2,
                                                             1).contiguous()

                if not args.scst:
                    loss_tuple = model(
                        conv_feats,
                        vis_pe,
                        input_ids,
                        segment_ids,
                        input_mask,
                        lm_label_ids,
                        ans_labels,
                        is_next,
                        masked_pos=masked_pos,
                        masked_weights=masked_weights,
                        task_idx=task_idx,
                        vis_masked_pos=vis_masked_pos,
                        mask_image_regions=args.mask_image_regions,
                        drop_worst_ratio=args.max_drop_worst_ratio
                        if i_epoch > args.drop_after else 0)
                    mean_reward = loss_tuple[0].new(1).fill_(0)
                else:
                    # scst training
                    model.eval()
                    position_ids = torch.arange(
                        input_ids.size(1),
                        dtype=input_ids.dtype,
                        device=input_ids.device).unsqueeze(0).expand_as(
                            input_ids)
                    input_dummy = input_ids[:, :args.len_vis_input +
                                            2]  # +2 for [CLS] and [SEP]
                    greedy_res = input_ids.new(
                        input_ids.size(0),
                        input_ids.size(1) - args.len_vis_input - 2).fill_(0)
                    gen_result = input_ids.new(
                        input_ids.size(0),
                        input_ids.size(1) - args.len_vis_input - 2).fill_(0)

                    with torch.no_grad():
                        greedy_res_raw, _ = model(conv_feats,
                                                  vis_pe,
                                                  input_dummy,
                                                  segment_ids,
                                                  position_ids,
                                                  input_mask,
                                                  task_idx=task_idx,
                                                  sample_mode='greedy')
                        for b in range(greedy_res_raw.size(0)):
                            for idx in range(greedy_res_raw.size(1)):
                                if greedy_res_raw[b][idx] not in [
                                        eos_word_ids, pad_word_ids
                                ]:
                                    greedy_res[b][idx] = greedy_res_raw[b][idx]
                                else:
                                    if greedy_res_raw[b][idx] == eos_word_ids:
                                        greedy_res[b][idx] = eos_word_ids
                                    break
                    model.train()
                    gen_result_raw, sample_logprobs = model(
                        conv_feats,
                        vis_pe,
                        input_dummy,
                        segment_ids,
                        position_ids,
                        input_mask,
                        task_idx=task_idx,
                        sample_mode='sample')
                    for b in range(gen_result_raw.size(0)):
                        for idx in range(gen_result_raw.size(1)):
                            if gen_result_raw[b][idx] not in [
                                    eos_word_ids, pad_word_ids
                            ]:
                                gen_result[b][idx] = gen_result_raw[b][idx]
                            else:
                                if gen_result_raw[b][idx] == eos_word_ids:
                                    gen_result[b][idx] = eos_word_ids
                                break

                    gt_ids = input_ids[:, args.len_vis_input + 2:]
                    reward = get_self_critical_reward(greedy_res,
                                                      gt_ids, gen_result,
                                                      gt_ids.size(0))
                    reward = torch.from_numpy(reward).float().to(
                        gen_result.device)
                    mean_reward = reward.mean()
                    loss = rl_crit(sample_logprobs, gen_result.data, reward)

                    loss_tuple = [
                        loss,
                        loss.new(1).fill_(0.),
                        loss.new(1).fill_(0.)
                    ]

                # disable pretext_loss_deprecated for now
                masked_lm_loss, pretext_loss_deprecated, ans_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu. For dist, this is done through gradient addition.
                    masked_lm_loss = masked_lm_loss.mean()
                    pretext_loss_deprecated = pretext_loss_deprecated.mean()
                    ans_loss = ans_loss.mean()
                loss = masked_lm_loss + pretext_loss_deprecated + ans_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())
                train_loss.append(loss.item())
                pretext_loss.append(pretext_loss_deprecated.item())
                vqa2_loss.append(ans_loss.item())
                scst_reward.append(mean_reward.item())
                if step % 100 == 0:
                    logger.info(
                        "Epoch {}, Iter {}, Loss {:.2f}, Pretext {:.2f}, VQA2 {:.2f}, Mean R {:.3f}\n"
                        .format(i_epoch, step, np.mean(train_loss),
                                np.mean(pretext_loss), np.mean(vqa2_loss),
                                np.mean(scst_reward)))

                if args.enable_visdom:
                    if vis_window['iter'] is None:
                        vis_window['iter'] = vis.line(
                            X=np.tile(
                                np.arange((i_epoch - 1) * nbatches + step,
                                          (i_epoch - 1) * nbatches + step + 1),
                                (1, 1)).T,
                            Y=np.column_stack(
                                (np.asarray([np.mean(train_loss)]), )),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total']))
                    else:
                        vis.line(X=np.tile(
                            np.arange((i_epoch - 1) * nbatches + step,
                                      (i_epoch - 1) * nbatches + step + 1),
                            (1, 1)).T,
                                 Y=np.column_stack(
                                     (np.asarray([np.mean(train_loss)]), )),
                                 opts=dict(title='Training Loss',
                                           xlabel='Training Iteration',
                                           ylabel='Loss',
                                           legend=['total']),
                                 win=vis_window['iter'],
                                 update='append')

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(args.output_dir,
                                             "model.{0}.bin".format(i_epoch))
            output_optim_file = os.path.join(args.output_dir,
                                             "optim.{0}.bin".format(i_epoch))
            if args.global_rank in (
                    -1, 0):  # save model if the first device or no dist
                torch.save(
                    copy.deepcopy(model_to_save).cpu().state_dict(),
                    output_model_file)
                # torch.save(optimizer.state_dict(), output_optim_file) # disable for now, need to sanitize state and ship everthing back to cpu

            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print (vocab.size)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.layers, args.approx)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)
   
    weight_decay_params = []
    no_weight_decay_params = []
    
    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{'params':weight_decay_params, 'weight_decay':0.01},
                        {'params':no_weight_decay_params, 'weight_decay':0.}]
    if args.world_size > 1:
        torch.manual_seed(1234 + dist.get_rank())
        random.seed(5678 + dist.get_rank())
    
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        optimizer = FusedAdam(grouped_params,
                              lr=args.lr,
                              betas=(0.9, 0.999),
                              eps =1e-6,
                              bias_correction=False,
                              max_grad_norm=1.0)
        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)

    else:
        optimizer = AdamWeightDecayOptimizer(grouped_params,
                           lr=args.lr, betas=(0.9, 0.999), eps=1e-6)
    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    train_data = DataLoader(vocab, args.train_data, args.batch_size, args.max_len)
    batch_acm = 0
    acc_acm, ntokens_acm, npairs_acm, loss_acm = 0., 0., 0., 0.
    while True:
        model.train()
        for truth, inp, msk in train_data:
            batch_acm += 1
            if batch_acm <= args.warmup_steps:
                update_lr(optimizer, args.lr*batch_acm/args.warmup_steps)
            truth = truth.cuda(local_rank)
            inp = inp.cuda(local_rank)
            msk = msk.cuda(local_rank)

            optimizer.zero_grad()
            res, loss, acc, ntokens, npairs = model(truth, inp, msk)
            loss_acm += loss.item()
            acc_acm += acc
            ntokens_acm += ntokens
            npairs_acm += npairs
            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()
            if args.world_size > 1:
                average_gradients(model)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.print_every == -1%args.print_every:
                print ('batch_acm %d, loss %.3f, acc %.3f, x_acm %d'%(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, npairs_acm))
                acc_acm, ntokens_acm, loss_acm = 0., 0., 0.
            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.save_every == -1%args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                torch.save({'args':args, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}, '%s/epoch%d_batch_%d'%(args.save_dir, train_data.epoch_id, batch_acm))
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained BERT model. "
        "This specifies the model architecture.")
    parser.add_argument(
        "--vocab_file",
        default=None,
        type=str,
        required=True,
        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument(
        "--init_checkpoint",
        default=None,
        type=str,
        help="Initial checkpoint (usually from a pre-trained BERT model).")
    ## Required parameters
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    ## Other parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--train_ans_file",
        default=None,
        type=str,
        help="SQuAD answer for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
    )
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--restore', default=False)
    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        #        torch.backends.cudnn.benchmark = True
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    if os.path.exists(args.output_dir) == False:
        # raise ValueError("Output directory () already exists and is not empty.")
        os.makedirs(args.output_dir, exist_ok=True)

    import pickle as cPickle
    train_examples = None
    num_train_steps = None
    if args.do_train:
        raw_test_data = open(args.predict_file, mode='r')
        raw_train_data = open(args.train_file, mode='r')
        if os.path.exists("train_file_baseline.pkl") and False:
            train_examples = cPickle.load(
                open("train_file_baseline.pkl", mode='rb'))
        else:
            ans_dict = {}
            with open(args.train_ans_file) as f:
                for line in f:
                    line = line.split(',')
                    ans_dict[line[0]] = int(line[1])
            train_examples = read_chid_examples(raw_train_data,
                                                is_training=True,
                                                ans_dict=ans_dict)
            cPickle.dump(train_examples,
                         open("newtrain_file_baseline.pkl", mode='wb'))

        #tt = len(train_examples) // 2
        #train_examples = train_examples[:tt]

        logger.info("train examples {}".format(len(train_examples)))
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    bert_config = BertConfig.from_json_file(args.bert_config_file)
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)
    model = BertForCloze(bert_config, num_choices=10)
    if args.init_checkpoint is not None:
        logger.info('load bert weight')
        state_dict = torch.load(args.init_checkpoint, map_location='cpu')
        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        # new_state_dict=state_dict.copy()
        # for kye ,value in state_dict.items():
        #     new_state_dict[kye.replace("bert","c_bert")]=value
        # state_dict=new_state_dict
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(
                prefix[:-1], {})

            module._load_from_state_dict(state_dict, prefix, local_metadata,
                                         True, missing_keys, unexpected_keys,
                                         error_msgs)
            for name, child in module._modules.items():
                # logger.info("name {} chile {}".format(name,child))
                if child is not None:
                    load(child, prefix + name + '.')

        load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
        logger.info("missing keys:{}".format(missing_keys))
        logger.info('unexpected keys:{}'.format(unexpected_keys))
        logger.info('error msgs:{}'.format(error_msgs))
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex import amp
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate)
        # optimizer = BertAdam(optimizer_grouped_parameters,
        #                      lr=args.learning_rate,
        #                      warmup=args.warmup_proportion,
        #                      t_total=t_total)
        # optimizer = RAdam(optimizer_grouped_parameters,
        #                      lr=args.learning_rate)
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.restore:
        checkpoint = torch.load('amp_checkpoint.pt')

        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        amp.load_state_dict(checkpoint['amp'])

    global_step = 0
    if args.do_train:
        cached_train_features_file = args.train_file + '_{0}_v{1}'.format(
            str(args.max_seq_length), str(4))
        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except:
            train_features = convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length)

            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s",
                            cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_labels = torch.tensor([f.label for f in train_features],
                                  dtype=torch.long)
        all_option_ids = torch.tensor([f.option_ids for f in train_features],
                                      dtype=torch.long)
        all_positions = torch.tensor([f.position for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_labels, all_option_ids,
                                   all_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size,
                                      drop_last=True,
                                      pin_memory=True)
        loss_ini = 50
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            vizname = 'epoch' + str(_)
            viz = Visdom(env=str(vizname))
            vis = Visdom(env='loss')
            via = Visdom(env='ac')
            model.train()
            model.zero_grad()
            epoch_itorator = tqdm(train_dataloader, disable=None)
            for step, batch in enumerate(epoch_itorator):
                if n_gpu == 1:
                    batch = tuple(
                        t.to(device)
                        for t in batch)  # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, labels, option_ids, positions = batch
                loss = model(input_ids, option_ids, segment_ids, input_mask,
                             positions, labels)
                #                print('att', loss.size())
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    #                    model, optimizer = amp.initialize(model, optimizer, opt_level= "O1")
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    # if args.fp16:
                    #      torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.)
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (step + 1) % 1000 == 0:
                    logger.info("loss@{}:{}".format(step, loss.cpu().item()))
                steptotal = step + _ * int(
                    len(train_examples) / args.train_batch_size)
                if (steptotal + 1) % 50 == 0:
                    vis.line([loss.cpu().item()], [steptotal],
                             win='train_loss',
                             update='append')
                if (step + 1) % 50 == 0:
                    viz.line([loss.cpu().item()], [step],
                             win='train_loss',
                             update='append')
            loss_total = str(loss.cpu().item())
            print(loss_total)
            loss_ini = loss_total
            logger.info("loss:%f", loss.cpu().item())
            logger.info("loss+:{}".format(loss.cpu().item()))
            raw_test_data_pre = open(args.predict_file, mode='r')
            eval_examples = read_chid_examples(raw_test_data_pre,
                                               is_training=False)
            # eval_examples=eval_examples[:100]
            eval_features = convert_examples_to_features(
                examples=eval_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length)

            logger.info("***** Running predictions *****")
            logger.info("  Num orig examples = %d", len(eval_examples))
            logger.info("  Num split examples = %d", len(eval_features))
            logger.info("  Batch size = %d", args.predict_batch_size)

            all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                         dtype=torch.long)
            all_input_mask = torch.tensor(
                [f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor(
                [f.segment_ids for f in eval_features], dtype=torch.long)
            all_option_ids = torch.tensor(
                [f.option_ids for f in eval_features], dtype=torch.long)
            all_positions = torch.tensor([f.position for f in eval_features],
                                         dtype=torch.long)
            all_tags = torch.tensor([f.tag for f in eval_features],
                                    dtype=torch.long)
            eval_data = TensorDataset(all_input_ids, all_input_mask,
                                      all_segment_ids, all_option_ids,
                                      all_positions, all_tags)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data,
                                         sampler=eval_sampler,
                                         batch_size=args.predict_batch_size)

            model.eval()
            reader1 = pd.read_csv('dev_answer1.csv', usecols=[1], header=None)
            total_dev_loss = 0
            all_results = {}
            logger.info("Start evaluating")
            for input_ids, input_mask, segment_ids, option_ids, positions, tags in \
                    tqdm(eval_dataloader, desc="Evaluating",disable=None):
                if len(all_results) % 1000 == 0:
                    logger.info("Processing example: %d" % (len(all_results)))
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                option_ids = option_ids.to(device)
                positions = positions.to(device)
                with torch.no_grad():
                    batch_logits, align = model(input_ids, option_ids,
                                                segment_ids, input_mask,
                                                positions)
                for i, tag in enumerate(tags):
                    logits = batch_logits[i].detach().cpu().numpy()
                    logit = [logits]
                    logit = torch.tensor(logit)

                    inum = int(tag) - 577157
                    dlabel = [reader1[1][inum]]
                    dlabel = torch.tensor(dlabel)
                    # loss_dev =FocalLoss(gamma=0.25)
                    loss_dev = CrossEntropyLoss()
                    dev_loss = loss_dev(logit, dlabel)
                    total_dev_loss += dev_loss
                    #  for index1, dlabel in zip(reader1[0], reader1[1]):
                    #      if index1[6:11] == str(tag):
                    #          loss_dev =CrossEntropyLoss()
                    #          dev_loss = loss_dev(logits, dlabel)
                    #          total_dev_loss += dev_loss
                    #          continue
                    ans = np.argmax(logits)
                    all_results["#idiom%06d#" % tag] = ans

            predict_name = "ln11saprediction" + str(_) + ".csv"
            output_prediction_file = os.path.join(args.output_dir,
                                                  predict_name)
            with open(output_prediction_file, "w") as f:
                for each in all_results:
                    f.write(each + ',' + str(all_results[each]) + "\n")
            raw_test_data.close()
            pre_ac = 0
            outputpre = 'output_model/' + predict_name
            reader2 = pd.read_csv(outputpre, usecols=[0, 1], header=None)

            for index2, ans2 in zip(reader2[0], reader2[1]):
                num = index2[6:12]
                num = int(num) - 577157
                ans1 = reader1[1][num]
                if ans1 == ans2:
                    pre_ac += 1
            print(pre_ac)
            per = (pre_ac) / 23011
            pernum = per * 100
            logger.info("accuracy:%f", pernum)
            devlossmean = total_dev_loss / (23011 / 128)
            logger.info("devloss:%f", devlossmean)
            via.line([pernum], [_], win='accuracy', update='append')
            via.line([devlossmean], [_], win='loss', update='append')
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict()
            }
            torch.save(checkpoint, 'checkpoint/amp_checkpoint.pt')

            outmodel = 'ln11samodel' + str(pernum) + '.bin'
            output_model_file = os.path.join(args.output_dir, outmodel)
            if args.do_train:
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                torch.save(model_to_save.state_dict(), output_model_file)
        raw_test_data.close()
        raw_train_data.close()


# Save a trained model
#    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
#    if args.do_train:
#        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
#        torch.save(model_to_save.state_dict(), output_model_file)

# Load a trained model that you have fine-tuned

    if args.do_predict and (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
        list1 = os.listdir('./output_model/')
        list1 = sorted(
            list1,
            key=lambda x: os.path.getmtime(os.path.join('./output_model/', x)))
        output_model_file = os.path.join(args.output_dir, list1[-1])
        # output_model_file = os.path.join(args.output_dir, 'n11samodel77.33258007040111.bin')
        model_state_dict = torch.load(output_model_file)
        model = BertForCloze(bert_config, num_choices=10)
        model.load_state_dict(model_state_dict)
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
        # raw_test_data_pre = open('./data/dev.txt', mode='r')
        raw_test_data_pre = open('./data/out.txt', mode='r')
        # raw_test_data_pre = open('new_test_data.txt', mode='r')
        eval_examples = read_chid_examples(raw_test_data_pre,
                                           is_training=False)
        # eval_examples=eval_examples[:100]
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_option_ids = torch.tensor([f.option_ids for f in eval_features],
                                      dtype=torch.long)
        all_positions = torch.tensor([f.position for f in eval_features],
                                     dtype=torch.long)
        all_tags = torch.tensor([f.tag for f in eval_features],
                                dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_option_ids,
                                  all_positions, all_tags)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.predict_batch_size)

        model.eval()
        all_results = {}
        all_results1 = {}
        all_results2 = {}
        # reader1 = pd.read_csv('test_ans.csv', usecols=[1], header=None)
        reader1 = pd.read_csv('./data/out_answer.csv',
                              usecols=[1],
                              header=None)  #dev_answer1.csv
        # reader1 = pd.read_csv('dev_answer1.csv', usecols=[1], header=None)
        total_dev_loss = 0
        logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, option_ids, positions, tags in \
                tqdm(eval_dataloader, desc="Evaluating",disable=None):
            if len(all_results) % 1000 == 0:
                logger.info("Processing example: %d" % (len(all_results)))
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            option_ids = option_ids.to(device)
            positions = positions.to(device)
            with torch.no_grad():
                batch_logits, align = model(input_ids, option_ids, segment_ids,
                                            input_mask, positions)
            for i, tag in enumerate(tags):
                logits = batch_logits[i].detach().cpu().numpy()

                ans = np.argmax(logits)
                all_results["#idiom%06d#" % tag] = ans

                # matric = align[i].detach().cpu().numpy()
                # all_results1["#idiom%06d#" % tag] = matric[ans]
                # gr_logic = logits[:]
                # gr_logic = sorted(gr_logic, reverse=True)
                # all_results2["#idiom%06d#" % tag] = gr_logic

        output_prediction_file = os.path.join(args.output_dir,
                                              "testprediction.csv")
        # output_m_file = os.path.join(args.output_dir, "ealign.csv")
        # output_ma_file = os.path.join(args.output_dir, "sdmv.csv")

        with open(output_prediction_file, "w") as f:
            for each in all_results:
                f.write(each + ',' + str(all_results[each]) + "\n")
        # with open(output_m_file, "w") as f:
        #     for each in all_results1:
        #         f.write(each + ',' + str(all_results1[each]) + "\n")
        # with open(output_ma_file, "w") as f:
        #     for each in all_results1:
        #         f.write(each + ',' + str(all_results2[each]) + "\n")
        raw_test_data_pre.close()
        reader2 = pd.read_csv(output_prediction_file,
                              usecols=[0, 1],
                              header=None)
        pre_ac = 0
        for index2, ans2 in zip(reader2[0], reader2[1]):
            num = index2[6:-1]
            # num = int(num)-1
            # num = re.findall(r"\d+\.?\d*",index2)
            num = int(num) - 623377
            # num = int(num) - 577157
            ans1 = reader1[1][num]
            if ans1 == ans2:
                pre_ac += 1
        print(pre_ac)
        # per = (pre_ac)/23011
        # per = (pre_ac)/24948
        per = (pre_ac) / 27704
        pernum = per * 100
        logger.info("accuracy:%f", pernum)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument(
        "--test_set",
        default='story',
        type=str,
        #choices=['story', 'news', 'chat', 'train'],
        help="Choose the test set.")
    parser.add_argument("--no_logit_mask",
                        action='store_true',
                        help="Whether not to use logit mask")
    parser.add_argument("--eval_every_epoch",
                        action='store_true',
                        help="Whether to evaluate for every epoch")
    parser.add_argument("--use_weight",
                        action='store_true',
                        help="Whether to use class-balancing weight")
    parser.add_argument(
        "--state_dir",
        default="",
        type=str,
        help=
        "Where to load state dict instead of using Google pre-trained model")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    processor = DataProcessor(args.test_set)
    label_list = processor.get_labels(args.data_dir)
    num_labels = len(label_list)
    logger.info("num_labels:" + str(num_labels))
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        #raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
        if os.path.exists(os.path.join(args.output_dir, WEIGHTS_NAME)):
            output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
            config = BertConfig(output_config_file)
            model = PolyphonyLSTM(config, num_labels=num_labels)
            model.load_state_dict(torch.load(output_model_file))
        else:
            raise ValueError(
                "Output directory ({}) already exists but no model checkpoint was found."
                .format(args.output_dir))
    else:
        os.makedirs(args.output_dir, exist_ok=True)
        if args.state_dir and os.path.exists(args.state_dir):
            state_dict = torch.load(args.state_dir)
            print("Using my own BERT state dict.")
        elif args.state_dir and not os.path.exists(args.state_dir):
            print(
                "Warning: the state dict does not exist, using the Google pre-trained model instead."
            )
            state_dict = None
        else:
            state_dict = None
        model = PolyphonyLSTM.from_pretrained(args.bert_model,
                                              cache_dir=cache_dir,
                                              state_dict=state_dict,
                                              num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)
    if os.path.exists(os.path.join(args.output_dir, OPTIMIZER_NAME)):
        output_optimizer_file = os.path.join(args.output_dir, OPTIMIZER_NAME)
        optimizer.load_state_dict(torch.load(output_optimizer_file))

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features, masks, weight = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        if args.eval_every_epoch:
            eval_examples = processor.get_dev_examples(args.data_dir)
            eval_features, masks, weight = convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer)

        if args.no_logit_mask:
            print("Remove logit mask")
            masks = None
        if not args.use_weight:
            weight = None
        print(weight)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        all_label_poss = torch.tensor([f.label_pos for f in train_features],
                                      dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_label_ids, all_label_poss)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for ep in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, label_ids, label_poss = batch
                # print(masks.size())
                loss = model(input_ids,
                             input_mask,
                             label_ids,
                             logit_masks=masks,
                             weight=weight)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            if args.eval_every_epoch:
                # evaluate for every epoch
                # save model and load for a single GPU
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(args.output_dir,
                                                 WEIGHTS_NAME + '_' + str(ep))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optimizer_file = os.path.join(
                    args.output_dir, OPTIMIZER_NAME + '_' + str(ep))
                torch.save(optimizer.state_dict(), output_optimizer_file)
                output_config_file = os.path.join(args.output_dir,
                                                  CONFIG_NAME + '_' + str(ep))
                with open(output_config_file, 'w') as f:
                    f.write(model_to_save.config.to_json_string())

                # Load a trained model and config that you have fine-tuned
                config = BertConfig(output_config_file)
                model_eval = PolyphonyLSTM(config, num_labels=num_labels)
                model_eval.load_state_dict(torch.load(output_model_file))
                model_eval.to(device)

                if args.no_logit_mask:
                    print("Remove logit mask")
                    masks = None
                else:
                    masks = masks.to(device)
                chars = [f.char for f in eval_features]
                print(len(set(chars)), sorted(list(set(chars))))
                logger.info("***** Running evaluation *****")
                logger.info("  Num examples = %d", len(eval_examples))
                logger.info("  Batch size = %d", args.eval_batch_size)
                all_input_ids = torch.tensor(
                    [f.input_ids for f in eval_features], dtype=torch.long)
                all_input_mask = torch.tensor(
                    [f.input_mask for f in eval_features], dtype=torch.long)
                all_label_ids = torch.tensor(
                    [f.label_id for f in eval_features], dtype=torch.long)
                all_label_poss = torch.tensor(
                    [f.label_pos for f in eval_features], dtype=torch.long)
                eval_data = TensorDataset(all_input_ids, all_input_mask,
                                          all_label_ids, all_label_poss)
                # Run prediction for full data
                eval_sampler = SequentialSampler(eval_data)
                eval_dataloader = DataLoader(eval_data,
                                             sampler=eval_sampler,
                                             batch_size=args.eval_batch_size)

                model_eval.eval()
                eval_loss, eval_accuracy = 0, 0
                nb_eval_steps, nb_eval_examples = 0, 0

                res_list = []
                for input_ids, input_mask, label_ids, label_poss in tqdm(
                        eval_dataloader, desc="Evaluating"):
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    label_ids = label_ids.to(device)
                    label_poss = label_poss.to(device)
                    with torch.no_grad():
                        tmp_eval_loss = model_eval(input_ids,
                                                   input_mask,
                                                   label_ids,
                                                   logit_masks=masks)
                        logits = model_eval(input_ids,
                                            input_mask,
                                            label_ids,
                                            logit_masks=masks,
                                            cal_loss=False)
                    # print(logits.size())
                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    res_list += accuracy_list(logits, label_ids, label_poss)
                    eval_loss += tmp_eval_loss.mean().item()

                    nb_eval_examples += input_ids.size(0)
                    nb_eval_steps += 1

                eval_loss = eval_loss / nb_eval_steps
                loss = tr_loss / nb_tr_steps if args.do_train else None
                acc = sum(res_list) / len(res_list)
                char_count = {k: [] for k in list(set(chars))}
                for i, c in enumerate(chars):
                    char_count[c].append(res_list[i])
                char_acc = {
                    k: sum(char_count[k]) / len(char_count[k])
                    for k in char_count
                }

                result = {
                    'epoch': ep + 1,
                    'eval_loss': eval_loss,
                    'eval_accuracy': acc,
                    'global_step': global_step,
                    'loss': loss
                }
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                output_eval_file = os.path.join(
                    args.output_dir, "epoch_" + str(ep + 1) + ".txt")
                with open(output_eval_file, 'w') as f:
                    f.write(
                        json.dumps(result, ensure_ascii=False) + '\n' +
                        json.dumps(char_acc, ensure_ascii=False))

                # multi processing
                # if n_gpu > 1:
                #    model = torch.nn.DataParallel(model)

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_optimizer_file = os.path.join(args.output_dir, OPTIMIZER_NAME)
        torch.save(optimizer.state_dict(), output_optimizer_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = PolyphonyLSTM(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        # model = BertForPolyphonyMulti.from_pretrained(args.bert_model, num_labels = num_labels)
        pass
    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        model = PolyphonyLSTM(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
        model.to(device)

        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features, masks, weight = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer)
        if args.no_logit_mask:
            print("Remove logit mask")
            masks = None
        else:
            masks = masks.to(device)
        chars = [f.char for f in eval_features]
        print(len(set(chars)), sorted(list(set(chars))))
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        all_label_poss = torch.tensor([f.label_pos for f in eval_features],
                                      dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_label_ids,
                                  all_label_poss)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        res_list = []
        # masks = masks.to(device)
        for input_ids, input_mask, label_ids, label_poss in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            label_ids = label_ids.to(device)
            label_poss = label_poss.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids,
                                      input_mask,
                                      label_ids,
                                      logit_masks=masks)
                logits = model(input_ids,
                               input_mask,
                               label_ids,
                               logit_masks=masks,
                               cal_loss=False)
            # print(logits.size())
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)
            res_list += accuracy_list(logits, label_ids, label_poss)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        acc = sum(res_list) / len(res_list)
        char_count = {k: [] for k in list(set(chars))}
        for i, c in enumerate(chars):
            char_count[c].append(res_list[i])
        char_acc = {
            k: sum(char_count[k]) / len(char_count[k])
            for k in char_count
        }

        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': loss,
            'acc': acc
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
            for key in sorted(char_acc.keys()):
                logger.info("  %s = %s", key, str(char_acc[key]))
                writer.write("%s = %s\n" % (key, str(char_acc[key])))
        print("mean accuracy",
              sum(char_acc[c] for c in char_acc) / len(char_acc))
        output_acc_file = os.path.join(args.output_dir,
                                       args.test_set + ".json")
        output_reslist_file = os.path.join(args.output_dir,
                                           args.test_set + "reslist.json")
        with open(output_acc_file, "w") as f:
            f.write(json.dumps(char_acc, ensure_ascii=False))
        with open(output_reslist_file, "w") as f:
            f.write(json.dumps(res_list, ensure_ascii=False))
Beispiel #9
0
def main():

    print("IN NEW MAIN XD\n")
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--input_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain .hdf5 files  for the task.")

    parser.add_argument("--config_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The BERT model config")

    parser.add_argument(
        "--bert_model",
        default="bert-large-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--max_predictions_per_seq",
        default=80,
        type=int,
        help="The maximum total of masked tokens in input sequence")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps",
                        default=1000,
                        type=float,
                        help="Total number of training steps to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.01,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0.0,
        help=
        'Loss scaling, positive power of 2 values can improve fp16 convergence.'
    )
    parser.add_argument('--log_freq',
                        type=float,
                        default=50.0,
                        help='frequency of logging loss.')
    parser.add_argument('--checkpoint_activations',
                        default=False,
                        action='store_true',
                        help="Whether to use gradient checkpointing")
    parser.add_argument("--resume_from_checkpoint",
                        default=False,
                        action='store_true',
                        help="Whether to resume training from checkpoint.")
    parser.add_argument('--resume_step',
                        type=int,
                        default=-1,
                        help="Step to resume training from.")
    parser.add_argument(
        '--num_steps_per_checkpoint',
        type=int,
        default=2000,
        help="Number of update steps until a model checkpoint is saved to disk."
    )

    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    assert (torch.cuda.is_available())

    if args.local_rank == -1:
        device = torch.device("cuda")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    if args.train_batch_size % args.gradient_accumulation_steps != 0:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible"
            .format(args.gradient_accumulation_steps, args.train_batch_size))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if not args.resume_from_checkpoint and os.path.exists(
            args.output_dir) and (os.listdir(args.output_dir) and os.listdir(
                args.output_dir) != ['logfile.txt']):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))

    if not args.resume_from_checkpoint:
        os.makedirs(args.output_dir, exist_ok=True)

    # Prepare model
    config = BertConfig.from_json_file(args.config_file)
    model = BertForPreTraining(config)

    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1:
            model_names = [
                f for f in os.listdir(args.output_dir) if f.endswith(".pt")
            ]
            args.resume_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step

        checkpoint = torch.load(os.path.join(args.output_dir,
                                             "ckpt_{}.pt".format(global_step)),
                                map_location="cpu")
        model.load_state_dict(checkpoint['model'], strict=False)

        print("resume step from ", args.resume_step)

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:

        optimizer = FusedAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            #warmup=args.warmup_proportion,
            #t_total=args.max_steps,
            bias_correction=False,
            weight_decay=0.01,
            max_grad_norm=1.0)

        if args.loss_scale == 0:
            # optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              keep_batchnorm_fp32=False,
                                              loss_scale="dynamic")
        else:
            # optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              keep_batchnorm_fp32=False,
                                              loss_scale=args.loss_scale)

        scheduler = LinearWarmUpScheduler(optimizer,
                                          warmup=args.warmup_proportion,
                                          total_steps=args.max_steps)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=args.max_steps)

    if args.resume_from_checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

    if args.local_rank != -1:
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    files = [
        os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
        if os.path.isfile(os.path.join(args.input_dir, f))
    ]
    files.sort()

    num_files = len(files)

    logger.info("***** Running training *****")
    # logger.info("  Num examples = %d", len(train_data))
    logger.info("  Batch size = %d", args.train_batch_size)
    print("  LR = ", args.learning_rate)

    model.train()
    print("Training. . .")

    most_recent_ckpts_paths = []

    print("Training. . .")
    tr_loss = 0.0  # total added training loss
    average_loss = 0.0  # averaged loss every args.log_freq steps
    epoch = 0
    training_steps = 0
    while True:
        if not args.resume_from_checkpoint:
            random.shuffle(files)
            f_start_id = 0
        else:
            f_start_id = checkpoint['files'][0]
            files = checkpoint['files'][1:]
            args.resume_from_checkpoint = False
        for f_id in range(f_start_id, len(files)):
            data_file = files[f_id]
            logger.info("file no %s file %s" % (f_id, data_file))
            train_data = pretraining_dataset(
                input_file=data_file,
                max_pred_length=args.max_predictions_per_seq)

            if args.local_rank == -1:
                train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(
                    train_data,
                    sampler=train_sampler,
                    batch_size=args.train_batch_size * n_gpu,
                    num_workers=4,
                    pin_memory=True)
            else:
                train_sampler = DistributedSampler(train_data)
                train_dataloader = DataLoader(train_data,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              num_workers=4,
                                              pin_memory=True)

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="File Iteration")):

                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch  #\
                loss = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    attention_mask=input_mask,
                    masked_lm_labels=masked_lm_labels,
                    next_sentence_label=next_sentence_labels,
                    checkpoint_activations=args.checkpoint_activations)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    #   optimizer.backward(loss)
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                tr_loss += loss
                average_loss += loss.item()

                if training_steps % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        scheduler.step()
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if training_steps == 1 * args.gradient_accumulation_steps:
                    logger.info(
                        "Step:{} Average Loss = {} Step Loss = {} LR {}".
                        format(global_step, average_loss, loss.item(),
                               optimizer.param_groups[0]['lr']))

                if training_steps % (args.log_freq *
                                     args.gradient_accumulation_steps) == 0:
                    logger.info(
                        "Step:{} Average Loss = {} Step Loss = {} LR {}".
                        format(global_step, average_loss / args.log_freq,
                               loss.item(), optimizer.param_groups[0]['lr']))
                    average_loss = 0

                if global_step >= args.max_steps or training_steps % (
                        args.num_steps_per_checkpoint *
                        args.gradient_accumulation_steps) == 0:

                    if (not torch.distributed.is_initialized()
                            or (torch.distributed.is_initialized()
                                and torch.distributed.get_rank() == 0)):
                        # Save a trained model
                        logger.info(
                            "** ** * Saving fine - tuned model ** ** * ")
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_save_file = os.path.join(
                            args.output_dir, "ckpt_{}.pt".format(global_step))

                        torch.save(
                            {
                                'model': model_to_save.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'files': [f_id] + files
                            }, output_save_file)

                        most_recent_ckpts_paths.append(output_save_file)
                        if len(most_recent_ckpts_paths) > 3:
                            ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
                            os.remove(ckpt_to_be_removed)

                    if global_step >= args.max_steps:
                        tr_loss = tr_loss * args.gradient_accumulation_steps / training_steps
                        if (torch.distributed.is_initialized()):
                            tr_loss /= torch.distributed.get_world_size()
                            torch.distributed.all_reduce(tr_loss)
                        logger.info("Total Steps:{} Final Loss = {}".format(
                            training_steps, tr_loss.item()))
                        return
            del train_dataloader
            del train_sampler
            del train_data
            #for obj in gc.get_objects():
            #  if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            #    del obj

            torch.cuda.empty_cache()
        epoch += 1
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input train corpus.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--on_memory",
        action='store_true',
        help="Whether to load train samples into memory or use disk")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--hybrid_attention",
                        action='store_true',
                        help="Whether to use hybrid attention")
    parser.add_argument("--continue_training",
                        action='store_true',
                        help="Continue training from a checkpoint")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train:
        raise ValueError(
            "Training is currently the only implemented execution option. Please set `do_train`."
        )

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and not args.continue_training:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    #train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        print("Loading Train Dataset", args.train_file)
        train_dataset = BERTDataset(args.train_file,
                                    tokenizer,
                                    seq_len=args.max_seq_length,
                                    corpus_lines=None,
                                    on_memory=args.on_memory)
        num_train_optimization_steps = int(
            len(train_dataset) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForMaskedLM.from_pretrained(args.bert_model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    if args.hybrid_attention:
        max_seq_length = args.max_seq_length
        attention_mask = torch.ones(12,
                                    max_seq_length,
                                    max_seq_length,
                                    dtype=torch.long)
        # left attention
        attention_mask[:2, :, :] = torch.tril(
            torch.ones(max_seq_length, max_seq_length, dtype=torch.long))
        # right attention
        attention_mask[2:4, :, :] = torch.triu(
            torch.ones(max_seq_length, max_seq_length, dtype=torch.long))
        # local attention, window size = 3
        attention_mask[4:6, :, :] = torch.triu(
            torch.tril(
                torch.ones(max_seq_length, max_seq_length, dtype=torch.long),
                1), -1)
        attention_mask = torch.cat(
            [attention_mask.unsqueeze(0) for _ in range(8)])
        attention_mask = attention_mask.to(device)
    else:
        attention_mask = None

    global_step = 0
    epoch_start = 0
    if args.do_train:
        if args.continue_training:
            # if checkpoint file exists, find the last checkpoint
            if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
                all_cp = os.listdir(args.output_dir)
                steps = [
                    int(re.search('_\d+', cp).group()[1:]) for cp in all_cp
                    if re.search('_\d+', cp)
                ]
                if len(steps) == 0:
                    raise ValueError(
                        "No existing checkpoint. Please do not use --continue_training."
                    )
                max_step = max(steps)
                # load checkpoint
                checkpoint = torch.load(
                    os.path.join(args.output_dir,
                                 'checkpoints_' + str(max_step) + '.pt'))
                logger.info("***** Loading checkpoint *****")
                logger.info("  Num steps = %d", checkpoint['global_step'])
                logger.info("  Num epoch = %d", checkpoint['epoch'])
                logger.info("  Loss = %d, %d", checkpoint['loss'],
                            checkpoint['loss_now'])
                model.module.load_state_dict(checkpoint['model'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                global_step = checkpoint['global_step']
                epoch_start = checkpoint['epoch']
                del checkpoint
            else:
                raise ValueError(
                    "No existing checkpoint. Please do not use --continue_training."
                )

        writer = SummaryWriter(log_dir=os.environ['HOME'])
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            #TODO: check if this works with current data generator from disk that relies on next(file)
            # (it doesn't return item back by index)
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        tr_loss_1000 = 0
        for ep in trange(epoch_start, int(args.num_train_epochs),
                         desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch
                loss = model(input_ids,
                             segment_ids,
                             input_mask,
                             lm_label_ids,
                             hybrid_mask=attention_mask)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                tr_loss += loss.item()
                tr_loss_1000 += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                # log the training loss for every 1000 steps
                if global_step % 1000 == 999:
                    writer.add_scalar('data/loss', tr_loss_1000 / 1000,
                                      global_step)
                    logger.info("training steps: %s", global_step)
                    logger.info("training loss per 1000: %s",
                                tr_loss_1000 / 1000)
                    tr_loss_1000 = 0
                # save the checkpoint for every 10000 steps
                if global_step % 10000 == 0:
                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self
                    output_file = os.path.join(
                        args.output_dir,
                        "checkpoints_" + str(global_step) + ".pt")
                    checkpoint = {
                        'model': model_to_save.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': ep,
                        'global_step': global_step,
                        'loss': tr_loss / nb_tr_steps,
                        'loss_now': tr_loss_1000
                    }
                    if args.do_train:
                        torch.save(checkpoint, output_file)
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(args.output_dir,
                                             "pytorch_model.bin_" + str(ep))
            if args.do_train:
                torch.save(model_to_save.state_dict(), output_model_file)
            logger.info("training loss: %s", tr_loss / nb_tr_steps)

        # Save a trained model
        logger.info("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
        if args.do_train:
            torch.save(model_to_save.state_dict(), output_model_file)
Beispiel #11
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.",
    )
    parser.add_argument(
        "--config_path", default=None, type=str, help="Bert config file path."
    )
    parser.add_argument(
        "--output_dir",
        default="tmp",
        type=str,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--log_file",
        default="eval.log",
        type=str,
        help="The output directory where the log will be written.",
    )
    parser.add_argument(
        "--model_recover_path",
        default=None,
        type=str,
        help="The file of fine-tuned pretraining model.",
    )
    parser.add_argument(
        "--do_train",
        action="store_true",
        help="Whether to run training. This should ALWAYS be set to True.",
    )
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model.",
    )
    parser.add_argument(
        "--train_batch_size",
        default=64,
        type=int,
        help="Total batch size for training.",
    )
    parser.add_argument(
        "--learning_rate",
        default=3e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--label_smoothing",
        default=0,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--weight_decay",
        default=0.01,
        type=float,
        help="The weight decay rate for Adam.",
    )
    parser.add_argument(
        "--finetune_decay",
        action="store_true",
        help="Weight decay to the original weights.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=30,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--global_rank",
        type=int,
        default=-1,
        help="global_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--fp32_embedding",
        action="store_true",
        help="Whether to use 32-bit float precision instead of 32-bit for embeddings",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--amp", action="store_true", help="Whether to use amp for fp16"
    )
    parser.add_argument(
        "--from_scratch",
        action="store_true",
        help="Initialize parameters with random values (i.e., training from scratch).",
    )
    parser.add_argument(
        "--new_segment_ids",
        action="store_true",
        help="Use new segment ids for bi-uni-directional LM.",
    )
    parser.add_argument(
        "--tokenized_input", action="store_true", help="Whether the input is tokenized."
    )
    parser.add_argument(
        "--len_vis_input",
        type=int,
        default=100,
        help="The length of visual token input",
    )
    parser.add_argument(
        "--max_len_b",
        type=int,
        default=20,
        help="Truncate_config: maximum length of segment B.",
    )
    parser.add_argument(
        "--trunc_seg",
        default="b",
        help="Truncate_config: first truncate segment A/B (option: a, b).",
    )
    parser.add_argument(
        "--always_truncate_tail",
        action="store_true",
        help="Truncate_config: Whether we should always truncate tail.",
    )
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help="Number of prediction is sometimes less than max_pred when sequence is short.",
    )
    parser.add_argument(
        "--max_pred", type=int, default=3, help="Max tokens of prediction."
    )
    parser.add_argument(
        "--num_workers",
        default=4,
        type=int,
        help="Number of workers for the data loader.",
    )
    parser.add_argument(
        "--max_position_embeddings",
        type=int,
        default=None,
        help="max position embeddings",
    )

    # Others for VLP
    parser.add_argument(
        "--src_file",
        default=["/mnt/dat/COCO/annotations/dataset_coco.json"],
        type=str,
        nargs="+",
        help="The input data file name.",
    )
    parser.add_argument("--enable_visdom", action="store_true")
    parser.add_argument("--visdom_port", type=int, default=8888)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument("--image_root", type=str, default="/mnt/dat/COCO/images")
    parser.add_argument(
        "--dataset", default="coco", type=str, help="coco | flickr30k | cc"
    )
    parser.add_argument("--split", type=str, nargs="+", default=["train", "restval"])

    parser.add_argument(
        "--world_size", default=1, type=int, help="number of distributed processes"
    )
    parser.add_argument(
        "--dist_url",
        default="file://[PT_OUTPUT_DIR]/nonexistent_file",
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument(
        "--file_valid_jpgs",
        default="/mnt/dat/COCO/annotations/coco_valid_jpgs.json",
        type=str,
    )
    parser.add_argument(
        "--sche_mode",
        default="warmup_linear",
        type=str,
        help="warmup_linear | warmup_constant | warmup_cosine",
    )
    parser.add_argument("--drop_prob", default=0.1, type=float)
    parser.add_argument("--use_num_imgs", default=-1, type=int)
    parser.add_argument("--vis_mask_prob", default=0, type=float)
    parser.add_argument("--max_drop_worst_ratio", default=0, type=float)
    parser.add_argument("--drop_after", default=6, type=int)

    parser.add_argument(
        "--s2s_prob",
        default=1,
        type=float,
        help="Percentage of examples that are bi-uni-directional LM (seq2seq).",
    )
    parser.add_argument(
        "--bi_prob",
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.",
    )
    parser.add_argument(
        "--enable_butd", action="store_true", help="set to take in region features"
    )
    parser.add_argument(
        "--region_bbox_file",
        default="coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5",
        type=str,
    )
    parser.add_argument(
        "--region_det_file_prefix",
        default="feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval",
        type=str,
    )
    parser.add_argument("--tasks", default="img2txt", help="img2txt | vqa2")
    parser.add_argument(
        "--relax_projection",
        action="store_true",
        help="Use different projection layers for tasks.",
    )
    parser.add_argument(
        "--scst", action="store_true", help="Self-critical sequence training"
    )

    args = parser.parse_args()

    print("global_rank: {}, local rank: {}".format(args.global_rank, args.local_rank))

    args.max_seq_length = (
        args.max_len_b + args.len_vis_input + 3
    )  # +3 for 2x[SEP] and [CLS]
    args.mask_image_regions = (
        args.vis_mask_prob > 0
    )  # whether to mask out image regions
    args.dist_url = args.dist_url.replace("[PT_OUTPUT_DIR]", args.output_dir)

    # arguments inspection
    assert args.tasks in ("img2txt", "vqa2")
    assert args.enable_butd == True, "only support region attn! featmap attn deprecated"
    assert (not args.scst) or args.dataset == "coco", "scst support on coco only!"
    if args.scst:
        assert args.dataset == "coco", "scst support on coco only!"
        assert args.max_pred == 0 and args.mask_prob == 0, "no mask for scst!"
        rl_crit = RewardCriterion()

    if args.enable_butd:
        assert args.len_vis_input == 100
        args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file)
        args.region_det_file_prefix = (
            os.path.join(args.image_root, args.region_det_file_prefix)
            if args.dataset in ("cc", "coco") and args.region_det_file_prefix != ""
            else ""
        )

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(
        args.__dict__,
        open(os.path.join(args.output_dir, "eval_opt.json"), "w"),
        sort_keys=True,
        indent=2,
    )

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode="w",
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(
            backend="nccl",
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.global_rank,
        )
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                args.gradient_accumulation_steps
            )
        )

    args.train_batch_size = int(
        args.train_batch_size / args.gradient_accumulation_steps
    )

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom

        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {"iter": None, "score": None}

    # preprocessing/data loader
    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank),
    )
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer

    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    "max_len_b": args.max_len_b,
                    "trunc_seg": args.trunc_seg,
                    "always_truncate_tail": args.always_truncate_tail,
                },
                mask_image_regions=args.mask_image_regions,
                mode="s2s",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == "vqa2"),
            )
        ]
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    "max_len_b": args.max_len_b,
                    "trunc_seg": args.trunc_seg,
                    "always_truncate_tail": args.always_truncate_tail,
                },
                mask_image_regions=args.mask_image_regions,
                mode="bi",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == "vqa2"),
            )
        )

        train_dataset = seq2seq_loader.Img2txtDataset(
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,
            bi_prob=args.bi_prob,
            enable_butd=args.enable_butd,
            tasks=args.tasks,
        )

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True,
        )

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(
        len(train_dataloader)
        * args.num_train_epochs
        * 1.0
        / args.gradient_accumulation_steps
    )

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp

        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == "img2txt" else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"]
    )  # index in BERT vocab: 103, 102, 0

    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, "must init from maximum likelihood training"
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir
            + "/.pretrained_model_{}".format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks,
        )
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(
                os.path.join(args.output_dir, "model.{0}.bin".format(recover_step))
            )
            # recover_step == number of epochs
            global_step = math.floor(
                recover_step * t_total * 1.0 / args.num_train_epochs
            )
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****", args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0
        if not args.scst:
            model = BertForPreTrainingLossMask.from_pretrained(
                args.bert_model,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                relax_projection=relax_projection,
                config_path=args.config_path,
                task_idx=task_idx_proj,
                max_position_embeddings=args.max_position_embeddings,
                label_smoothing=args.label_smoothing,
                fp32_embedding=args.fp32_embedding,
                cache_dir=args.output_dir
                + "/.pretrained_model_{}".format(args.global_rank),
                drop_prob=args.drop_prob,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
                tasks=args.tasks,
            )
        else:
            model = BertForSeq2SeqDecoder.from_pretrained(
                args.bert_model,
                max_position_embeddings=args.max_position_embeddings,
                config_path=args.config_path,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                task_idx=task_idx_proj,
                mask_word_id=mask_word_id,
                search_beam_size=1,
                eos_id=eos_word_ids,
                mode="s2s",
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
            )

        del model_recover
        torch.cuda.empty_cache()

    # deprecated
    # from vlp.resnet import resnet
    # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    # cnn.to(device)
    if args.local_rank != -1:
        try:
            # from apex.parallel import DistributedDataParallel as DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )
        # cnn = DDP(cnn)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)
        # cnn = DataParallelImbalance(cnn)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.01,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            bias_correction=False,
            max_grad_norm=1.0,
        )
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(
                optimizer, static_loss_scale=args.loss_scale
            )
    else:
        optimizer = BertAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            warmup=args.warmup_proportion,
            schedule=args.sche_mode,
            t_total=t_total,
        )

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step))
        )
        if hasattr(optim_recover, "state_dict"):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        model.eval()

        losses = []
        for batch in tqdm(train_dataloader):
            # wrangle batch
            batch = [t.to(device) for t in batch]
            (
                input_ids,
                segment_ids,
                input_mask,
                lm_label_ids,
                masked_pos,
                masked_weights,
                is_next,
                task_idx,
                img,
                vis_masked_pos,
                vis_pe,
                ans_labels,
            ) = batch

            if args.fp16:
                img = img.half()
                vis_pe = vis_pe.half()

            if args.enable_butd:
                conv_feats = img.data  # Bx100x2048
                vis_pe = vis_pe.data
            else:
                conv_feats, _ = cnn(img.data)  # Bx2048x7x7
                conv_feats = (
                    conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1)
                    .permute(0, 2, 1)
                    .contiguous()
                )

            # compute loss
            masked_lm_loss, _, _ = model(
                conv_feats,
                vis_pe,
                input_ids,
                segment_ids,
                input_mask,
                lm_label_ids,
                ans_labels,
                is_next,
                masked_pos=masked_pos,
                masked_weights=masked_weights,
                task_idx=task_idx,
                vis_masked_pos=vis_masked_pos,
                mask_image_regions=args.mask_image_regions,
                drop_worst_ratio=args.max_drop_worst_ratio
            )

            # average across multiple GPUs
            if n_gpu > 1:
                masked_lm_loss = masked_lm_loss.mean()

            losses.append(masked_lm_loss.item())
        
        print(args.split, 'perplexity:', np.exp(np.mean(losses)))
class Seq2SeqTrainer:
    """
    Seq2SeqTrainer
    """
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 loss_scaling={},
                 cuda=True,
                 distributed=False,
                 distributed_overlap_allreduce=False,
                 distributed_overlap_num_allreduce_streams=1,
                 distributed_overlap_allreduce_messagesize=1e7,
                 distributed_overlap_allreduce_communicators=None,
                 intra_epoch_eval=0,
                 prealloc_mode='always',
                 iter_size=1,
                 verbose=False):
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param batch_first: if True the model uses (batch,seq,feature) tensors,
            if false the model uses (seq, batch, feature)
        :param save_info: dict with additional state stored in each checkpoint
        :param save_path: path to the directiory for checkpoints
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param loss_scaling: options for dynamic loss scaling
        :param cuda: if True use cuda, if False train on cpu
        :param distributed: if True run distributed training
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param prealloc_mode: controls preallocation,
            choices=['off', 'once', 'always']
        :param iter_size: number of iterations between weight updates
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = None
        self.scheduler = None
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size
        self.prealloc_mode = prealloc_mode
        self.preallocated = False

        self.retain_allreduce_buffers = True
        self.gradient_average = False

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        params = self.model.parameters()
        if math == 'fp16':
            self.model = self.model.half()
            if distributed:
                self.model = DDP(
                    self.model,
                    message_size=distributed_overlap_allreduce_messagesize,
                    delay_allreduce=(not distributed_overlap_allreduce),
                    retain_allreduce_buffers=self.retain_allreduce_buffers,
                    gradient_average=self.gradient_average)
            self.fp_optimizer = Fp16Optimizer(
                self.model,
                grad_clip,
                loss_scale=loss_scaling['init_scale'],
                dls_upscale_interval=loss_scaling['upscale_interval'])
            params = [self.fp_optimizer.fp32_params]
        elif math == 'fp32':
            if distributed:
                self.model = DDP(
                    self.model,
                    message_size=distributed_overlap_allreduce_messagesize,
                    delay_allreduce=(not distributed_overlap_allreduce))
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            # params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        if opt_name == 'FusedAdam':
            if math == 'fp16' or math == 'fp32':
                self.optimizer = FusedAdam(params, **opt_config)
            else:
                self.optimizer = FusedAdam(
                    params,
                    use_mt=True,
                    max_grad_norm=grad_clip,
                    amp_scale_adjustment=get_world_size(),
                    **opt_config)
        else:
            self.optimizer = torch.optim.__dict__[opt_name](params,
                                                            **opt_config)
        if math == 'amp_fp16':
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                cast_model_outputs=torch.float16,
                keep_batchnorm_fp32=False,
                opt_level='O2')
            self.fp_optimizer = AMPOptimizer(
                self.model,
                grad_clip,
                loss_scale=loss_scaling['init_scale'],
                dls_upscale_interval=loss_scaling['upscale_interval'])
            if distributed:
                self.model = DDP(
                    self.model,
                    message_size=distributed_overlap_allreduce_messagesize,
                    delay_allreduce=(not distributed_overlap_allreduce),
                    num_allreduce_streams=
                    distributed_overlap_num_allreduce_streams,
                    allreduce_communicators=
                    distributed_overlap_allreduce_communicators,
                    retain_allreduce_buffers=self.retain_allreduce_buffers,
                    gradient_average=self.gradient_average)

        logging.info(f'Using optimizer: {self.optimizer}')

        mlperf_print(key=mlperf_compliance.constants.OPT_BASE_LR,
                     value=opt_config['lr'])

    def iterate(self, src, tgt, update=True, training=True):
        """
        Performs one iteration of the training/validation.

        :param src: batch of examples from the source language
        :param tgt: batch of examples from the target language
        :param update: if True: optimizer does update of the weights
        :param training: if True: executes optimizer
        """
        src, src_length = src
        tgt, tgt_length = tgt
        src_length = torch.LongTensor(src_length)
        tgt_length = torch.LongTensor(tgt_length)

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        if self.cuda:
            src = src.cuda()
            src_length = src_length.cuda()
            tgt = tgt.cuda()

        if self.batch_first:
            output = self.model(src, src_length, tgt[:, :-1])
            tgt_labels = tgt[:, 1:]
            T, B = output.size(1), output.size(0)
        else:
            output = self.model(src, src_length, tgt[:-1])
            tgt_labels = tgt[1:]
            T, B = output.size(0), output.size(1)

        loss = self.criterion(output.view(T * B, -1),
                              tgt_labels.contiguous().view(-1))

        loss_per_batch = loss.item()
        loss /= (B * self.iter_size)

        if training:
            self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                   update)

        loss_per_token = loss_per_batch / num_toks['tgt']
        loss_per_sentence = loss_per_batch / B

        return loss_per_token, loss_per_sentence, num_toks

    def feed_data(self, data_loader, training=True):
        """
        Runs training or validation on batches from data_loader.

        :param data_loader: data loader
        :param training: if True runs training else runs validation
        """
        if training:
            assert self.optimizer is not None
            eval_fractions = np.linspace(0, 1, self.intra_epoch_eval + 2)[1:-1]
            iters_with_update = len(data_loader) // self.iter_size
            eval_iters = (eval_fractions * iters_with_update).astype(int)
            eval_iters = eval_iters * self.iter_size
            eval_iters = set(eval_iters)

        batch_time = AverageMeter(skip_first=False)
        data_time = AverageMeter(skip_first=False)
        losses_per_token = AverageMeter(skip_first=False)
        losses_per_sentence = AverageMeter(skip_first=False)

        tot_tok_time = AverageMeter(skip_first=False)
        src_tok_time = AverageMeter(skip_first=False)
        tgt_tok_time = AverageMeter(skip_first=False)

        batch_size = data_loader.batch_size

        end = time.time()
        for i, (src, tgt) in enumerate(data_loader):
            self.save_counter += 1
            # measure data loading time
            data_time.update(time.time() - end)

            update = False
            if i % self.iter_size == self.iter_size - 1:
                update = True

            # do a train/evaluate iteration
            stats = self.iterate(src, tgt, update, training=training)
            loss_per_token, loss_per_sentence, num_toks = stats

            # measure accuracy and record loss
            losses_per_token.update(loss_per_token, num_toks['tgt'])
            losses_per_sentence.update(loss_per_sentence, batch_size)

            # measure elapsed time
            elapsed = time.time() - end
            batch_time.update(elapsed)
            src_tok_time.update(num_toks['src'] / elapsed)
            tgt_tok_time.update(num_toks['tgt'] / elapsed)
            tot_num_toks = num_toks['tgt'] + num_toks['src']
            tot_tok_time.update(tot_num_toks / elapsed)
            self.loss = losses_per_token.avg

            if training and i in eval_iters:
                assert self.translator is not None
                test_bleu, _ = self.translator.run(calc_bleu=True,
                                                   epoch=self.epoch,
                                                   iteration=i)

                log = []
                log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'BLEU: {test_bleu:.2f}']
                log = '\t'.join(log)
                logging.info(log)

                self.model.train()
                self.preallocate(data_loader.batch_size,
                                 data_loader.dataset.max_len,
                                 training=True)

            if i % self.print_freq == 0:
                phase = 'TRAIN' if training else 'VALIDATION'
                log = []
                log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})']
                log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})']
                log += [
                    f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})'
                ]
                if self.verbose:
                    log += [
                        f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})'
                    ]
                log += [
                    f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})'
                ]
                if training:
                    lr = self.optimizer.param_groups[0]['lr']
                    log += [f'LR {lr:.3e}']
                log = '\t'.join(log)
                logging.info(log)

            save_chkpt = (self.save_counter %
                          self.save_freq) == (self.save_freq - 1)
            if training and save_chkpt:
                self.save_counter = 0
                self.save_info['iteration'] = i
                identifier = next(self.checkpoint_counter, -1)
                if identifier != -1:
                    with sync_workers() as rank:
                        if rank == 0:
                            self.save(identifier=identifier)

            end = time.time()

        tot_tok_time.reduce('sum')
        losses_per_token.reduce('mean')

        return losses_per_token.avg, tot_tok_time.avg

    def preallocate(self, batch_size, max_length, training):
        """
        Generates maximum sequence length batch and runs forward and backward
        pass without updating model parameters.

        :param batch_size: batch size for preallocation
        :param max_length: max sequence length for preallocation
        :param training: if True preallocates memory for backward pass
        """
        if self.prealloc_mode == 'always' or (self.prealloc_mode == 'once'
                                              and not self.preallocated):
            logging.info('Executing preallocation')
            torch.cuda.empty_cache()

            src_length = [max_length] * batch_size
            tgt_length = [max_length] * batch_size

            if self.batch_first:
                shape = (batch_size, max_length)
            else:
                shape = (max_length, batch_size)

            src = torch.full(shape, 4, dtype=torch.int64)
            tgt = torch.full(shape, 4, dtype=torch.int64)
            src = src, src_length
            tgt = tgt, tgt_length
            self.iterate(src, tgt, update=False, training=training)
            self.model.zero_grad()
            self.preallocated = True

    def optimize(self, data_loader):
        """
        Sets model in training mode, preallocates memory and runs training on
        data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(True)
        self.model.train()
        self.preallocate(data_loader.batch_size,
                         data_loader.dataset.max_len,
                         training=True)

        output = self.feed_data(data_loader, training=True)

        self.model.zero_grad()
        return output

    def evaluate(self, data_loader):
        """
        Sets model in eval mode, disables gradients, preallocates memory and
        runs validation on data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(False)
        self.model.eval()
        self.preallocate(data_loader.batch_size,
                         data_loader.dataset.max_len,
                         training=False)

        output = self.feed_data(data_loader, training=False)

        self.model.zero_grad()
        return output

    def load(self, filename):
        """
        Loads checkpoint from filename.

        :param filename: path to the checkpoint file
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'})
            if self.distributed:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.fp_optimizer.initialize_model(self.model)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            assert self.scheduler is not None
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epoch = checkpoint['epoch']
            self.loss = checkpoint['loss']
            logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})')
        else:
            logging.error(f'Invalid checkpoint: {filename}')

    def save(self, identifier=None, is_best=False, save_all=False):
        """
        Stores checkpoint to a file.

        :param identifier: identifier for periodic checkpoint
        :param is_best: if True stores checkpoint to 'model_best.pth'
        :param save_all: if True stores checkpoint after completed training
            epoch
        """
        def write_checkpoint(state, filename):
            filename = os.path.join(self.save_path, filename)
            logging.info(f'Saving model to {filename}')
            torch.save(state, filename)

        if self.distributed:
            model_state = self.model.module.state_dict()
        else:
            model_state = self.model.state_dict()

        assert self.scheduler is not None
        state = {
            'epoch': self.epoch,
            'state_dict': model_state,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'loss': getattr(self, 'loss', None),
        }
        state = dict(list(state.items()) + list(self.save_info.items()))

        if identifier is not None:
            filename = self.checkpoint_filename % identifier
            write_checkpoint(state, filename)

        if is_best:
            filename = 'model_best.pth'
            write_checkpoint(state, filename)

        if save_all:
            filename = f'checkpoint_epoch_{self.epoch:03d}.pth'
            write_checkpoint(state, filename)
Beispiel #13
0
def main():
    parser = argparse.ArgumentParser(fromfile_prefix_chars="@")

    parser.add_argument("--pregenerated_data",
                        type=Path,
                        required=True,
                        help="The input train corpus.")

    parser.add_argument("--epochs", type=int, required=True)

    parser.add_argument("--bert_model", type=str, required=True)

    parser.add_argument("--bert_config_file",
                        type=str,
                        default="bert_config.json")
    parser.add_argument("--vocab_file", type=str, default="senti_vocab.txt")

    parser.add_argument('--output_dir', type=Path, required=True)

    parser.add_argument("--model_name", type=str, default="senti_base_model")

    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )

    parser.add_argument("--world_size", type=int, default=4)
    parser.add_argument("--start_rank", type=int, default=0)
    parser.add_argument("--server", type=str, default="tcp://127.0.0.1:1234")

    parser.add_argument("--load_model", action="store_true")
    parser.add_argument("--load_model_name", type=str, default="large_model")

    parser.add_argument("--save_step", type=int, default=100000)
    parser.add_argument("--train_batch_size",
                        default=4,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=1e-4,
                        type=float,
                        help="The initial learning rate for Adam.")

    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")

    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    assert args.pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"

    print("local_rank : ", args.local_rank)

    samples_per_epoch = []
    for i in range(args.epochs):
        epoch_file = args.pregenerated_data / f"epoch_{i}.json"
        metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit("No training data was found!")
            print(
                f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})."
            )
            print(
                "This script will loop over the available data, but training diversity may be negatively impacted."
            )
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.epochs

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=args.server,
                                             rank=args.local_rank +
                                             args.start_rank,
                                             world_size=args.world_size)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logger.warning(
            f"Output directory ({args.output_dir}) already exists and is not empty!"
        )
    args.output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = Tokenizer(
        os.path.join(args.bert_model, "senti_vocab.txt"),
        os.path.join(args.bert_model, "RoBERTa_Sentiment_kor"))

    total_train_examples = 0
    for i in range(args.epochs):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = math.ceil(total_train_examples /
                                             args.train_batch_size /
                                             args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = math.ceil(
            num_train_optimization_steps / torch.distributed.get_world_size())

    # Prepare model
    config = BertConfig.from_json_file(
        os.path.join(args.bert_model, args.bert_config_file))
    logger.info('{}'.format(config))
    ###############################################
    # Load Model
    if args.load_model:
        load_model_name = os.path.join(args.output_dir, args.load_model_name)
        model = BertForPreTraining.from_pretrained(
            args.bert_model,
            state_dict=torch.load(load_model_name)["state_dict"])
    else:
        model = BertForPreTraining(config)
    ###############################################

    if args.fp16:
        model.half()
    model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
            model = DDP(model)

        except ImportError:
            from torch.nn.parallel import DistributedDataParallel as DDP
            model = DDP(model,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank)

    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
        warmup_linear = WarmupLinearSchedule(
            warmup=args.warmup_proportion,
            t_total=num_train_optimization_steps)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)
    epoch0 = 0
    global_step = 0
    if args.load_model:
        ###############################################
        # Load Model
        logger.info(f"***** Load Model {args.load_model_name} *****")
        loaded_states = torch.load(os.path.join(args.output_dir,
                                                args.load_model_name),
                                   map_location=device)
        optimizer.load_state_dict(loaded_states["optimizer"])

        regex = re.compile(r'\d+epoch')
        epoch0 = int(
            regex.findall(args.load_model_name)[-1].replace('epoch', ''))
        logger.info('extract {} -> epoch0 : {}'.format(args.load_model_name,
                                                       epoch0))

        ###############################################

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {total_train_examples}")
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)

    model.train()
    # model.eval()
    for epoch in range(epoch0, args.epochs):
        epoch_dataset = PregeneratedDataset(
            epoch=epoch,
            training_path=args.pregenerated_data,
            tokenizer=tokenizer,
            num_data_epochs=num_data_epochs,
            reduce_memory=args.reduce_memory)
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)

        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        with tqdm(total=len(train_dataloader), desc='training..') as pbar:
            for step, batch in enumerate(train_dataloader):

                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, lm_label_ids = batch

                loss = model(input_ids, input_mask, lm_label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                tr_loss += loss.item()

                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                pbar.update(1)
                mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps

                if (step + 1) % 50 == 0:
                    pbar.set_description(
                        "Epoch = {}, global_step = {}, loss = {:.5f}".format(
                            epoch, global_step + 1, mean_loss))
                    logger.info(
                        "Epoch = {}, global_step = {}, loss = {:.5f}".format(
                            epoch, global_step + 1, mean_loss))

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (step + 1) % args.save_step == 0:
                    if args.local_rank == -1 or args.local_rank == 0:
                        logger.info(
                            "** ** * Saving {} - step model ** ** * ".format(
                                global_step))
                        output_model_file = os.path.join(
                            args.output_dir,
                            args.model_name + "_{}step".format(global_step))
                        model_to_save = model.module if hasattr(
                            model, 'module') else model
                        state = {
                            "state_dict": model_to_save.state_dict(),
                            "optimizer": optimizer.state_dict()
                        }
                        torch.save(state, output_model_file)

        if args.local_rank == -1 or args.local_rank == 0:
            logger.info(
                "** ** * Saving {} - epoch model ** ** * ".format(epoch))
            output_model_file = os.path.join(
                args.output_dir,
                args.model_name + "_{}epoch".format(epoch + 1))
            model_to_save = model.module if hasattr(model, 'module') else model
            state = {
                "state_dict": model_to_save.state_dict(),
                "optimizer": optimizer.state_dict()
            }
            torch.save(state, output_model_file)
def main():
    opt = parse_args()
    opt.sever_name = gethostname()

    # --- CUDA setting ---
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Random Seed setting ---
    if opt.random_seed is None:
        opt.random_seed = random.randint(1, 10000)
    random.seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(opt.random_seed)
        cudnn.deterministic = True
        cudnn.benchmark = opt.cudnn_benchmark

    # --- PATH setting ---
    save_result_dir = Path(__file__).parent / "results" / opt.experiment_name
    opt.save_model_dir = str(save_result_dir / "trained_models")
    opt.save_log_path = str(save_result_dir / "train.log")
    mkdirs(opt.save_model_dir)

    # --- Prepare DataLoader ---
    train_loader, valid_loader = get_train_val_loader(opt)
    opt.src_vocab_size = train_loader.dataset.src_vocab_size
    opt.tgt_vocab_size = train_loader.dataset.tgt_vocab_size

    # --- Prepare Model ---
    model = MultimodalTransformer(
        src_vocab_size=opt.src_vocab_size,
        tgt_vocab_size=opt.tgt_vocab_size,
        max_position_num=opt.max_position_num,
        d_model=opt.d_model,
        head_num=opt.head_num,
        d_k=opt.d_k,
        d_v=opt.d_v,
        d_inner=opt.d_inner,
        layer_num=opt.layer_num,
        dropout=opt.dropout,
        cnn_fine_tuning=opt.cnn_fine_tuning,
        shared_embedding=opt.shared_embedding,
        share_dec_input_output_embed=opt.share_dec_input_output_embed,
        init_weight=opt.init_weight,
        fused_layer_norm=opt.use_fused,
    ).to(device)

    # --- Prepare optimizer and scaler ---
    if opt.use_fused:
        from apex.optimizers import FusedAdam as Adam
    else:
        from torch.optim import Adam
    optimizer = Adam(filter(lambda x: x.requires_grad, model.parameters()),
                     betas=(0.9, 0.98),
                     eps=1e-09,
                     weight_decay=opt.weight_decay)
    scaler = GradScaler(init_scale=65536.0, enabled=opt.use_amp)

    # --- Restart setting ---
    start_cnt = 1
    steps_cnt = 0
    if opt.adapt_prop_MNMT is not None:
        ex_name, epoch_cnt = opt.adapt_prop_MNMT.split(',')
        saved_path = f"{pardir}/results/{ex_name}/trained_models/epoch_{epoch_cnt}.pth"
        saved_dict = torch.load(saved_path,
                                map_location=lambda storage, loc: storage)
        init_dir, init_epoch = saved_dict["settings"].MNMT.split(',')
        init_path = f"{pardir}/MNMT/results/{init_dir}/trained_models/epoch_{init_epoch}.pth"
        init_data = torch.load(init_path,
                               map_location=lambda storage, loc: storage)
        check_arguments(init_data["settings"], opt)
        model.load_state_dict(saved_dict["models"]["MNMT"])
        print(f"[Info]Loading complete ({saved_path})")
    elif opt.adapt_init_MNMT is not None:
        ex_name, epoch_cnt = opt.adapt_init_MNMT.split(',')
        saved_path = f"{Path(__file__).parent}/results/{ex_name}/trained_models/epoch_{epoch_cnt}.pth"
        saved_dict = torch.load(saved_path,
                                map_location=lambda storage, loc: storage)
        check_arguments(saved_dict["settings"], opt)
        model.load_state_dict(saved_dict["model"])
        print(f"[Info]Loading complete ({saved_path})")

    if opt.restart is not None:
        start_cnt = opt.restart + 1
        if opt.restart < 500:
            model_name = f"epoch_{opt.restart}.pth"
        else:
            model_name = f"step_{opt.restart}.pth"
        saved_path = f"{opt.save_model_dir}/{model_name}"
        saved_dict = torch.load(saved_path,
                                map_location=lambda storage, loc: storage)
        check_arguments(saved_dict["settings"], opt)
        model.load_state_dict(saved_dict["model"])
        optimizer.load_state_dict(saved_dict["optimizer"])
        scaler.load_state_dict(saved_dict["scaler"])
        steps_cnt = saved_dict["steps_cnt"]
        print(f"[Info]Loading complete ({saved_path})")

    scheduler = Scheduler(
        optimizer=optimizer,
        init_lr=0.,
        end_lr=opt.end_lr,
        warmup_steps=opt.warmup_steps,
        current_steps=steps_cnt,
    )

    # --- DataParallel setting ---
    gpus = [i for i in range(len(opt.gpu_ids.split(',')))]
    if len(gpus) > 1:
        model = nn.DataParallel(model, device_ids=gpus)

    # --- Prepare trainer and validator ---
    if valid_loader is not None:
        validator = ScoreCalculator(
            model=model,
            data_loader=valid_loader,
            references=valid_loader.dataset.tgt_insts,
            bpe=opt.bpe,
            cp_avg_num=opt.check_point_average,
        )
    else:
        validator = None

    trainer = MNMTTrainer(
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        scaler=scaler,
        scheduler=scheduler,
        opt=opt,
        validator=validator,
    )

    # -- Train --
    if opt.max_epoch is not None:
        trainer.train_by_epoch(start_cnt)
    else:
        trainer.train_by_step(start_cnt)
class ExperimentBase():
    # TODO train and val as pipelines

    def __init__(self, cfg=None):
        cfg = cfg or self.cfg  # try reading from static variable

        self.datasets = {}
        self.pipelines = {}
        self.init_config(cfg)
        self.init_transforms()

    def init_config(self, cfg):
        self.cfg = cfg
        self.workdir = Path(cfg['dir_checkpoint'])

        self.workdir.mkdir(exist_ok=True, parents=True)

        (self.workdir / 'config.json').write_text(cfg_json_encode(self.cfg))

    def print_cfg(self):
        print_cfg(self.cfg)

    def init_transforms(self):
        """ Init and store transforms that take time to construct 
			The transforms will be used in self.construct_default_pipeline
		"""
        # Lifecycle of a frame
        # in dataset:
        # 	dset.tr_post_load_pre_cache
        # 	dset.tr_output
        # in experiment:
        pass

    #def sampler_for_dset(self, role, dset):
    #tr_role = self.tr_input_per_role.get(role, None)
    #tr_in = self.tr_input if tr_role is None else TrsChain(tr_role, self.tr_input)

    #collate_fn = partial(self.dataloader_collate, tr_in, self.tr_input_post_batch)

    #args = self.loader_args_for_role(role)

    #return DataLoader(dset, collate_fn=collate_fn, **args)

    def set_dataset(self, role, dset):
        """
			role "train" or "val"
		"""

        self.datasets[role] = dset

    def load_checkpoint(self, chk_name='chk_best.pth'):
        dir_chk = Path(self.cfg['dir_checkpoint'])
        path_chk = dir_chk / chk_name

        if path_chk.is_file():
            log.info(f'Loading checkpoint found at {path_chk}')
            return torch.load(path_chk, map_location='cpu')
        else:
            log.info(f'No checkpoint at at {path_chk}')
            return None

    def init_net(self, role):
        """ Role: val or train - determines which checkpoint is loaded"""

        if role == 'train':
            chk = self.load_checkpoint(chk_name='chk_last.pth')
            chk_opt = self.load_checkpoint(chk_name='optimizer.pth')

            self.build_net(role, chk=chk)
            self.build_optimizer(role, chk_optimizer=chk_opt)
            self.net_mod.train()

        elif role == 'eval':
            chk = self.load_checkpoint(chk_name='chk_best.pth')
            self.build_net(role, chk=chk)
            self.net_mod.eval()

        else:
            raise NotImplementedError(f'role={role}')

        if chk is not None:
            self.state = GrumpyDict(chk['state'])
        else:
            self.state = train_state_init()

    def build_net(self, role, chk=None, chk_optimizer=None):
        """ Build net and optimizer (if we train) """
        log.info('Building net')

    @staticmethod
    def load_checkpoint_to_net(net_mod, chk_object):
        (missing_keys,
         superfluous_keys) = net_mod.load_state_dict(chk_object['weights'],
                                                     strict=False)
        if missing_keys:
            log.warning(
                f'Missing keys when loading a checkpoint: {missing_keys}')
        if superfluous_keys:
            log.warning(
                f'Missing keys when loading a checkpoint: {superfluous_keys}')

    def build_optimizer(self, role, chk_optimizer=None):
        log.info('Building optimizer')

        cfg_opt = self.cfg['train']['optimizer']

        network = self.net_mod
        self.optimizer = AdamOptimizer(
            [p for p in network.parameters() if p.requires_grad],
            lr=cfg_opt['learn_rate'],
            weight_decay=cfg_opt.get('weight_decay', 0),
        )
        self.learn_rate_scheduler = ReduceLROnPlateau(
            self.optimizer,
            patience=cfg_opt['lr_patience'],
            min_lr=cfg_opt['lr_min'],
        )

        if chk_optimizer is not None:
            self.optimizer.load_state_dict(chk_optimizer['optimizer'])

    def init_loss(self):
        log.info('Building loss_mod')

    def init_log(self, fids_to_display=[]):
        """
		:param fids_to_display: ids of frames to show in tensorboard
		"""
        # log for the current training run
        self.tboard = SummaryWriter(self.workdir /
                                    f"tb_{self.state['run_name']}")
        # save ground truth here to compare in tensorboard
        self.tboard_gt = SummaryWriter(self.workdir / 'tb_gt')
        self.tboard_img = SummaryWriter(self.workdir / 'tb_img')

        self.train_out_dir = self.workdir / f"imgs_{self.state['run_name']}"
        self.train_out_dir.mkdir(exist_ok=True, parents=True)

        # names of the frames to display
        def short_frame_name(fn):
            # remove directory path
            if '/' in fn:
                fn = os.path.basename(fn)
            return fn

        self.fids_to_display = set(fids_to_display)
        self.short_frame_names = {
            fid: short_frame_name(fid)
            for fid in self.fids_to_display
        }

    def log_selected_images(self, fid, frame, **_):
        if fid in self.fids_to_log:
            log.warning('log_selected_images: not implemented')

    def init_default_datasets(self):
        pass

    def init_pipelines(self):
        for role in ['train', 'val', 'test']:
            self.pipelines[role] = self.construct_default_pipeline(role)

    def get_epoch_limit(self):
        return self.cfg['train'].get('epoch_limit', None)

    def cuda_modules(self, attr_names):
        if torch.cuda.is_available():
            attr_names = [attr_names] if isinstance(attr_names,
                                                    str) else attr_names

            for an in attr_names:
                setattr(self, an, getattr(self, an).cuda())

    def training_start_batch(self, **_):
        self.optimizer.zero_grad()

    def training_backpropagate(self, loss, **_):
        #if torch.any(torch.isnan(loss)):
        #	print('Loss is NAN, cancelling backpropagation in batch')

        #raise Exception('Stopping training so we can investigate where the nan is coming from')
        #else:
        loss.backward()
        self.optimizer.step()

    def training_epoch_start(self, epoch_idx):
        self.net_mod.train()  # set train mode for dropout and batch-norm

    def training_epoch(self, epoch_idx):
        self.training_epoch_start(epoch_idx)

        out_frames = self.pipelines['train'].execute(
            dset=self.datasets['train'],
            b_grad=True,
            b_pbar=False,
            b_accumulate=True,
            log_progress_interval=self.cfg['train'].get(
                'progress_interval', None),
            short_epoch=self.cfg['train'].get('short_epoch_train', None),
        )
        gc.collect()

        results_avg = Frame({
            # the loss may be in fp16, let's average it at high precision to avoid NaN
            fn:
            np.mean(np.array([pf[fn] for pf in out_frames], dtype=np.float64))
            for fn in out_frames[0].keys() if fn.lower().startswith('loss')
        })

        self.training_epoch_finish(epoch_idx, results_avg)

        return results_avg['loss']

    def training_epoch_finish(self, epoch_idx, results_avg):
        for name, loss_avg in results_avg.items():
            self.tboard.add_scalar('train_' + name, loss_avg, epoch_idx)

    def val_epoch_start(self, epoch_idx):
        self.net_mod.eval()

    def val_epoch_finish(self, epoch_idx, results_avg):
        self.learn_rate_scheduler.step(results_avg['loss'])

        for name, loss_avg in results_avg.items():
            self.tboard.add_scalar('val_' + name, loss_avg, epoch_idx)

    def val_epoch(self, epoch_idx):
        self.val_epoch_start(epoch_idx)

        out_frames = self.pipelines['val'].execute(
            dset=self.datasets['val'],
            b_grad=False,
            b_pbar=False,
            b_accumulate=True,
            short_epoch=self.cfg['train'].get('short_epoch_val', None),
        )
        gc.collect()

        results_avg = Frame({
            fn: np.mean([pf[fn] for pf in out_frames])
            for fn in out_frames[0].keys() if fn.lower().startswith('loss')
        })

        self.val_epoch_finish(epoch_idx, results_avg)

        return results_avg['loss']

    def run_epoch(self, epoch_idx):

        gc.collect()
        epoch_limit = self.get_epoch_limit()
        log.info('E {ep:03d}{eplimit}\n	train start'.format(
            ep=epoch_idx,
            eplimit=f' / {epoch_limit}' if epoch_limit is not None else '',
        ))
        t_train_start = time.time()
        loss_train = self.training_epoch(epoch_idx)
        gc.collect()
        t_val_start = time.time()
        log.info('	train finished	t={tt}s	loss_t={ls}, val starting'.format(
            tt=t_val_start - t_train_start,
            ls=loss_train,
        ))

        gc.collect()
        loss_val = self.val_epoch(epoch_idx)
        gc.collect()
        log.info('	val finished	t={tt}s	loss_e={ls}'.format(
            tt=time.time() - t_val_start,
            ls=loss_val,
        ))

        is_best = loss_val < self.state['best_loss_val']
        if is_best:
            self.state['best_loss_val'] = loss_val
        is_chk_scheduled = epoch_idx % self.cfg['train'][
            'checkpoint_interval'] == 0

        if is_best or is_chk_scheduled:
            self.save_checkpoint(epoch_idx, is_best, is_chk_scheduled)

    def save_checkpoint(self, epoch_idx, is_best, is_scheduled):

        # TODO separate methods for saving various parts of the experiment
        chk_dict = dict()
        chk_dict['weights'] = self.net_mod.state_dict()
        chk_dict['state'] = dict(self.state)

        path_best = self.workdir / 'chk_best.pth'
        path_last = self.workdir / 'chk_last.pth'

        if is_scheduled:
            pytorch_save_atomic(chk_dict, path_last)

            pytorch_save_atomic(
                dict(epoch_idx=epoch_idx,
                     optimizer=self.optimizer.state_dict()),
                self.workdir / 'optimizer.pth',
            )

        if is_best:
            log.info('	New best checkpoint')
            if is_scheduled:
                # we already saved to chk_last.pth
                shutil.copy(path_last, path_best)
            else:
                pytorch_save_atomic(chk_dict, path_best)

    def training_run(self, b_initial_eval=True):
        name = self.cfg['name']
        log.info(f'Experiment {name} - train')

        path_stop = self.workdir / 'stop'

        if b_initial_eval:
            log.info('INIT\n	initial val')
            loss_val = self.val_epoch(self.state['epoch_idx'])

            log.info('	init loss_e={le}'.format(le=loss_val))
            self.state['best_loss_val'] = loss_val
        else:
            self.state['best_loss_val'] = 1e4

        b_continue = True

        while b_continue:
            self.state['epoch_idx'] += 1
            self.run_epoch(self.state['epoch_idx'])

            if path_stop.is_file():
                log.info('Stop file detected')
                path_stop.unlink()  # remove file
                b_continue = False

            epoch_limit = self.get_epoch_limit()
            if (epoch_limit
                    is not None) and (self.state['epoch_idx'] >= epoch_limit):
                log.info(f'Reached epoch limit {epoch_limit}')
                b_continue = False

    @classmethod
    def training_procedure(cls):
        print(f'-- Training procesure for {cls.__name__} --')
        exp = cls()
        log_config_file(exp.workdir / 'training.log')

        log.info(f'Starting training job for {cls.__name__}')
        exp.print_cfg()

        try:
            exp.init_default_datasets()

            exp.init_net("train")

            log.info(f'Name of the run: {exp.state["run_name"]}')

            exp.init_transforms()
            exp.init_loss()
            exp.init_log()
            exp.init_pipelines()
            exp.training_run()
        # if training crashes, put the exception in the log
        except Exception as e:
            log.exception(f'Exception in taining procedure: {e}')

    def predict_sequence(self, dset, consumer=None, pbar=True):
        """
		If consumer is specified, it will be used for online processing:
		frames will be given to it instead of being accumulated
		"""
        self.net_mod.eval()
        out_frames = self.pipelines['test'].execute(
            dset=self.datasets['test'],
            b_grad=False,
            b_pbar=pbar,
            b_accumulate=True,
        )
        return out_frames

    def loader_args_for_role(self, role):
        if role == 'train':
            return dict(
                shuffle=True,
                batch_size=self.cfg['net']['batch_train'],
                num_workers=self.cfg['train'].get('num_workers', 0),
                drop_last=True,
            )
        elif role == 'val' or role == 'test':
            return dict(
                shuffle=False,
                batch_size=self.cfg['net']['batch_eval'],
                num_workers=self.cfg['train'].get('num_workers', 0),
                drop_last=False,
            )
        else:
            raise NotImplementedError("role: " + role)

    def construct_default_pipeline(self, role):
        if role == 'train':
            tr_batch = TrsChain([
                TrCUDA(),
                self.training_start_batch,
                self.net_mod,
                self.loss_mod,
                self.training_backpropagate,
            ])
            tr_output = TrsChain([
                TrKeepFieldsByPrefix('loss'),  # save loss for averaging later
                TrNP(),  # clear away the gradients if any are left
            ])

        elif role == 'val':
            tr_batch = TrsChain([
                TrCUDA(),
                self.net_mod,
                self.loss_mod,
            ])
            tr_output = TrsChain([
                self.log_selected_images,
                TrKeepFieldsByPrefix('loss'),  # save loss for averaging later
                TrNP(),  # clear away the gradients if any are left
            ])

        elif role == 'test':
            tr_batch = TrsChain([
                TrCUDA(),
                self.net_mod,
            ])
            tr_output = TrsChain([
                TrNP(),
                tr_untorch_images,
            ])

        return Pipeline(
            tr_batch=tr_batch,
            tr_output=tr_output,
            loader_args=self.loader_args_for_role(role),
        )

    def run_evaluation(self, eval_obj, dset=None, b_one_batch=False):
        pipe_test = self.construct_default_pipeline('test')
        dset = dset or eval_obj.get_dset()

        eval_obj.construct_transforms(dset)

        pipe_test.tr_batch.append(eval_obj.tr_batch)
        pipe_test.tr_output.append(eval_obj.tr_output)

        log.info(f'Test pipeline: {pipe_test}')
        pipe_test.execute(dset, b_accumulate=False, b_one_batch=b_one_batch)
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--file_path",
        default="data/conceptual_caption/",
        type=str,
        help="The input train corpus.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-base-uncased, roberta-base, roberta-large, ",
    )
    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, roberta-base",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        # required=True,
        help=
        "The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        type=str,
        default="config/bert_base_6layer_6conect.json",
        help="The config file which specified the model details.",
    )
    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=36,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.",
    )
    parser.add_argument(
        "--train_batch_size",
        default=512,
        type=int,
        help="Total batch size for training.",
    )
    parser.add_argument(
        "--learning_rate",
        default=1e-4,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=10.0,
        type=float,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--start_epoch",
        default=0,
        type=float,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument("--img_weight",
                        default=1,
                        type=float,
                        help="weight for image loss")
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--on_memory",
        action="store_true",
        help="Whether to load train samples into memory or use disk",
    )
    parser.add_argument(
        "--do_lower_case",
        type=bool,
        default=True,
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--dynamic_attention",
        action="store_true",
        help="whether use dynamic attention.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=25,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument("--save_name",
                        default="",
                        type=str,
                        help="save name for training.")
    parser.add_argument(
        "--baseline",
        action="store_true",
        help="Wheter to use the baseline model (single bert).",
    )
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.",
    )
    parser.add_argument(
        "--distributed",
        action="store_true",
        help="whether use chunck for parallel training.",
    )
    parser.add_argument("--without_coattention",
                        action="store_true",
                        help="whether pair loss.")
    parser.add_argument(
        "--visual_target",
        default=0,
        type=int,
        help="which target to use for visual branch. \
        0: soft label, \
        1: regress the feature, \
        2: NCE loss.",
    )

    parser.add_argument(
        "--objective",
        default=0,
        type=int,
        help="which objective to use \
        0: with ICA loss, \
        1: with ICA loss, for the not aligned pair, no masking objective, \
        2: without ICA loss, do not sample negative pair.",
    )
    parser.add_argument("--num_negative",
                        default=255,
                        type=int,
                        help="num of negative to use")

    parser.add_argument("--resume_file",
                        default="",
                        type=str,
                        help="Resume from checkpoint")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")

    args = parser.parse_args()

    if args.baseline:
        from pytorch_pretrained_bert.modeling import BertConfig
        from vilbert.basebert import BertForMultiModalPreTraining
    else:
        from vilbert.vilbert import BertForMultiModalPreTraining, BertConfig

    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""

    timeStamp = args.config_file.split("/")[1].split(".")[0] + prefix
    savePath = os.path.join(args.output_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.from_pretrained + "_weight_name.json", "r"))

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)

    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    cache = 5000
    if dist.is_available() and args.local_rank != -1:
        num_replicas = dist.get_world_size()
        args.train_batch_size = args.train_batch_size // num_replicas
        args.num_workers = args.num_workers // num_replicas
        cache = cache // num_replicas

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if "roberta" in args.bert_model:
        tokenizer = RobertaTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
    else:
        tokenizer = BertTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
    num_train_optimization_steps = None
    train_dataset = ConceptCapLoaderTrain(
        args.file_path,
        tokenizer,
        args.bert_model,
        seq_len=args.max_seq_length,
        batch_size=args.train_batch_size,
        visual_target=args.visual_target,
        num_workers=args.num_workers,
        local_rank=args.local_rank,
        objective=args.objective,
        cache=cache,
    )

    validation_dataset = ConceptCapLoaderVal(
        args.file_path,
        tokenizer,
        args.bert_model,
        seq_len=args.max_seq_length,
        batch_size=args.train_batch_size,
        visual_target=args.visual_target,
        num_workers=2,
        objective=args.objective,
    )

    num_train_optimization_steps = int(
        train_dataset.num_dataset / args.train_batch_size /
        args.gradient_accumulation_steps) * (args.num_train_epochs -
                                             args.start_epoch)

    task_names = ["Conceptual_Caption"]
    task_ids = ["TASK0"]
    task_num_iters = {
        "TASK0": train_dataset.num_dataset / args.train_batch_size
    }

    logdir = os.path.join("logs", timeStamp)
    if default_gpu:
        tbLogger = utils.tbLogger(
            logdir,
            savePath,
            task_names,
            task_ids,
            task_num_iters,
            args.gradient_accumulation_steps,
        )

    if args.visual_target == 0:
        config.v_target_size = 1601
        config.visual_target = args.visual_target
    else:
        config.v_target_size = 2048
        config.visual_target = args.visual_target

    if "roberta" in args.bert_model:
        config.model = "roberta"

    if args.freeze > config.t_biattention_id[0]:
        config.fixed_t_layer = config.t_biattention_id[0]

    if args.without_coattention:
        config.with_coattention = False

    if args.dynamic_attention:
        config.dynamic_attention = True

    if args.from_pretrained:
        model = BertForMultiModalPreTraining.from_pretrained(
            args.from_pretrained, config=config, default_gpu=default_gpu)
    else:
        model = BertForMultiModalPreTraining(config)

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if "embeddings" in name:
                bert_weight_name_filtered.append(name)
            elif "encoder" in name:
                layer_num = name.split(".")[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    if not args.from_pretrained:
        param_optimizer = list(model.named_parameters())
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.01,
            },
            {
                "params": [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
    else:
        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if value.requires_grad:
                if key[12:] in bert_weight_name:
                    lr = args.learning_rate * 0.1
                else:
                    lr = args.learning_rate

                if any(nd in key for nd in no_decay):
                    optimizer_grouped_parameters += [{
                        "params": [value],
                        "lr": lr,
                        "weight_decay": 0.0
                    }]

                if not any(nd in key for nd in no_decay):
                    optimizer_grouped_parameters += [{
                        "params": [value],
                        "lr": lr,
                        "weight_decay": 0.01
                    }]

        if default_gpu:
            print(len(list(model.named_parameters())),
                  len(optimizer_grouped_parameters))

    # set different parameters for vision branch and lanugage branch.
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            bias_correction=False,
            max_grad_norm=1.0,
        )
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:

        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            eps=args.adam_epsilon,
            betas=(0.9, 0.98),
        )

    scheduler = WarmupLinearSchedule(
        optimizer,
        warmup_steps=args.warmup_proportion * num_train_optimization_steps,
        t_total=num_train_optimization_steps,
    )

    startIterID = 0
    global_step = 0

    if args.resume_file != "" and os.path.exists(args.resume_file):
        checkpoint = torch.load(args.resume_file, map_location="cpu")
        new_dict = {}
        for attr in checkpoint["model_state_dict"]:
            if attr.startswith("module."):
                new_dict[attr.replace(
                    "module.", "", 1)] = checkpoint["model_state_dict"][attr]
            else:
                new_dict[attr] = checkpoint["model_state_dict"][attr]
        model.load_state_dict(new_dict)
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        del checkpoint

    model.cuda()

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    if args.fp16:
        model.half()
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if default_gpu:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", train_dataset.num_dataset)
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

    for epochId in range(int(args.start_epoch), int(args.num_train_epochs)):
        model.train()
        for step, batch in enumerate(train_dataset):

            iterId = startIterID + step + (epochId * len(train_dataset))
            image_ids = batch[-1]
            batch = tuple(
                t.cuda(device=device, non_blocking=True) for t in batch[:-1])

            input_ids, input_mask, segment_ids, lm_label_ids, is_next, image_feat, image_loc, image_target, image_label, image_mask = (
                batch)

            if args.objective == 1:
                image_label = image_label * (is_next == 0).long().unsqueeze(1)
                image_label[image_label == 0] = -1

                lm_label_ids = lm_label_ids * (is_next
                                               == 0).long().unsqueeze(1)
                lm_label_ids[lm_label_ids == 0] = -1

            masked_loss_t, masked_loss_v, next_sentence_loss = model(
                input_ids,
                image_feat,
                image_loc,
                segment_ids,
                input_mask,
                image_mask,
                lm_label_ids,
                image_label,
                image_target,
                is_next,
            )

            if args.objective == 2:
                next_sentence_loss = next_sentence_loss * 0

            masked_loss_v = masked_loss_v * args.img_weight
            loss = masked_loss_t + masked_loss_v + next_sentence_loss

            if n_gpu > 1:
                loss = loss.mean()
                masked_loss_t = masked_loss_t.mean()
                masked_loss_v = masked_loss_v.mean()
                next_sentence_loss = next_sentence_loss.mean()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / num_train_optimization_steps,
                        args.warmup_proportion,
                    )
                    for param_group in optimizer.param_groups:
                        param_group["lr"] = lr_this_step

                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                if default_gpu:
                    tbLogger.step_train_CC(
                        epochId,
                        iterId,
                        float(masked_loss_t),
                        float(masked_loss_v),
                        float(next_sentence_loss),
                        optimizer.param_groups[0]["lr"],
                        "TASK0",
                        "train",
                    )

            if (step % (20 * args.gradient_accumulation_steps) == 0
                    and step != 0 and default_gpu):
                tbLogger.showLossTrainCC()

        # Do the evaluation
        torch.set_grad_enabled(False)
        numBatches = len(validation_dataset)

        model.eval()
        for step, batch in enumerate(validation_dataset):
            image_ids = batch[-1]
            batch = tuple(
                t.cuda(device=device, non_blocking=True) for t in batch[:-1])

            input_ids, input_mask, segment_ids, lm_label_ids, is_next, image_feat, image_loc, image_target, image_label, image_mask = (
                batch)

            batch_size = input_ids.size(0)
            masked_loss_t, masked_loss_v, next_sentence_loss = model(
                input_ids,
                image_feat,
                image_loc,
                segment_ids,
                input_mask,
                image_mask,
                lm_label_ids,
                image_label,
                image_target,
                is_next,
            )

            masked_loss_v = masked_loss_v * args.img_weight
            loss = masked_loss_t + masked_loss_v + next_sentence_loss

            if n_gpu > 1:
                loss = loss.mean()
                masked_loss_t = masked_loss_t.mean()
                masked_loss_v = masked_loss_v.mean()
                next_sentence_loss = next_sentence_loss.mean()

            if default_gpu:
                tbLogger.step_val_CC(
                    epochId,
                    float(masked_loss_t),
                    float(masked_loss_v),
                    float(next_sentence_loss),
                    "TASK0",
                    batch_size,
                    "val",
                )
                sys.stdout.write("%d / %d \r" % (step, numBatches))
                sys.stdout.flush()

        if default_gpu:
            ave_score = tbLogger.showLossValCC()

        torch.set_grad_enabled(True)

        if default_gpu:
            # Save a trained model
            logger.info("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self
            output_model_file = os.path.join(
                savePath, "pytorch_model_" + str(epochId) + ".bin")
            output_checkpoint = os.path.join(
                savePath, "pytorch_ckpt_" + str(epochId) + ".tar")
            torch.save(model_to_save.state_dict(), output_model_file)
            torch.save(
                {
                    "model_state_dict": model_to_save.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "global_step": global_step,
                },
                output_checkpoint,
            )

    if default_gpu:
        tbLogger.txt_close()
Beispiel #17
0
def run_train(args, hparams):
    if args.seed is not None:
        print("Setting numpy random seed to {}...".format(args.seed))
        np.random.seed(args.seed)

    seed_from_numpy = np.random.randint(2147483648)
    print("Manual seed for pytorch:", seed_from_numpy)
    torch.manual_seed(seed_from_numpy)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed_from_numpy)

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
    #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    # os.makedirs(args.output_dir, exist_ok=True)

    print("Initializing model...")
    load_path = args.load_path
    if load_path is not None:
        print(f"Loading parameters from {load_path}")
        info = torch_load(load_path)
        model = Zmodel.Jointmodel.from_spec(info['spec'], info['state_dict'])
        hparams = model.hparams
        Ptb_dataset = PTBDataset(hparams)
        Ptb_dataset.process_PTB(args)
    else:
        hparams.set_from_args(args)
        Ptb_dataset = PTBDataset(hparams)
        Ptb_dataset.process_PTB(args)
        model = Zmodel.Jointmodel(
            Ptb_dataset.tag_vocab,
            Ptb_dataset.word_vocab,
            Ptb_dataset.label_vocab,
            Ptb_dataset.char_vocab,
            Ptb_dataset.type_vocab,
            Ptb_dataset.srl_vocab,
            hparams,
        )
    print("Hyperparameters:")
    hparams.print()
    # tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    #train_examples = None
    # num_train_steps = None
    print("Loading Train Dataset", args.train_file)
    Ptb_dataset.rand_dataset()
    # print(model.tokenizer.tokenize("Federal Paper Board sells paper and wood products ."))
    #max_seq_length = model.bert_max_len
    train_dataset = BERTDataset(args.pre_wiki_line,
                                hparams,
                                Ptb_dataset,
                                args.train_file,
                                model.tokenizer,
                                seq_len=args.max_seq_length,
                                corpus_lines=None,
                                on_memory=args.on_memory)
    task_list = [
        'dev_synconst', 'dev_srlspan', 'dev_srldep', 'test_synconst',
        'test_srlspan', 'test_srldep', 'brown_srlspan', 'brown_srldep'
    ]
    evaluator = EvalManyTask(device=1,
                             hparams=hparams,
                             ptb_dataset=Ptb_dataset,
                             task_list=task_list,
                             bert_tokenizer=model.tokenizer,
                             seq_len=args.eval_seq_length,
                             eval_batch_size=args.eval_batch_size,
                             evalb_dir=args.evalb_dir,
                             model_path_base=args.save_model_path_base,
                             log_path="{}_log".format("models_log/" +
                                                      hparams.model_name))

    num_train_steps = int(
        len(train_dataset) / args.train_batch_size /
        args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    # model = BertForPreTraining.from_pretrained(args.bert_model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps)
    if load_path is not None:
        optimizer.load_state_dict(info['optimizer'])

    global_step = args.pre_step
    pre_step = args.pre_step
    # wiki_line = 0
    # while train_dataset.wiki_id < wiki_line:
    #     train_dataset.file.__next__().strip()
    #     train_dataset.wiki_id+=1

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    if args.local_rank == -1:
        train_sampler = RandomSampler(train_dataset)
    else:
        #TODO: check if this works with current data generator from disk that relies on file.__next__
        # (it doesn't return item back by index)
        train_sampler = DistributedSampler(train_dataset)

    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    hparams.model_name = args.model_name
    print("This is ", hparams.model_name)
    start_time = time.time()

    def save_args(hparams):
        arg_path = "{}_log".format("models_log/" +
                                   hparams.model_name) + '.arg.json'
        kwargs = hparams.to_dict()
        json.dump({'kwargs': kwargs}, open(arg_path, 'w'), indent=4)

    save_args(hparams)
    # test_save_path = args.save_model_path_base + "_fortest"
    # torch.save({
    #     'spec': model_to_save.spec,
    #     'state_dict': model_to_save.state_dict(),
    #     'optimizer': optimizer.state_dict(),
    # }, test_save_path + ".pt")
    # evaluator.test_model_path = test_save_path

    cur_ptb_epoch = 0
    for epoch in trange(int(args.num_train_epochs), desc="Epoch"):

        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        #save_model_path, is_save = evaluator.eval_multitask(start_time, cur_ptb_epoch)

        epoch_start_time = time.time()
        for step, batch in enumerate(train_dataloader):
            model.train()
            input_ids, origin_ids, input_mask, word_start_mask, word_end_mask, segment_ids, perm_mask, target_mapping, lm_label_ids, lm_label_mask, is_next, \
            synconst_list, syndep_head_list, syndep_type_list, srlspan_str_list, srldep_str_list, is_ptb = batch
            # synconst_list, syndep_head_list, syndep_type_list , srlspan_str_list, srldep_str_list = gold_list
            dis_idx = [i for i in range(len(input_ids))]
            dis_idx = torch.tensor(dis_idx)
            batch = dis_idx, input_ids, origin_ids, input_mask, word_start_mask, word_end_mask, segment_ids, perm_mask, target_mapping, lm_label_ids, lm_label_mask, is_next
            bert_data = tuple(t.to(device) for t in batch)
            sentences = []
            gold_syntree = []
            gold_srlspans = []
            gold_srldeps = []

            # for data_dict1 in dict1:
            for synconst, syndep_head_str, syndep_type_str, srlspan_str, srldep_str in zip(
                    synconst_list, syndep_head_list, syndep_type_list,
                    srlspan_str_list, srldep_str_list):

                syndep_head = json.loads(syndep_head_str)
                syndep_type = json.loads(syndep_type_str)
                syntree = trees.load_trees(
                    synconst, [[int(head) for head in syndep_head]],
                    [syndep_type],
                    strip_top=False)[0]
                sentences.append([(leaf.tag, leaf.word)
                                  for leaf in syntree.leaves()])

                gold_syntree.append(syntree.convert())

                srlspan = {}
                srlspan_dict = json.loads(srlspan_str)
                for pred_id, argus in srlspan_dict.items():
                    srlspan[int(pred_id)] = [(int(a[0]), int(a[1]), a[2])
                                             for a in argus]

                srldep_dict = json.loads(srldep_str)
                srldep = {}
                if str(-1) in srldep_dict:
                    srldep = None
                else:
                    for pred_id, argus in srldep_dict.items():
                        srldep[int(pred_id)] = [(int(a[0]), a[1])
                                                for a in argus]

                gold_srlspans.append(srlspan)
                gold_srldeps.append(srldep)

            if global_step < pre_step:
                if global_step % 1000 == 0:
                    print("global_step:", global_step)
                    print("pre_step:", pre_step)
                    print("Wiki line:", train_dataset.wiki_line)
                    print("total-elapsed {} ".format(
                        format_elapsed(start_time)))
                global_step += 1
                cur_ptb_epoch = train_dataset.ptb_epoch
                continue

            bert_loss, task_loss = model(sentences=sentences,
                                         gold_trees=gold_syntree,
                                         gold_srlspans=gold_srlspans,
                                         gold_srldeps=gold_srldeps,
                                         bert_data=bert_data)
            if n_gpu > 1:
                bert_loss = bert_loss.sum()
                task_loss = task_loss.sum()
            loss = bert_loss + task_loss  #* 0.1
            loss = loss / len(synconst_list)
            bert_loss = bert_loss / len(synconst_list)
            task_loss = task_loss / len(synconst_list)

            tatal_loss = float(loss.data.cpu().numpy())

            if bert_loss > 0:
                bert_loss = float(bert_loss.data.cpu().numpy())
            if task_loss > 0:
                task_loss = float(task_loss.data.cpu().numpy())

            # grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold)

            lr_this_step = args.learning_rate * warmup_linear(
                global_step / num_train_steps, args.warmup_proportion)
            print("epoch {:,} "
                  "ptb-epoch {:,} "
                  "batch {:,}/{:,} "
                  "processed {:,} "
                  "PTB line {:,} "
                  "Wiki line {:,} "
                  "total-loss {:.4f} "
                  "bert-loss {:.4f} "
                  "task-loss {:.4f} "
                  "lr_this_step {:.12f} "
                  "epoch-elapsed {} "
                  "total-elapsed {}".format(
                      epoch,
                      cur_ptb_epoch,
                      global_step,
                      int(np.ceil(len(train_dataset) / args.train_batch_size)),
                      (global_step + 1) * args.train_batch_size,
                      train_dataset.ptb_cur_line,
                      train_dataset.wiki_line,
                      tatal_loss,
                      bert_loss,
                      task_loss,
                      lr_this_step,
                      format_elapsed(epoch_start_time),
                      format_elapsed(start_time),
                  ))

            # if n_gpu > 1:
            #     loss = loss.mean() # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            if (step + 1) % args.gradient_accumulation_steps == 0:
                # modify learning rate with special warm up BERT uses
                lr_this_step = args.learning_rate * warmup_linear(
                    global_step / num_train_steps, args.warmup_proportion)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                #if train_dataset.ptb_epoch > cur_ptb_epoch:
                if global_step % args.pre_step_tosave == 0:
                    cur_ptb_epoch = train_dataset.ptb_epoch

                    save_path = "{}_gstep{}_wiki{}_loss={:.4f}.pt".\
                        format(args.save_model_path_base, global_step, train_dataset.wiki_line, tatal_loss)
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    torch.save(
                        {
                            'spec': model_to_save.spec,
                            'state_dict': model_to_save.state_dict(),
                            'optimizer': optimizer.state_dict(),
                        }, save_path)

                    # evaluator.test_model_path = test_save_path
                    #
                    # save_model_path, is_save = evaluator.eval_multitask(start_time, cur_ptb_epoch)
                    # if is_save:
                    #     print("Saving new best model to {}...".format(save_model_path))
                    #     torch.save({
                    #         'spec': model_to_save.spec,
                    #         'state_dict': model_to_save.state_dict(),
                    #         'optimizer': optimizer.state_dict(),
                    #     }, save_model_path + ".pt")

    # Save a trained model
    logger.info("** ** * Saving fine - tuned model ** ** * ")
    torch.save(
        {
            'spec': model_to_save.spec,
            'state_dict': model_to_save.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, args.save_model_path_base + ".pt")
Beispiel #18
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--topic_model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--topic_model_dict_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument('--topic_mode',
                        default=1,
                        type=float,
                        help="1:idea1 1.1:idea1_wo_theta 2:idea2 ")
    parser.add_argument('--topic_model',
                        default=False,
                        type=bool,
                        help="if only use topic model")
    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=192,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument(
        "--train_batch_size",
        default=32,
        type=int,
        help="Total batch size for training.")  #batch_size = batch_size/n_gpus
    parser.add_argument("--eval_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=30,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()
    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)
    print("args.local_rank", args.local_rank)
    print("args.no_cuda", args.no_cuda)
    if args.local_rank == -1 or args.no_cuda:  #-1 False
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")  #device = cuda
        n_gpu = torch.cuda.device_count()
        print("n_gpu_1", n_gpu)
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
        print("n_gpu_1", n_gpu)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()
    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.data_dir,
            args.topic_model_dict_path,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)
    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2    ### type_vocab_size=6
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        unilm = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
        global_step = 0

    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:  # here is the entrance
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        unilm = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
    #1. 模型初始化,入口定义好
    gsm = GSM(train_dataset.vocabsize)
    gsm_checkpoint = torch.load(args.topic_model_recover_path)
    gsm.load_state_dict(gsm_checkpoint["net"])

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        unilm.half()
        gsm.half()
        if args.fp32_embedding:
            unilm.bert.embeddings.word_embeddings.float()
            unilm.bert.embeddings.position_embeddings.float()
            unilm.bert.embeddings.token_type_embeddings.float()
    unilm.to(device)
    gsm.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        unilm = DDP(unilm,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        unilm = DataParallelImbalance(unilm)
        gsm = DataParallelImbalance(gsm)
    # Prepare optimizer
    total = 0
    param_optimizer = list(unilm.named_parameters())
    param_optimizer_topic = list(gsm.named_parameters())
    for name, parameters in unilm.named_parameters():
        if "idea" in name:
            if "11" in name and "idea2" in name:
                total += np.prod(parameters.size())
                # print(name, ':', parameters.size())
        else:
            total += np.prod(parameters.size())
            # print(name, ':', parameters.size())

    print("gsm have {} paramerters in total".format(
        sum(x.numel() for x in gsm.parameters())))
    print("Number of parameter: %.6fM" % (total / 1e6))
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    if not args.topic_model:
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01,
            'topic':
            False
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0,
            'topic':
            False
        }, {
            'params': [p for n, p in param_optimizer_topic],
            'weight_decay':
            0.0,
            'lr':
            1e-3,
            'topic':
            True
        }]
    else:
        optimizer_grouped_parameters = [{
            'params': [p for n, p in param_optimizer_topic],
            'weight_decay':
            0.0,
            'lr':
            1e-3,
            'topic':
            True
        }]
    #一部分是有weight的,一部分是没有weight_dacay的
    # print("optimizer_grouped_parameters", optimizer_grouped_parameters)
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)
    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        unilm.train()
        gsm.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        print("000000", args.local_rank, start_epoch,
              int(args.num_train_epochs) + 1)
        topicloss = []
        unilmloss = []
        topicloss_lst = []
        unilmloss_lst = []
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):

            loss_sum = 0.0
            ppx_sum = 0.0
            word_count = 0.0
            doc_count = 0.0
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:  #false
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch
                else:  #这里加了bows
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, bows = batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                p_x, mus, log_vars, theta, beta, topic_embedding = gsm(bows)
                if not args.topic_model:
                    loss_tuple = unilm(input_ids,
                                       theta,
                                       beta,
                                       topic_embedding,
                                       args.topic_mode,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv)
                    masked_lm_loss, next_sentence_loss = loss_tuple

                ## topic loss
                logsoftmax = torch.log(p_x + 1e-10)
                rec_loss = -1.0 * torch.sum(
                    bows * logsoftmax
                )  #bows*logsoftmax = [batch_size, |V|], 其中torch.sum 把所有的loss全部加起来了,也可以只用加某一维度。
                rec_loss_per = -1.0 * torch.sum(bows * logsoftmax, dim=1)
                rec_loss_per = rec_loss_per.cpu().detach().numpy()

                kl_div = -0.5 * torch.sum(1 + log_vars - mus.pow(2) -
                                          log_vars.exp())
                loss_topic = rec_loss + kl_div
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    loss_topic = loss_topic.mean()
                    if not args.topic_model:
                        masked_lm_loss = masked_lm_loss.mean()
                        next_sentence_loss = next_sentence_loss.mean()
                if not args.topic_model:
                    loss_unilm = masked_lm_loss + next_sentence_loss

                # cal perplexity
                word_count_list = []
                loss_sum += loss_topic.item()
                for bow in bows:
                    word_num = torch.sum(bow).cpu().numpy()
                    word_count_list.append(word_num)
                    word_count += word_num

                word_count_np = np.array(word_count_list)
                doc_count += len(bows)
                ppx_sum += np.sum(np.true_divide(rec_loss_per, word_count_np))
                topicloss_lst.append(loss_topic.item() / len(bows))
                if not args.topic_model:
                    unilmloss_lst.append(loss_unilm.item())
                #topic_loss end
                if not args.topic_model:
                    loss = loss_unilm + loss_topic
                else:
                    loss = loss_topic
                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:  # =1
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            if not param_group['topic']:
                                param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                if not args.topic_model:
                    iter_bar.set_description(
                        'Iter (loss_unilm=%5.3f),Iter (ppl=%5.3f)' %
                        (loss_unilm.item(),
                         np.sum(np.true_divide(rec_loss_per, word_count_np))))
                else:
                    iter_bar.set_description(
                        'Iter (loss_topic=%5.3f), (ppl=%5.3f)' %
                        (loss_topic.item(),
                         np.sum(np.true_divide(rec_loss_per, word_count_np))))
            #Save a trained model
            ppx_word = np.exp(loss_sum / word_count)
            ppx_document = np.exp(ppx_sum / doc_count)
            print("********")
            print("word_count", word_count)
            print("ppx_word", ppx_word)
            print("ppx_document", ppx_document)
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                #save unilm model
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                unilm_model_to_save = unilm.module if hasattr(
                    unilm, 'module') else unilm  # Only save the model it-self
                output_unilm_model_file = os.path.join(
                    args.output_dir, "unilm.{0}.bin".format(i_epoch))
                torch.save(unilm_model_to_save.state_dict(),
                           output_unilm_model_file)
                #save topic model
                logger.info(
                    "** ** * Saving topic model and optimizer ** ** * ")
                topic_model_to_save = gsm.module if hasattr(
                    gsm, 'module') else gsm  # Only save the model it-self
                output_topic_model_file = os.path.join(
                    args.output_dir, "topic.{0}.ckpt".format(i_epoch))
                torch.save(topic_model_to_save.state_dict(),
                           output_topic_model_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
        smth_pts = smooth_curve(topicloss_lst)
        # plt.plot(range(len(topicloss_lst)), topicloss_lst)
        plt.plot(range(len(smth_pts)), smth_pts)
        plt.xlabel('epochs')
        plt.title('Topic Model Train Loss')
        plt.savefig(args.output_dir + '/topic_loss.png')
        plt.cla()
        plt.plot(range(len(unilmloss_lst)), unilmloss_lst)
        plt.xlabel('epochs')
        plt.title('Unilm Train Loss')
        plt.savefig(args.output_dir + '/unilm_loss.png')
Beispiel #19
0
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)

    if args.disable_weight_tying:
        import torch.nn as nn
        print ("WARNING!!!!!!! Disabling weight tying for this run")
        print ("BEFORE ", model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight)
        model.cls.predictions.decoder.weight = torch.nn.Parameter(model.cls.predictions.decoder.weight.clone().detach())
        print ("AFTER ", model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight)
        assert (model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight) == False

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")]
            args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)
        
        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    optimizer = FusedAdam(optimizer_grouped_parameters,
                          lr=args.learning_rate)
    lr_scheduler = PolyWarmUpScheduler(optimizer, 
                                       warmup=args.warmup_proportion, 
                                       total_steps=args.max_steps,
                                       degree=1)
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16)
        else:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16)
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    model.checkpoint_activations(args.checkpoint_activations)

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter]['step'] = global_step
                checkpoint['optimizer']['param_groups'][iter]['t_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter]['warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter]['lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters          
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)


    if args.disable_weight_tying:
       # Sanity Check that new param is in optimizer
       print ("SANITY CHECK OPTIMIZER: ", id(model.module.cls.predictions.decoder.weight) in [id(g) for g in optimizer.param_groups[0]['params']])
       assert id(model.module.cls.predictions.decoder.weight) in [id(g) for g in optimizer.param_groups[0]['params']]

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Beispiel #20
0
def main():
    args = process_args()
    if args.loss_type == 'mlm':
        assert args.neg_num == 0 and args.multiple_neg == 0
    elif args.loss_type == 'nsp':
        assert int(
            args.bi_prob) == 1 and args.max_pred == 0 and args.neg_num > 0

    print('global_rank: {}, local rank: {}'.format(args.global_rank,
                                                   args.local_rank))

    # Input format: [CLS] img [SEP] hist [SEP_0] ques [SEP_1] ans [SEP]
    args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1

    args.mask_image_regions = (args.vis_mask_prob > 0
                               )  # whether to mask out image regions
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    # arguments inspection
    assert args.enable_butd, 'only support region attn! featmap attn deprecated'

    if args.enable_butd:
        if args.visdial_v == '1.0':
            assert (args.len_vis_input == 36)
        elif args.visdial_v == '0.9':
            assert (args.len_vis_input == 100)
            args.region_bbox_file = os.path.join(args.image_root,
                                                 args.region_bbox_file)
            args.region_det_file_prefix = os.path.join(
                args.image_root,
                args.region_det_file_prefix) if args.dataset in (
                    'cc', 'coco') and args.region_det_file_prefix != '' else ''

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)
    stdout = True
    if stdout:
        ch = logging.StreamHandler(sys.stdout)
        ch.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(levelname)s - %(name)s -   %(message)s'))
        ch.setLevel(logging.INFO)
        logger.addHandler(ch)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=args.dist_url,
                                             world_size=args.world_size,
                                             rank=args.global_rank)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom
        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {'iter': None, 'score': None}

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir +
        '/.pretrained_model_{}'.format(args.global_rank))
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    assert args.do_train
    bi_uni_pipeline = [
        Preprocess4TrainVisdialRankLoss(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'len_vis_input': args.len_vis_input,
                'max_len_hist_ques': args.max_len_hist_ques,
                'max_len_ans': args.max_len_ans
            },
            mask_image_regions=args.mask_image_regions,
            mode="s2s",
            vis_mask_prob=args.vis_mask_prob,
            region_bbox_file=args.region_bbox_file,
            region_det_file_prefix=args.region_det_file_prefix,
            image_features_hdfpath=args.image_features_hdfpath,
            visdial_v=args.visdial_v,
            pad_hist=args.pad_hist,
            finetune=args.finetune,
            only_mask_ans=args.only_mask_ans,
            float_nsp_label=args.float_nsp_label),
        Preprocess4TrainVisdialRankLoss(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'len_vis_input': args.len_vis_input,
                'max_len_hist_ques': args.max_len_hist_ques,
                'max_len_ans': args.max_len_ans
            },
            mask_image_regions=args.mask_image_regions,
            mode="bi",
            vis_mask_prob=args.vis_mask_prob,
            region_bbox_file=args.region_bbox_file,
            region_det_file_prefix=args.region_det_file_prefix,
            image_features_hdfpath=args.image_features_hdfpath,
            visdial_v=args.visdial_v,
            pad_hist=args.pad_hist,
            finetune=args.finetune,
            only_mask_ans=args.only_mask_ans,
            float_nsp_label=args.float_nsp_label)
    ]

    train_dataset = VisdialDatasetRelRankLoss(args.train_src_file,
                                              args.val_src_file,
                                              args.train_rel_file,
                                              args.val_rel_file,
                                              args.train_batch_size,
                                              data_tokenizer,
                                              use_num_imgs=args.use_num_imgs,
                                              bi_uni_pipeline=bi_uni_pipeline,
                                              s2s_prob=args.s2s_prob,
                                              bi_prob=args.bi_prob,
                                              is_train=args.do_train,
                                              neg_num=args.neg_num,
                                              inc_gt_rel=args.inc_gt_rel,
                                              inc_full_hist=args.inc_full_hist)

    if args.world_size == 1:
        train_sampler = RandomSampler(train_dataset, replacement=False)
    else:
        train_sampler = DistributedSampler(train_dataset)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        sampler=train_sampler,
        num_workers=args.num_workers,
        collate_fn=batch_list_to_batch_tensors_rank_loss,
        pin_memory=True)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs * 1. /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'img2txt' else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, 'must init from maximum likelihood training'
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            visdial_v=args.visdial_v,
            loss_type=args.loss_type,
            float_nsp_label=args.float_nsp_label,
            rank_loss=args.rank_loss)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(
                os.path.join(args.output_dir,
                             "model.{0}.bin".format(recover_step)))
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total * 1. /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0

        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            visdial_v=args.visdial_v,
            loss_type=args.loss_type,
            float_nsp_label=args.float_nsp_label,
            rank_loss=args.rank_loss)

        del model_recover
        torch.cuda.empty_cache()

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    # cnn.to(device)
    if args.local_rank != -1:
        try:
            # from apex.parallel import DistributedDataParallel as DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
        # cnn = DDP(cnn)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)
        # cnn = DataParallelImbalance(cnn)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             schedule=args.sche_mode,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir,
                         "optim.{0}.bin".format(recover_step)))
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)
        logger.info("  Loader length = %d", len(train_dataloader))

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        logger.info("Begin training from epoch = %d", start_epoch)
        t0 = time.time()
        for i_epoch in trange(start_epoch,
                              args.num_train_epochs + 1,
                              desc="Epoch"):
            if args.multiple_neg and i_epoch != 1:
                train_dataset = VisdialDatasetRelRankLoss(
                    args.train_src_file,
                    args.val_src_file,
                    args.train_rel_file,
                    args.val_rel_file,
                    args.train_batch_size,
                    data_tokenizer,
                    use_num_imgs=args.use_num_imgs,
                    bi_uni_pipeline=bi_uni_pipeline,
                    s2s_prob=args.s2s_prob,
                    bi_prob=args.bi_prob,
                    is_train=args.do_train,
                    neg_num=args.neg_num,
                    inc_gt_rel=args.inc_gt_rel,
                    inc_full_hist=args.inc_full_hist,
                    add_val=args.add_val)

                if args.world_size == 1:
                    train_sampler = RandomSampler(train_dataset,
                                                  replacement=False)
                else:
                    train_sampler = DistributedSampler(train_dataset)

                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=args.train_batch_size,
                    sampler=train_sampler,
                    num_workers=args.num_workers,
                    collate_fn=batch_list_to_batch_tensors_rank_loss,
                    pin_memory=True)
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            losses = []
            pretext_loss = []
            mlm_losses = []
            nsp_losses = []
            zero_batch_cnt = 0
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, \
                task_idx, vis_masked_pos, img, vis_pe = batch

                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                if args.enable_butd:
                    conv_feats = img.data  # Bx100x2048
                    vis_pe = vis_pe.data

                loss_tuple = model(conv_feats,
                                   vis_pe,
                                   input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   vis_masked_pos=vis_masked_pos,
                                   mask_image_regions=args.mask_image_regions,
                                   drop_worst_ratio=args.max_drop_worst_ratio
                                   if i_epoch > args.drop_after else 0)

                # disable pretext_loss_deprecated for now
                masked_lm_loss, pretext_loss_deprecated, nsp_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu. For dist, this is done through gradient addition.
                    masked_lm_loss = masked_lm_loss.mean()
                    pretext_loss_deprecated = pretext_loss_deprecated.mean()
                    nsp_loss = nsp_loss.mean()
                loss = masked_lm_loss + pretext_loss_deprecated + nsp_loss
                # if loss.item() == 0:
                #     zero_batch_cnt += 1
                #     continue

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())
                losses.append(loss.item())
                mlm_losses.append(masked_lm_loss.item())
                pretext_loss.append(pretext_loss_deprecated.item())
                nsp_losses.append(nsp_loss.item())

                if step % max(1, nbatches // 10) == 0:
                    logger.info(
                        "Epoch {}, Iter {}, Loss {:.2f}, MLM {:.2f}, NSP {:.2f}, Elapse time {:.2f}\n"
                        .format(i_epoch, step, np.mean(losses),
                                np.mean(mlm_losses), np.mean(nsp_losses),
                                time.time() - t0))

                if args.enable_visdom:
                    if vis_window['iter'] is None:
                        vis_window['iter'] = vis.line(
                            X=np.tile(
                                np.arange((i_epoch - 1) * nbatches + step,
                                          (i_epoch - 1) * nbatches + step + 1),
                                (1, 1)).T,
                            Y=np.column_stack(
                                (np.asarray([np.mean(losses)]), )),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total']))
                    else:
                        vis.line(X=np.tile(
                            np.arange((i_epoch - 1) * nbatches + step,
                                      (i_epoch - 1) * nbatches + step + 1),
                            (1, 1)).T,
                                 Y=np.column_stack(
                                     (np.asarray([np.mean(losses)]), )),
                                 opts=dict(title='Training Loss',
                                           xlabel='Training Iteration',
                                           ylabel='Loss',
                                           legend=['total']),
                                 win=vis_window['iter'],
                                 update='append')

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                                   warmup_linear(global_step / t_total,
                                                 args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            print("\nFinish one epoch, %d/%d is zero loss batch" %
                  (zero_batch_cnt, nbatches))
            # Save a trained model
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(
                args.output_dir,
                "model.%d.%.3f.bin" % (i_epoch, np.mean(losses)))
            output_optim_file = os.path.join(args.output_dir,
                                             "optim.{0}.bin".format(i_epoch))
            if args.global_rank in (
                    -1, 0):  # save model if the first device or no dist
                torch.save(
                    copy.deepcopy(model_to_save).cpu().state_dict(),
                    output_model_file)
                logger.info("Save model to %s", output_model_file)
                # torch.save(optimizer.state_dict(), output_optim_file) # disable for now, need to sanitize state and ship everthing back to cpu
            logger.info(
                "Finish training epoch %d, avg loss: %.2f and takes %.2f seconds"
                % (i_epoch, np.mean(losses), time.time() - t0))
            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
Beispiel #21
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--local_debug",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()
    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline,
            corpus_preprocessors=corpus_preprocessors)
        train_dataset.initial()
        print(len(train_dataset.ex_list))
        print(train_dataset.batch_size)

        # assert 1==0

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        # c = 0
        # for i_epoch in trange(0, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)):
        #     if args.local_rank != -1:
        #         train_sampler.set_epoch(i_epoch)
        #     iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)',
        #                     disable=args.local_rank not in (-1, 0))
        #     for step, batch in enumerate(iter_bar):
        #         batch = [
        #             t.to(device) if t is not None else None for t in batch]
        #         if args.has_sentence_oracle:
        #             input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
        #         else:
        #             input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch
        #             oracle_pos, oracle_weights, oracle_labels = None, None, None
        #         c += input_ids.shape[0]
        #
        #         # print(input_ids)
        #         #
        #         # print(input_ids.shape)
        #         # print(segment_ids)
        #         # print(segment_ids.shape)
        #         # print(is_next)
        #         # print(task_idx)
        #         # print(sop_label)
        #         # print(task_idx.shape)
        #         # for i in range(input_mask.shape[0]):
        #         #     print(input_mask[i])
        #     print(c)
        #     print(train_dataset.c)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
                else:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch
                    print(sop_label)
                    print(task_idx)
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                # loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next,
                #                    masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx,
                #                    masked_pos_2=oracle_pos, masked_weights_2=oracle_weights,
                #                    masked_labels_2=oracle_labels, mask_qkv=mask_qkv)
                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   sop_label,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv)
                masked_lm_loss, next_sentence_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                print('mask_lm_loss {}'.format(masked_lm_loss))
                print('next_sentence_loss {}'.format(next_sentence_loss))
                print('----------------------------------------------')
                loss = masked_lm_loss + next_sentence_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
Beispiel #22
0
class QA():
    def __init__(self, config):
        super(QA).__init__()
        self.config = config
        self.setup_logger()
        self.setup_gpu()
        self.load_data()
        self.prepare_train()
        self.setup_model()

    def setup_logger(self):
        self.log = Logger()
        self.log.open((os.path.join(self.config.checkpoint_folder, "train_log.txt")), mode='a+')

    def setup_gpu(self):
        # confirm the device which can be either cpu or gpu
        self.config.use_gpu = torch.cuda.is_available()
        self.num_device = torch.cuda.device_count()
        if self.config.use_gpu:
            self.config.device = 'cuda'
            if self.num_device <= 1:
                self.config.data_parallel = False
            elif self.config.data_parallel:
                torch.multiprocessing.set_start_method('spawn', force=True)
        else:
            self.config.device = 'cpu'
            self.config.data_parallel = False

    def load_data(self):
        self.log.write('\nLoading data...')

        get_train_val_split(data_path=self.config.data_path,
                            save_path=self.config.save_path,
                            n_splits=self.config.n_splits,
                            seed=self.config.seed,
                            split=self.config.split)

        self.test_data_loader, self.tokenizer = get_test_loader(data_path=self.config.data_path,
                                                max_seq_length=self.config.max_seq_length,
                                                model_type=self.config.model_type,
                                                batch_size=self.config.val_batch_size,
                                                num_workers=self.config.num_workers)

        self.train_data_loader, self.val_data_loader, _ = get_train_val_loaders(data_path=self.config.data_path,
                                                                  seed=self.config.seed,
                                                                  fold=self.config.fold,
                                                                  max_seq_length=self.config.max_seq_length,
                                                                  model_type=self.config.model_type,
                                                                  batch_size=self.config.batch_size,
                                                                  val_batch_size=self.config.val_batch_size,
                                                                  num_workers=self.config.num_workers,
                                                                  Datasampler=self.config.Datasampler)

    def prepare_train(self):
        # preparation for training
        self.step = 0
        self.epoch = 0
        self.finished = False
        self.valid_epoch = 0
        self.train_loss, self.valid_loss, self.valid_metric_optimal = float('-inf'), float('-inf'), float('-inf')
        self.writer = SummaryWriter()
        ############################################################################### eval setting
        self.eval_step = int(len(self.train_data_loader) * self.config.saving_rate)
        self.log_step = int(len(self.train_data_loader) * self.config.progress_rate)
        self.eval_count = 0
        self.count = 0

    def pick_model(self):
        # for switching model
        self.model = TweetBert(model_type=self.config.model_type,
                               hidden_layers=self.config.hidden_layers).to(self.config.device)

        if self.config.load_pretrain:
            checkpoint_to_load = torch.load(self.config.checkpoint_pretrain, map_location=self.config.device)
            model_state_dict = checkpoint_to_load

            if self.config.data_parallel:
                state_dict = self.model.model.state_dict()
            else:
                state_dict = self.model.state_dict()

            keys = list(state_dict.keys())

            for key in keys:
                if any(s in key for s in (self.config.skip_layers + ["qa_classifier.weight", "qa_classifier.bias"])):
                    continue
                try:
                    state_dict[key] = model_state_dict[key]
                except:
                    print("Missing key:", key)

            if self.config.data_parallel:
                self.model.model.load_state_dict(state_dict)
            else:
                self.model.load_state_dict(state_dict)

        # offsets for context
        if self.config.model_type == "roberta-base" or self.config.model_type == "roberta-large" or \
                self.config.model_type == "roberta-base-squad":

            self.offsets = 4

        elif (self.config.model_type == "albert-base-v2") or (self.config.model_type == "albert-large-v2") or \
                (self.config.model_type == "albert-xlarge-v2"):

            self.offsets = 3

        elif (self.config.model_type == "xlnet-base-cased") or (self.config.model_type == "xlnet-large-cased"):

            self.offsets = 2

        elif (self.config.model_type == "bert-base-uncased") or (self.config.model_type == "bert-large-uncased") or \
                (self.config.model_type == "bert-base-cased") or (self.config.model_type == "bert-large-cased") or \
                (self.config.model_type == "electra-base") or (self.config.model_type == "electra-large"):

            self.offsets = 3

        else:
            raise NotImplementedError


    def differential_lr(self):

        param_optimizer = list(self.model.named_parameters())

        def is_backbone(n):
            prefix = "bert"
            return prefix in n

        def is_cross_attention(n):
            cross_attention = "cross_attention"
            return cross_attention in n

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        self.optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and is_backbone(n)],
             'lr': self.config.min_lr,
             'weight_decay': self.config.weight_decay},
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and is_cross_attention(n)],
             'lr': self.config.max_lr,
             'weight_decay': self.config.weight_decay},
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and not is_backbone(n)
                        and not is_cross_attention(n)],
             'lr': self.config.lr,
             'weight_decay': self.config.weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and is_backbone(n)],
             'lr': self.config.min_lr,
             'weight_decay': 0.0},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and is_cross_attention(n)],
             'lr': self.config.max_lr,
             'weight_decay': 0.0},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and not is_backbone(n)
                        and not is_cross_attention(n)],
             'lr': self.config.lr,
             'weight_decay': 0.0}
        ]
        # self.optimizer_grouped_parameters = [
        #     {'params': [p for n, p in param_optimizer],
        #      'lr': self.config.min_lr,
        #      'weight_decay': 0}
        # ]

    def prepare_optimizer(self):

        # differential lr for each sub module first
        self.differential_lr()

        # optimizer
        if self.config.optimizer_name == "Adam":
            self.optimizer = torch.optim.Adam(self.optimizer_grouped_parameters, eps=self.config.adam_epsilon)
        elif self.config.optimizer_name == "Ranger":
            self.optimizer = Ranger(self.optimizer_grouped_parameters)
        elif self.config.optimizer_name == "AdamW":
            self.optimizer = AdamW(self.optimizer_grouped_parameters,
                                   eps=self.config.adam_epsilon,
                                   betas=(0.9, 0.999))
        elif self.config.optimizer_name == "FusedAdam":
            self.optimizer = FusedAdam(self.optimizer_grouped_parameters,
                                       bias_correction=False)
        else:
            raise NotImplementedError

        # lr scheduler
        if self.config.lr_scheduler_name == "WarmupCosineAnealing":
            num_train_optimization_steps = self.config.num_epoch * len(self.train_data_loader) \
                                           // self.config.accumulation_steps
            self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
                                                             num_warmup_steps=self.config.warmup_steps,
                                                             num_training_steps=num_train_optimization_steps)
            self.lr_scheduler_each_iter = True
        elif self.config.lr_scheduler_name == "WarmRestart":
            self.scheduler = WarmRestart(self.optimizer, T_max=5, T_mult=1, eta_min=1e-6)
            self.lr_scheduler_each_iter = False
        elif self.config.lr_scheduler_name == "WarmupLinear":
            num_train_optimization_steps = self.config.num_epoch * len(self.train_data_loader) \
                                           // self.config.accumulation_steps
            self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
                                                             num_warmup_steps=self.config.warmup_steps,
                                                             num_training_steps=num_train_optimization_steps)
            self.lr_scheduler_each_iter = True
        elif self.config.lr_scheduler_name == "ReduceLROnPlateau":
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=0.6,
                                                                        patience=1, min_lr=1e-7)
            self.lr_scheduler_each_iter = False
        elif self.config.lr_scheduler_name == "WarmupConstant":
            self.scheduler = get_constant_schedule_with_warmup(self.optimizer,
                                                             num_warmup_steps=self.config.warmup_steps)
            self.lr_scheduler_each_iter = True
        else:
            raise NotImplementedError

        # lr scheduler step for checkpoints
        if self.lr_scheduler_each_iter:
            self.scheduler.step(self.step)
        else:
            self.scheduler.step(self.epoch)

    def prepare_apex(self):
        self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")

    def load_check_point(self):
        self.log.write('Model loaded as {}.'.format(self.config.load_point))
        checkpoint_to_load = torch.load(self.config.load_point, map_location=self.config.device)
        self.step = checkpoint_to_load['step']
        self.epoch = checkpoint_to_load['epoch']

        model_state_dict = checkpoint_to_load['model']
        if self.config.load_from_load_from_data_parallel:
            # model_state_dict = {k[7:]: v for k, v in model_state_dict.items()}
            # "module.model"
            model_state_dict = {k[13:]: v for k, v in model_state_dict.items()}

        if self.config.data_parallel:
            state_dict = self.model.model.state_dict()
        else:
            state_dict = self.model.state_dict()

        keys = list(state_dict.keys())

        for key in keys:
            if any(s in key for s in self.config.skip_layers):
                continue
            try:
                state_dict[key] = model_state_dict[key]
            except:
                print("Missing key:", key)

        if self.config.data_parallel:
            self.model.model.load_state_dict(state_dict)
        else:
            self.model.load_state_dict(state_dict)

        if self.config.load_optimizer:
            self.optimizer.load_state_dict(checkpoint_to_load['optimizer'])

    def save_check_point(self):
        # save model, optimizer, and everything required to keep
        checkpoint_to_save = {
            'step': self.step,
            'epoch': self.epoch,
            'model': self.model.state_dict(),
            # 'optimizer': self.optimizer.state_dict()
            }

        save_path = self.config.save_point.format(self.step, self.epoch)
        torch.save(checkpoint_to_save, save_path)
        self.log.write('Model saved as {}.'.format(save_path))

    def setup_model(self):
        self.pick_model()

        if self.config.data_parallel:
            self.prepare_optimizer()

            if self.config.apex:
                self.prepare_apex()

            if self.config.reuse_model:
                self.load_check_point()

            self.model = torch.nn.DataParallel(self.model)

        else:
            if self.config.reuse_model:
                self.load_check_point()

            self.prepare_optimizer()

            if self.config.apex:
                self.prepare_apex()

    def count_parameters(self):
        # get total size of trainable parameters
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    def show_info(self):
        # show general information before training
        self.log.write('\n*General Setting*')
        self.log.write('\nseed: {}'.format(self.config.seed))
        self.log.write('\nmodel: {}'.format(self.config.model_name))
        self.log.write('\ntrainable parameters:{:,.0f}'.format(self.count_parameters()))
        self.log.write("\nmodel's state_dict:")
        self.log.write('\ndevice: {}'.format(self.config.device))
        self.log.write('\nuse gpu: {}'.format(self.config.use_gpu))
        self.log.write('\ndevice num: {}'.format(self.num_device))
        self.log.write('\noptimizer: {}'.format(self.optimizer))
        self.log.write('\nreuse model: {}'.format(self.config.reuse_model))
        self.log.write('\nadversarial training: {}'.format(self.config.adversarial))
        if self.config.reuse_model:
            self.log.write('\nModel restored from {}.'.format(self.config.load_point))
        self.log.write('\n')

    def train_op(self):
        self.show_info()
        self.log.write('** start training here! **\n')
        self.log.write('   batch_size=%d,  accumulation_steps=%d\n' % (self.config.batch_size,
                                                                       self.config.accumulation_steps))
        self.log.write('   experiment  = %s\n' % str(__file__.split('/')[-2:]))

        while self.epoch <= self.config.num_epoch:

            self.train_metrics = []
            self.train_metrics_no_postprocessing = []
            self.train_metrics_postprocessing = []
            self.train_ans_acc = []
            self.train_noise_acc = []

            # update lr and start from start_epoch
            if (self.epoch >= 1) and (not self.lr_scheduler_each_iter) \
                    and (self.config.lr_scheduler_name != "ReduceLROnPlateau"):
                self.scheduler.step()

            self.log.write("Epoch%s\n" % self.epoch)
            self.log.write('\n')

            sum_train_loss = np.zeros_like(self.train_loss)
            sum_train = np.zeros_like(self.train_loss)

            # init optimizer
            torch.cuda.empty_cache()
            self.model.zero_grad()

            for tr_batch_i, (
                    all_input_ids, all_attention_masks, all_token_type_ids,
                    all_start_positions, all_end_positions,
                    all_onehot_sentiment_type, all_onehot_ans_type, all_onehot_noise_type,
                    all_orig_tweet, all_orig_tweet_with_extra_space, all_orig_selected,
                    all_sentiment, all_ans, all_noise,
                    all_offsets_token_level, all_offsets_word_level) in enumerate(self.train_data_loader):

                rate = 0
                for param_group in self.optimizer.param_groups:
                    rate += param_group['lr'] / len(self.optimizer.param_groups)

                # set model training mode
                self.model.train()

                # set input to cuda mode
                all_input_ids = all_input_ids.to(self.config.device)
                all_attention_masks = all_attention_masks.to(self.config.device)
                all_token_type_ids = all_token_type_ids.to(self.config.device)
                all_start_positions = all_start_positions.to(self.config.device)
                all_end_positions = all_end_positions.to(self.config.device)
                all_onehot_sentiment_type = all_onehot_sentiment_type.to(self.config.device)
                all_onehot_ans_type = all_onehot_ans_type.to(self.config.device)
                all_onehot_noise_type = all_onehot_noise_type.to(self.config.device)

                sentiment = all_sentiment
                sentiment_weight = np.array([self.config.sentiment_weight_map[sentiment_] for sentiment_ in sentiment])
                sentiment_weight = torch.tensor(sentiment_weight).float().to(self.config.device)

                ans = all_ans
                ans_weight = np.array([self.config.ans_weight_map[ans_] for ans_ in ans])
                ans_weight = torch.tensor(ans_weight).float().to(self.config.device)

                noise = all_noise
                noise_weight = np.array([self.config.noise_weight_map[noise_] for noise_ in noise])
                noise_weight = torch.tensor(noise_weight).float().to(self.config.device)


                outputs = self.model(input_ids=all_input_ids,
                                     attention_mask=all_attention_masks,
                                     token_type_ids=all_token_type_ids,
                                     start_positions=all_start_positions,
                                     end_positions=all_end_positions,
                                     onehot_sentiment_type=all_onehot_sentiment_type,
                                     onehot_ans_type=all_onehot_ans_type,
                                     onehot_noise_type=all_onehot_noise_type,
                                     sentiment_weight=sentiment_weight,
                                     ans_weight=ans_weight,
                                     noise_weight=noise_weight)

                loss, start_logits, end_logits, ans_logits, noise_logits = outputs[0], outputs[1], outputs[2], \
                                                                           outputs[3], outputs[4]

                # use apex
                if self.config.apex:
                    with amp.scale_loss(loss / self.config.accumulation_steps, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    # print(loss, all_start_positions, all_end_positions, all_orig_tweet)
                    loss.backward()

                # adversarial training
                if self.config.adversarial:
                    try:
                        with torch.autograd.detect_anomaly():
                            self.model.attack()
                            outputs_adv = self.model(input_ids=all_input_ids, attention_mask=all_attention_masks,
                                                     token_type_ids=all_token_type_ids, start_positions=all_start_positions,
                                                     end_positions=all_end_positions, onehot_ans_type=all_onehot_ans_type,
                                                 sentiment_weight=sentiment_weight)
                            loss_adv = outputs_adv[0]

                            # use apex
                            if self.config.apex:
                                with amp.scale_loss(loss_adv / self.config.accumulation_steps, self.optimizer) as scaled_loss_adv:
                                    scaled_loss_adv.backward()
                                    self.model.restore()
                            else:
                                loss_adv.backward()
                                self.model.restore()
                    except:
                        print("NAN LOSS")

                if ((tr_batch_i + 1) % self.config.accumulation_steps == 0):
                    if self.config.apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.config.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                    self.optimizer.step()
                    self.model.zero_grad()
                    # adjust lr
                    if (self.lr_scheduler_each_iter):
                        self.scheduler.step()

                    self.writer.add_scalar('train_loss_' + str(self.config.fold), loss.item(),
                                           (self.epoch - 1) * len(
                                               self.train_data_loader) * self.config.batch_size + tr_batch_i *
                                           self.config.batch_size)
                    self.step += 1

                # translate to predictions
                start_logits = start_logits[:, self.offsets:]
                end_logits = end_logits[:, self.offsets:]

                start_logits = torch.softmax(start_logits, dim=-1)
                end_logits = torch.softmax(end_logits, dim=-1)
                ans_logits = torch.argmax(ans_logits, dim=-1)
                noise_logits = torch.argmax(noise_logits, dim=-1)
                all_onehot_ans_type = torch.argmax(all_onehot_ans_type, dim=-1)
                all_onehot_noise_type = torch.argmax(all_onehot_noise_type, dim=-1)

                def to_numpy(tensor):
                    return tensor.detach().cpu().numpy()

                start_logits = to_numpy(start_logits)
                end_logits = to_numpy((end_logits))
                ans_logits = to_numpy(ans_logits)
                noise_logits = to_numpy(noise_logits)
                all_onehot_ans_type = to_numpy(all_onehot_ans_type)
                all_onehot_noise_type = to_numpy(all_onehot_noise_type)

                for px, orig_tweet in enumerate(all_orig_tweet):

                    start_logits_word_level, end_logits_word_level, word_level_bbx = get_word_level_logits(
                        start_logits[px],
                        end_logits[px],
                        self.config.model_type,
                        all_offsets_word_level[px])

                    start_idx_token, end_idx_token = get_token_level_idx(start_logits[px],
                                                                         end_logits[px],
                                                                         start_logits_word_level,
                                                                         end_logits_word_level,
                                                                         word_level_bbx)

                    # start_idx_token, end_idx_token = start_logits[px].argmax(-1), end_logits[px].argmax(-1)

                    selected_tweet = all_orig_selected[px]
                    jaccard_score, final_text = calculate_jaccard_score(
                        original_tweet=orig_tweet,
                        selected_text=selected_tweet,
                        idx_start=start_idx_token,
                        idx_end=end_idx_token,
                        model_type=self.config.model_type,
                        tweet_offsets=all_offsets_token_level[px],
                    )

                    # if (sentiment[px] == "neutral" or len(all_orig_tweet[px].split()) < 3):
                    if ans_logits[px] == 0:

                        final_text = all_orig_tweet_with_extra_space[px]
                        final_text = pp_v2(all_orig_tweet_with_extra_space[px], final_text)

                        self.train_metrics_postprocessing.append(jaccard(final_text.strip(), selected_tweet.strip()))
                        self.train_metrics.append(jaccard(final_text.strip(), selected_tweet.strip()))
                    else:

                        final_text = pp_v2(all_orig_tweet_with_extra_space[px], final_text)

                        self.train_metrics_no_postprocessing.append(jaccard(final_text.strip(), selected_tweet.strip()))
                        self.train_metrics.append(jaccard(final_text.strip(), selected_tweet.strip()))

                    if ans_logits[px] == all_onehot_ans_type[px]:
                        self.train_ans_acc.append(1)
                    else:
                        self.train_ans_acc.append(0)

                    if noise_logits[px] == all_onehot_noise_type[px]:
                        self.train_noise_acc.append(1)
                    else:
                        self.train_noise_acc.append(0)

                l = np.array([loss.item() * self.config.batch_size])
                n = np.array([self.config.batch_size])
                sum_train_loss = sum_train_loss + l
                sum_train = sum_train + n

                # log for training
                if (tr_batch_i + 1) % self.log_step == 0:
                    train_loss = sum_train_loss / (sum_train + 1e-12)
                    sum_train_loss[...] = 0
                    sum_train[...] = 0
                    mean_train_metric = np.mean(self.train_metrics)
                    mean_train_metric_postprocessing = np.mean(self.train_metrics_postprocessing)
                    mean_train_metric_no_postprocessing = np.mean(self.train_metrics_no_postprocessing)
                    mean_train_ans_acc = np.mean(self.train_ans_acc)
                    mean_train_noise_acc = np.mean(self.train_noise_acc)

                    self.log.write('lr: %f train loss: %f train_jaccard: %f train_jaccard_postprocessing: %f train_jaccard_no_postprocessing: %f train_ans_acc: %f train_noise_acc: %f\n' % \
                                   (rate, train_loss[0], mean_train_metric, mean_train_metric_postprocessing,
                                    mean_train_metric_no_postprocessing, mean_train_ans_acc, mean_train_noise_acc))

                    print("Training ground truth: ", selected_tweet)
                    print("Training prediction: ", final_text)

                if (tr_batch_i + 1) % self.eval_step == 0:
                    self.evaluate_op()

            if self.count >= self.config.early_stopping:
                break

            self.epoch += 1

    def evaluate_op(self):

        self.eval_count += 1
        valid_loss = np.zeros(1, np.float32)
        valid_num = np.zeros_like(valid_loss)

        self.eval_metrics = []
        self.eval_metrics_no_postprocessing = []
        self.eval_metrics_postprocessing = []
        self.eval_ans_acc = []
        self.eval_noise_acc = []

        all_result = []

        with torch.no_grad():

            # init cache
            torch.cuda.empty_cache()

            for val_batch_i, (
                    all_input_ids, all_attention_masks, all_token_type_ids,
                    all_start_positions, all_end_positions,
                    all_onehot_sentiment_type, all_onehot_ans_type, all_onehot_noise_type,
                    all_orig_tweet, all_orig_tweet_with_extra_space, all_orig_selected,
                    all_sentiment, all_ans, all_noise,
                    all_offsets_token_level, all_offsets_word_level) in enumerate(self.val_data_loader):

                # set model to eval mode
                self.model.eval()

                # set input to cuda mode
                all_input_ids = all_input_ids.to(self.config.device)
                all_attention_masks = all_attention_masks.to(self.config.device)
                all_token_type_ids = all_token_type_ids.to(self.config.device)
                all_start_positions = all_start_positions.to(self.config.device)
                all_end_positions = all_end_positions.to(self.config.device)
                all_onehot_sentiment_type = all_onehot_sentiment_type.to(self.config.device)
                all_onehot_ans_type = all_onehot_ans_type.to(self.config.device)
                all_onehot_noise_type = all_onehot_noise_type.to(self.config.device)

                sentiment = all_sentiment

                outputs = self.model(input_ids=all_input_ids,
                                     attention_mask=all_attention_masks,
                                     token_type_ids=all_token_type_ids,
                                     start_positions=all_start_positions,
                                     end_positions=all_end_positions,
                                     onehot_sentiment_type=all_onehot_sentiment_type,
                                     onehot_ans_type=all_onehot_ans_type,
                                     onehot_noise_type=all_onehot_noise_type)

                loss, start_logits, end_logits, ans_logits, noise_logits = outputs[0], outputs[1], outputs[2], \
                                                                           outputs[3], outputs[4]

                self.writer.add_scalar('val_loss_' + str(self.config.fold), loss.item(), (self.eval_count - 1) * len(
                    self.val_data_loader) * self.config.val_batch_size + val_batch_i * self.config.val_batch_size)

                # translate to predictions
                start_logits = start_logits[:, self.offsets:]
                end_logits = end_logits[:, self.offsets:]

                start_logits = torch.softmax(start_logits, dim=-1)
                end_logits = torch.softmax(end_logits, dim=-1)
                ans_logits = torch.argmax(ans_logits, dim=-1)
                noise_logits = torch.argmax(noise_logits, dim=-1)
                all_onehot_ans_type = torch.argmax(all_onehot_ans_type, dim=-1)
                all_onehot_noise_type = torch.argmax(all_onehot_noise_type, dim=-1)

                def to_numpy(tensor):
                    return tensor.detach().cpu().numpy()

                start_logits = to_numpy(start_logits)
                end_logits = to_numpy((end_logits))
                ans_logits = to_numpy(ans_logits)
                noise_logits = to_numpy(noise_logits)
                all_onehot_ans_type = to_numpy(all_onehot_ans_type)
                all_onehot_noise_type = to_numpy(all_onehot_noise_type)

                for px, orig_tweet in enumerate(all_orig_tweet):

                    start_logits_word_level, end_logits_word_level, word_level_bbx = get_word_level_logits(
                        start_logits[px],
                        end_logits[px],
                        self.config.model_type,
                        all_offsets_word_level[px])

                    start_idx_token, end_idx_token = get_token_level_idx(start_logits[px],
                                                                         end_logits[px],
                                                                         start_logits_word_level,
                                                                         end_logits_word_level,
                                                                         word_level_bbx)

                    # start_idx_token, end_idx_token = start_logits[px].argmax(-1), end_logits[px].argmax(-1)

                    selected_tweet = all_orig_selected[px]

                    jaccard_score, final_text = calculate_jaccard_score(
                        original_tweet=orig_tweet,
                        selected_text=selected_tweet,
                        idx_start=start_idx_token,
                        idx_end=end_idx_token,
                        model_type=self.config.model_type,
                        tweet_offsets=all_offsets_token_level[px],
                    )

                    # if (sentiment[px] == "neutral" or len(all_orig_tweet[px].split()) < 3):
                    if ans_logits[px] == 0:

                        final_text = all_orig_tweet_with_extra_space[px]
                        final_text = pp_v2(all_orig_tweet_with_extra_space[px], final_text)

                        self.eval_metrics_postprocessing.append(jaccard(final_text.strip(), selected_tweet.strip()))
                        self.eval_metrics.append(jaccard(final_text.strip(), selected_tweet.strip()))
                    else:

                        final_text = pp_v2(all_orig_tweet_with_extra_space[px], final_text)

                        self.eval_metrics_no_postprocessing.append(jaccard(final_text.strip(), selected_tweet.strip()))
                        self.eval_metrics.append(jaccard(final_text.strip(), selected_tweet.strip()))

                    all_result.append(final_text)

                    if ans_logits[px] == all_onehot_ans_type[px]:
                        self.eval_ans_acc.append(1)
                    else:
                        self.eval_ans_acc.append(0)

                    if noise_logits[px] == all_onehot_noise_type[px]:
                        self.eval_noise_acc.append(1)
                    else:
                        self.eval_noise_acc.append(0)

                l = np.array([loss.item() * self.config.val_batch_size])
                n = np.array([self.config.val_batch_size])
                valid_loss = valid_loss + l
                valid_num = valid_num + n

            valid_loss = valid_loss / valid_num
            mean_eval_metric = np.mean(self.eval_metrics)
            mean_eval_metric_postprocessing = np.mean(self.eval_metrics_postprocessing)
            mean_eval_metric_no_postprocessing = np.mean(self.eval_metrics_no_postprocessing)
            mean_eval_ans_acc = np.mean(self.eval_ans_acc)
            mean_eval_noise_acc = np.mean(self.eval_noise_acc)

            self.log.write('validation loss: %f eval_jaccard: %f eval_jaccard_postprocessing: %f eval_jaccard_no_postprocessing: %f eval_ans_acc: %f eval_noise_acc: %f\n' % \
                           (valid_loss[0], mean_eval_metric, mean_eval_metric_postprocessing,
                            mean_eval_metric_no_postprocessing, mean_eval_ans_acc, mean_eval_noise_acc))
            print("Validating ground truth: ", selected_tweet)
            print("Validating prediction: ", final_text)

        if self.config.lr_scheduler_name == "ReduceLROnPlateau":
            self.scheduler.step(mean_eval_metric)

        if (mean_eval_metric >= self.valid_metric_optimal):

            self.log.write('Validation metric improved ({:.6f} --> {:.6f}).  Saving model ...'.format(
                self.valid_metric_optimal, mean_eval_metric))

            self.valid_metric_optimal = mean_eval_metric
            self.save_check_point()

            self.count = 0

        else:
            self.count += 1

        val_df = pd.DataFrame({'selected_text': all_result})
        val_df.to_csv(os.path.join(self.config.checkpoint_folder, "val_prediction_{}_{}.csv".format(self.config.seed,
                                                                                                    self.config.fold)),
                      index=False)

    def find_errors(self):

        scores = []
        bad_predictions = []
        bad_labels = []
        bad_text = []
        bad_scores = []

        for fold in range(5):
            checkpoint_folder = os.path.join(self.config.checkpoint_folder_all_fold, 'fold_' + str(fold) + '/')
            val_pred = pd.read_csv(os.path.join(checkpoint_folder, "val_prediction_{}_{}.csv".format(self.config.seed,
                                                                                                   fold)))
            val_label = pd.read_csv(os.path.join(checkpoint_folder, "val_fold_{}_seed_{}.csv".format(fold,
                                                                                                     self.config.seed)))
            pred_text = val_pred.selected_text
            label_text = val_label.selected_text
            whole_text = val_label.text

            for i, label_string in enumerate(label_text):
                pred_string = pred_text[i]

                try:
                    if pred_string[0] == " ":
                        print(pred_string[1:])
                        print(label_string)
                        print(jaccard(label_string.strip(), pred_string.strip()), jaccard(label_string.strip(), pred_string[1:].strip()))
                        print("_______________________________________")
                        pred_string = pred_string[1:]
                    jac = jaccard(label_string.strip(), pred_string.strip())
                except:
                    continue

                scores.append(jac)

                if jac < 0.5:
                    bad_scores.append(jac)
                    bad_text.append(whole_text[i])
                    bad_predictions.append(pred_string)
                    bad_labels.append(label_string)

        bad_samples = pd.DataFrame({"score": bad_scores, "prediction": bad_predictions, "label": bad_labels, "text": bad_text})
        bad_samples = bad_samples.sort_values(by=["score"])

        bad_samples.to_csv(os.path.join(self.config.checkpoint_folder_all_fold, "bad_samples.csv"))

        print(np.mean(scores))

        return
Beispiel #23
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    #Train File
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data src file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The input data tgt file name.")
    parser.add_argument("--check_file",
                        default=None,
                        type=str,
                        help="The input check knowledge data file name")

    #KS File
    parser.add_argument("--ks_src_file",
                        default=None,
                        type=str,
                        help="The input ks data src file name.")
    parser.add_argument("--ks_tgt_file",
                        default=None,
                        type=str,
                        help="The input ks data tgt file name.")

    parser.add_argument("--predict_input_file",
                        default=None,
                        type=str,
                        help="predict_input_file")
    parser.add_argument("--predict_output_file",
                        default=None,
                        type=str,
                        help="predict_output_file")

    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    parser.add_argument("--predict_bleu",
                        default=0.2,
                        type=float,
                        help="The Predicted Bleu for KS Predict ")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run ks predict.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--train_avg_bpe_length",
                        default=25,
                        type=int,
                        help="average bpe length for train.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion_step",
        default=300,
        type=int,
        help=
        "Proportion of training to perform linear learning rate warmup for. ")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=67,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)

    handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"),
                                  encoding='UTF-8')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)

    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)

    logger.addHandler(handler)
    logger.addHandler(console)

    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    #Random Seed

    #torch.backends.cudnn.enabled = False
    #torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
    # if n_gpu > 0:
    # 	torch.cuda.manual_seed_all(args.seed)

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    #Data process pipelines
    bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    C_bi_uni_pipeline = [
        seq2seq_loader.C_Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]
    ks_predict_bi_uni_pipeline = [
        seq2seq_loader.Preprocess4Seq2seq_predict(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]

    if args.do_train:
        print("Loading QKR Train Dataset", args.data_dir)
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        fn_check = os.path.join(args.data_dir, args.check_file)

        train_dataset = seq2seq_loader.C_Seq2SeqDataset(
            fn_src,
            fn_tgt,
            fn_check,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=C_bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        print("Loading KS Train Dataset", args.data_dir)
        ks_fn_src = os.path.join(args.data_dir, args.ks_src_file)
        ks_fn_tgt = os.path.join(args.data_dir, args.ks_tgt_file)
        ks_train_dataset = seq2seq_loader.Seq2SeqDataset(
            ks_fn_src,
            ks_fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            ks_train_sampler = RandomSampler(ks_train_dataset,
                                             replacement=False)
            _batch_size = args.train_batch_size
        else:
            ks_train_sampler = DistributedSampler(ks_train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        ks_train_dataloader = torch.utils.data.DataLoader(
            ks_train_dataset,
            batch_size=_batch_size,
            sampler=ks_train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

        # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
        t_total = int(
            len(train_dataloader) * args.num_train_epochs /
            args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + (
        1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()

    #Recover model
    if args.model_recover_path:
        logger.info(" ** ** * Recover model: %s ** ** * ",
                    args.model_recover_path)
        model_recover = torch.load(args.model_recover_path, map_location='cpu')
        global_step = 0

    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    model = BertForPreTrainingLossMask.from_pretrained(
        args.bert_model,
        state_dict=model_recover,
        num_labels=cls_num_labels,
        num_rel=0,
        type_vocab_size=type_vocab_size,
        config_path=args.config_path,
        task_idx=3,
        num_sentlvl_labels=num_sentlvl_labels,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing,
        fp32_embedding=args.fp32_embedding,
        relax_projection=relax_projection,
        new_pos_ids=args.new_pos_ids,
        ffn_type=args.ffn_type,
        hidden_dropout_prob=args.hidden_dropout_prob,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        num_qkv=args.num_qkv,
        seg_emb=args.seg_emb,
        mask_word_id=mask_word_id,
        search_beam_size=5,
        length_penalty=0,
        eos_id=eos_word_ids,
        sos_id=sos_word_id,
        forbid_duplicate_ngrams=True,
        forbid_ignore_set=None,
        mode="s2s")

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)

    model.tmp_bert_emb.word_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.word_embeddings.weight.clone())
    model.tmp_bert_emb.token_type_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.token_type_embeddings.weight.clone())
    model.tmp_bert_emb.position_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.position_embeddings.weight.clone())
    model.mul_bert_emb.word_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.word_embeddings.weight.clone())
    model.mul_bert_emb.token_type_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.token_type_embeddings.weight.clone())
    model.mul_bert_emb.position_embeddings.weight = torch.nn.Parameter(
        model.bert.embeddings.position_embeddings.weight.clone())
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from pytorch_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if args.optim_recover_path is not None:
        logger.info(" ** ** * Recover optimizer from : {} ** ** * ".format(
            args.optim_recover_path))
        optim_recover = torch.load(args.optim_recover_path, map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info(
                " ** ** * Recover optimizer: dynamic_loss_scale ** ** * ")
            optimizer.dynamic_loss_scale = True

    #logger.info(" ** ** * CUDA.empty_cache() ** ** * ")
    torch.cuda.empty_cache()

    # ################# TRAIN ############################ #
    if args.do_train:
        max_F1 = 0
        best_step = 0
        logger.info(" ** ** * Running training ** ** * ")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        start_epoch = 1

        for i_epoch in trange(start_epoch,
                              start_epoch + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)

            step = 0
            for batch, ks_batch in zip(train_dataloader, ks_train_dataloader):
                # ################# E step + M step + Mutual Information Loss ############################ #
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, check_ids = batch
                oracle_pos, oracle_weights, oracle_labels = None, None, None

                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv,
                                   tgt_pos=tgt_pos,
                                   labels=labels.half(),
                                   ks_labels=ks_labels,
                                   check_ids=check_ids)

                masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, predict_kl_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                    Mutual_loss = Mutual_loss.mean()
                    Golden_loss = Golden_loss.mean()
                    KL_loss = KL_loss.mean()
                    predict_kl_loss = predict_kl_loss.mean()

                loss = masked_lm_loss + next_sentence_loss + KL_loss + predict_kl_loss + Mutual_loss + Golden_loss
                logger.info("In{}step, masked_lm_loss:{}".format(
                    step, masked_lm_loss))
                logger.info("In{}step, KL_loss:{}".format(step, KL_loss))
                logger.info("In{}step, Mutual_loss:{}".format(
                    step, Mutual_loss))
                logger.info("In{}step, Golden_loss:{}".format(
                    step, Golden_loss))
                logger.info("In{}step, predict_kl_loss:{}".format(
                    step, predict_kl_loss))

                logger.info("******************************************* ")

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total,
                        args.warmup_proportion_step / t_total)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                # ################# Knowledge Selection Loss ############################ #
                if random.randint(0, 4) == 0:
                    ks_batch = [
                        t.to(device) if t is not None else None
                        for t in ks_batch
                    ]

                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, _, labels, ks_labels = ks_batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                    loss_tuple = model(input_ids,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv,
                                       labels=labels,
                                       ks_labels=ks_labels,
                                       train_ks=True)

                    ks_loss, _ = loss_tuple
                    if n_gpu > 1:  # mean() to average on multi-gpu.
                        ks_loss = ks_loss.mean()
                    loss = ks_loss

                    logger.info("In{}step, ks_loss:{}".format(step, ks_loss))
                    logger.info("******************************************* ")

                    # ensure that accumlated gradients are normalized
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        optimizer.backward(loss)
                        if amp_handle:
                            amp_handle._clear_cache()
                    else:
                        loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / t_total,
                            args.warmup_proportion_step / t_total)
                        if args.fp16:
                            # modify learning rate with special warm up BERT uses
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr_this_step
                        optimizer.step()
                        optimizer.zero_grad()

                step += 1
                ###################### Eval Every 5000 Step ############################ #
                if (global_step + 1) % 5000 == 0:
                    next_i = 0
                    model.eval()

                    # Know Rank Stage
                    logger.info(" ** ** * DEV Know Selection Begin ** ** * ")
                    with open(os.path.join(args.data_dir,
                                           args.predict_input_file),
                              "r",
                              encoding="utf-8") as file:
                        src_file = file.readlines()
                    with open(os.path.join(args.data_dir,
                                           "train_tgt_pad.empty"),
                              "r",
                              encoding="utf-8") as file:
                        tgt_file = file.readlines()
                    with open(os.path.join(args.data_dir,
                                           args.predict_output_file),
                              "w",
                              encoding="utf-8") as out:
                        while next_i < len(src_file):
                            batch_src = src_file[next_i:next_i +
                                                 args.eval_batch_size]
                            batch_tgt = tgt_file[next_i:next_i +
                                                 args.eval_batch_size]

                            next_i += args.eval_batch_size

                            ex_list = []
                            for src, tgt in zip(batch_src, batch_tgt):
                                src_tk = data_tokenizer.tokenize(src.strip())
                                tgt_tk = data_tokenizer.tokenize(tgt.strip())
                                ex_list.append((src_tk, tgt_tk))

                            batch = []
                            for idx in range(len(ex_list)):
                                instance = ex_list[idx]
                                for proc in ks_predict_bi_uni_pipeline:
                                    instance = proc(instance)
                                    batch.append(instance)

                            batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(
                                batch)
                            batch = [
                                t.to(device) if t is not None else None
                                for t in batch_tensor
                            ]

                            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                            predict_bleu = args.predict_bleu * torch.ones(
                                [input_ids.shape[0]], device=input_ids.device)
                            oracle_pos, oracle_weights, oracle_labels = None, None, None
                            with torch.no_grad():
                                logits = model(input_ids,
                                               segment_ids,
                                               input_mask,
                                               lm_label_ids,
                                               is_next,
                                               masked_pos=masked_pos,
                                               masked_weights=masked_weights,
                                               task_idx=task_idx,
                                               masked_pos_2=oracle_pos,
                                               masked_weights_2=oracle_weights,
                                               masked_labels_2=oracle_labels,
                                               mask_qkv=mask_qkv,
                                               labels=predict_bleu,
                                               train_ks=True)

                                logits = torch.nn.functional.softmax(logits,
                                                                     dim=1)
                                labels = logits[:, 1].cpu().numpy()
                                for i in range(len(labels)):
                                    line = batch_src[i].strip()
                                    line += "\t"
                                    line += str(labels[i])
                                    out.write(line)
                                    out.write("\n")

                    data_path = os.path.join(args.data_dir,
                                             "qkr_dev.ks_score.tk")
                    src_path = os.path.join(args.data_dir, "qkr_dev.src.tk")
                    src_out_path = os.path.join(args.data_dir,
                                                "rank_qkr_dev.src.tk")
                    tgt_path = os.path.join(args.data_dir, "qkr_dev.tgt")

                    knowledge_selection(data_path, src_path, src_out_path)
                    logger.info(" ** ** * DEV Know Selection End ** ** * ")

                    # Decode Stage
                    logger.info(" ** ** * Dev Decode Begin ** ** * ")
                    with open(src_out_path, encoding="utf-8") as file:
                        dev_src_lines = file.readlines()
                    with open(tgt_path, encoding="utf-8") as file:
                        golden_response_lines = file.readlines()

                    decode_result = decode_batch(model, dev_src_lines)
                    logger.info(" ** ** * Dev Decode End ** ** * ")

                    # Compute dev F1
                    assert len(decode_result) == len(golden_response_lines)
                    C_F1 = f_one(decode_result, golden_response_lines)[0]
                    logger.info(
                        "** ** * Current F1 is {} ** ** * ".format(C_F1))
                    if C_F1 < max_F1:
                        logger.info(
                            "** ** * Current F1 is lower than Previous F1. So Stop Training ** ** * "
                        )
                        logger.info(
                            "** ** * The best model is {} ** ** * ".format(
                                best_step))
                        break
                    else:
                        max_F1 = C_F1
                        best_step = step
                        logger.info(
                            "** ** * Current F1 is larger than Previous F1. So Continue Training ** ** * "
                        )

                    # Save trained model
                    if (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
                        logger.info(
                            "** ** * Saving fine-tuned model and optimizer ** ** * "
                        )
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir,
                            "model.{}_{}.bin".format(i_epoch, global_step))
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        output_optim_file = os.path.join(
                            args.output_dir, "optim.bin")
                        torch.save(optimizer.state_dict(), output_optim_file)

                        #logger.info(" ** ** * CUDA.empty_cache() ** ** * ")
                        torch.cuda.empty_cache()

    # ################# Predict ############################ #
    if args.do_predict:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq_predict(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]

        next_i = 0
        model.eval()

        with open(os.path.join(args.data_dir, args.predict_input_file),
                  "r",
                  encoding="utf-8") as file:
            src_file = file.readlines()
        with open("train_tgt_pad.empty", "r", encoding="utf-8") as file:
            tgt_file = file.readlines()
        with open(os.path.join(args.data_dir, args.predict_output_file),
                  "w",
                  encoding="utf-8") as out:
            logger.info("** ** * Continue knowledge ranking ** ** * ")
            for next_i in tqdm(
                    range(len(src_file) // args.eval_batch_size + 1)):
                #while next_i < len(src_file):
                batch_src = src_file[next_i *
                                     args.eval_batch_size:(next_i + 1) *
                                     args.eval_batch_size]
                batch_tgt = tgt_file[next_i *
                                     args.eval_batch_size:(next_i + 1) *
                                     args.eval_batch_size]
                #next_i += args.eval_batch_size

                ex_list = []
                for src, tgt in zip(batch_src, batch_tgt):
                    src_tk = data_tokenizer.tokenize(src.strip())
                    tgt_tk = data_tokenizer.tokenize(tgt.strip())
                    ex_list.append((src_tk, tgt_tk))

                batch = []
                for idx in range(len(ex_list)):
                    instance = ex_list[idx]
                    for proc in bi_uni_pipeline:
                        instance = proc(instance)
                        batch.append(instance)

                batch_tensor = seq2seq_loader.batch_list_to_batch_tensors(
                    batch)
                batch = [
                    t.to(device) if t is not None else None
                    for t in batch_tensor
                ]

                input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch

                predict_bleu = args.predict_bleu * torch.ones(
                    [input_ids.shape[0]], device=input_ids.device)
                oracle_pos, oracle_weights, oracle_labels = None, None, None
                with torch.no_grad():
                    logits = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv,
                                   labels=predict_bleu,
                                   train_ks=True)

                    logits = torch.nn.functional.softmax(logits, dim=1)
                    labels = logits[:, 1].cpu().numpy()
                    for i in range(len(labels)):
                        line = batch_src[i].strip()
                        line += "\t"
                        line += str(labels[i])
                        out.write(line)
                        out.write("\n")
Beispiel #24
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )

    parser.add_argument("--dev_src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--dev_tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument("--dev_check_file",
                        default=None,
                        type=str,
                        help="The output style response/know data file name.")
    parser.add_argument("--dev_style_file",
                        default=None,
                        type=str,
                        help="The output style response/know data file name.")

    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")

    parser.add_argument("--predict_bleu",
                        default=0.5,
                        type=float,
                        help="The Predicted Bleu for KS Predict ")

    parser.add_argument("--train_vae",
                        action='store_true',
                        help="Whether to train vae.")
    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run ks predict.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion_step",
        default=300,
        type=int,
        help=
        "Proportion of training to perform linear learning rate warmup for. ")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)

    handler = logging.FileHandler(os.path.join(args.log_dir, "train.log"),
                                  encoding='UTF-8')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)

    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)

    logger.addHandler(handler)
    logger.addHandler(console)

    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    C_bi_uni_pipeline = [
        seq2seq_loader.C_Preprocess4Seq2seq(
            args.max_pred,
            args.mask_prob,
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            new_segment_ids=args.new_segment_ids,
            truncate_config={
                'max_len_a': args.max_len_a,
                'max_len_b': args.max_len_b,
                'trunc_seg': args.trunc_seg,
                'always_truncate_tail': args.always_truncate_tail
            },
            mask_source_words=args.mask_source_words,
            skipgram_prb=args.skipgram_prb,
            skipgram_size=args.skipgram_size,
            mask_whole_word=args.mask_whole_word,
            mode="s2s",
            has_oracle=args.has_sentence_oracle,
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift)
    ]

    logger.info("Loading Dataset from {}".format(args.data_dir))

    fn_src = os.path.join(args.data_dir, args.dev_src_file)
    fn_tgt = os.path.join(args.data_dir, args.dev_tgt_file)
    dev_reddit_dataset = seq2seq_loader.C_Seq2SeqDataset(
        fn_src,
        fn_tgt,
        args.eval_batch_size,
        data_tokenizer,
        args.max_seq_length,
        file_oracle=None,
        bi_uni_pipeline=C_bi_uni_pipeline)
    if args.local_rank == -1:
        dev_reddit_sampler = RandomSampler(dev_reddit_dataset,
                                           replacement=False)
        _batch_size = args.eval_batch_size
    else:
        dev_reddit_sampler = DistributedSampler(dev_reddit_dataset)
        _batch_size = args.eval_batch_size // dist.get_world_size()
    dev_reddit_dataloader = torch.utils.data.DataLoader(
        dev_reddit_dataset,
        batch_size=_batch_size,
        sampler=dev_reddit_sampler,
        num_workers=args.num_workers,
        collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
        pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
          (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()

    if args.model_recover_path:
        logger.info("***** Recover model: %s *****", args.model_recover_path)
        model_recover = torch.load(args.model_recover_path, map_location='cpu')

    model = BertForPreTrainingLossMask.from_pretrained(
        args.bert_model,
        state_dict=model_recover,
        num_labels=cls_num_labels,
        num_rel=0,
        type_vocab_size=type_vocab_size,
        config_path=args.config_path,
        task_idx=3,
        num_sentlvl_labels=num_sentlvl_labels,
        max_position_embeddings=args.max_position_embeddings,
        label_smoothing=args.label_smoothing,
        fp32_embedding=args.fp32_embedding,
        relax_projection=relax_projection,
        new_pos_ids=args.new_pos_ids,
        ffn_type=args.ffn_type,
        hidden_dropout_prob=args.hidden_dropout_prob,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        num_qkv=args.num_qkv,
        seg_emb=args.seg_emb)

    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if args.optim_recover_path is not None:
        logger.info("***** Recover optimizer from : {} *****".format(
            args.optim_recover_path))
        optim_recover = torch.load(args.optim_recover_path, map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:

        pretrain_step = -1
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)

            logger.info("***** Running QKR evaling *****")
            logger.info("  Batch size = %d", args.eval_batch_size)

            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            dev_iter_bar = tqdm(dev_reddit_dataloader,
                                desc='Iter (loss=X.XXX)',
                                disable=args.local_rank not in (-1, 0))

            total_lm_loss = 0
            for qkr_dev_step, batch in enumerate(dev_iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch
                else:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, tgt_pos, labels, ks_labels, style_ids, style_labels, check_ids = batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None

                with torch.no_grad():

                    loss_tuple = model(input_ids,
                                       segment_ids,
                                       input_mask,
                                       lm_label_ids,
                                       is_next,
                                       masked_pos=masked_pos,
                                       masked_weights=masked_weights,
                                       task_idx=task_idx,
                                       masked_pos_2=oracle_pos,
                                       masked_weights_2=oracle_weights,
                                       masked_labels_2=oracle_labels,
                                       mask_qkv=mask_qkv,
                                       tgt_pos=tgt_pos,
                                       labels=labels,
                                       ks_labels=ks_labels,
                                       train_vae=args.train_vae,
                                       style_ids=style_ids,
                                       style_labels=style_labels,
                                       check_ids=check_ids,
                                       pretrain=None)

                    masked_lm_loss, next_sentence_loss, KL_loss, Mutual_loss, Golden_loss, cosine_similarity_loss, predict_kl_loss = loss_tuple
                    if n_gpu > 1:  # mean() to average on multi-gpu.
                        masked_lm_loss = masked_lm_loss.mean()

                    # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                    total_lm_loss += masked_lm_loss.item()

                    # ensure that accumlated gradients are normalized
                    total_mean_lm_loss = total_lm_loss / (qkr_dev_step + 1)
                    print(total_mean_lm_loss)

            logger.info("** ** * Evaling mean loss ** ** * ")
            logger.info("In{}epoch,dev_lm_loss:{}".format(
                i_epoch, total_mean_lm_loss))
            logger.info("ppl:{}".format(np.exp(total_mean_lm_loss)))
            logger.info("******************************************* ")
            break