def main(): parser = argparse.ArgumentParser() # General parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default='tmp', type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_file", default="training.log", type=str, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument( "--do_train", action='store_true', help="Whether to run training. This should ALWAYS be set to True.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=64, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=30, type=int, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--global_rank", type=int, default=-1, help="global_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 32-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") parser.add_argument('--max_len_b', type=int, default=20, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='b', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=3, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=4, type=int, help="Number of workers for the data loader.") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") # Others for VLP parser.add_argument( "--src_file", default=['/mnt/dat/COCO/annotations/dataset_coco.json'], type=str, nargs='+', help="The input data file name.") parser.add_argument('--enable_visdom', action='store_true') parser.add_argument('--visdom_port', type=int, default=8888) # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth') parser.add_argument('--image_root', type=str, default='/mnt/dat/COCO/images') parser.add_argument('--dataset', default='coco', type=str, help='coco | flickr30k | cc') parser.add_argument('--split', type=str, nargs='+', default=['train', 'restval']) parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='file://[PT_OUTPUT_DIR]/nonexistent_file', type=str, help='url used to set up distributed training') parser.add_argument( '--file_valid_jpgs', default='/mnt/dat/COCO/annotations/coco_valid_jpgs.json', type=str) parser.add_argument('--sche_mode', default='warmup_linear', type=str, help="warmup_linear | warmup_constant | warmup_cosine") parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--use_num_imgs', default=-1, type=int) parser.add_argument('--vis_mask_prob', default=0, type=float) parser.add_argument('--max_drop_worst_ratio', default=0, type=float) parser.add_argument('--drop_after', default=6, type=int) parser.add_argument( '--s2s_prob', default=1, type=float, help="Percentage of examples that are bi-uni-directional LM (seq2seq)." ) parser.add_argument( '--bi_prob', default=0, type=float, help="Percentage of examples that are bidirectional LM.") parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument( '--region_bbox_file', default= 'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str) parser.add_argument( '--region_det_file_prefix', default= 'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval', type=str) parser.add_argument('--tasks', default='img2txt', help='img2txt | vqa2') parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--scst', action='store_true', help='Self-critical sequence training') args = parser.parse_args() print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank)) args.max_seq_length = args.max_len_b + args.len_vis_input + 3 # +3 for 2x[SEP] and [CLS] args.mask_image_regions = (args.vis_mask_prob > 0 ) # whether to mask out image regions args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir) # arguments inspection assert (args.tasks in ('img2txt', 'vqa2')) assert args.enable_butd == True, 'only support region attn! featmap attn deprecated' assert ( not args.scst) or args.dataset == 'coco', 'scst support on coco only!' if args.scst: assert args.dataset == 'coco', 'scst support on coco only!' assert args.max_pred == 0 and args.mask_prob == 0, 'no mask for scst!' rl_crit = RewardCriterion() if args.enable_butd: assert (args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join( args.image_root, args.region_det_file_prefix) if args.dataset in ( 'cc', 'coco') and args.region_det_file_prefix != '' else '' # output config os.makedirs(args.output_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode='w', format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group( backend='nccl', init_method='tcp://localhost:10001', #args.dist_url, world_size=args.world_size, rank=args.global_rank) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # plotting loss, optional if args.enable_visdom: import visdom vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir) vis_window = {'iter': None, 'score': None} tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank)) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.do_train: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_image_regions=args.mask_image_regions, mode="s2s", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == 'vqa2')) ] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_image_regions=args.mask_image_regions, mode="bi", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == 'vqa2'))) train_dataset = seq2seq_loader.Img2txtDataset( args.src_file, args.image_root, args.split, args.train_batch_size, data_tokenizer, args.max_seq_length, file_valid_jpgs=args.file_valid_jpgs, bi_uni_pipeline=bi_uni_pipeline, use_num_imgs=args.use_num_imgs, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, enable_butd=args.enable_butd, tasks=args.tasks) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int( len(train_dataloader) * args.num_train_epochs * 1. / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == 'img2txt' else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"]) # index in BERT vocab: 103, 102, 0 if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init assert args.scst == False, 'must init from maximum likelihood training' _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load( os.path.join(args.output_dir, "model.{0}.bin".format(recover_step))) # recover_step == number of epochs global_step = math.floor(recover_step * t_total * 1. / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path) global_step = 0 if not args.scst: model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks) else: model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=task_idx_proj, mask_word_id=mask_word_id, search_beam_size=1, eos_id=eos_word_ids, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover torch.cuda.empty_cache() # deprecated # from vlp.resnet import resnet # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning if args.fp16: model.half() # cnn.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) # cnn.to(device) if args.local_rank != -1: try: # from apex.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # cnn = DDP(cnn) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # cnn = DataParallelImbalance(cnn) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load( os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step))) if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) logger.info(" Loader length = %d", len(train_dataloader)) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"): if args.local_rank >= 0: train_sampler.set_epoch(i_epoch - 1) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)') nbatches = len(train_dataloader) train_loss = [] pretext_loss = [] vqa2_loss = [] scst_reward = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1).permute(0, 2, 1).contiguous() if not args.scst: loss_tuple = model( conv_feats, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, ans_labels, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions, drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0) mean_reward = loss_tuple[0].new(1).fill_(0) else: # scst training model.eval() position_ids = torch.arange( input_ids.size(1), dtype=input_ids.dtype, device=input_ids.device).unsqueeze(0).expand_as( input_ids) input_dummy = input_ids[:, :args.len_vis_input + 2] # +2 for [CLS] and [SEP] greedy_res = input_ids.new( input_ids.size(0), input_ids.size(1) - args.len_vis_input - 2).fill_(0) gen_result = input_ids.new( input_ids.size(0), input_ids.size(1) - args.len_vis_input - 2).fill_(0) with torch.no_grad(): greedy_res_raw, _ = model(conv_feats, vis_pe, input_dummy, segment_ids, position_ids, input_mask, task_idx=task_idx, sample_mode='greedy') for b in range(greedy_res_raw.size(0)): for idx in range(greedy_res_raw.size(1)): if greedy_res_raw[b][idx] not in [ eos_word_ids, pad_word_ids ]: greedy_res[b][idx] = greedy_res_raw[b][idx] else: if greedy_res_raw[b][idx] == eos_word_ids: greedy_res[b][idx] = eos_word_ids break model.train() gen_result_raw, sample_logprobs = model( conv_feats, vis_pe, input_dummy, segment_ids, position_ids, input_mask, task_idx=task_idx, sample_mode='sample') for b in range(gen_result_raw.size(0)): for idx in range(gen_result_raw.size(1)): if gen_result_raw[b][idx] not in [ eos_word_ids, pad_word_ids ]: gen_result[b][idx] = gen_result_raw[b][idx] else: if gen_result_raw[b][idx] == eos_word_ids: gen_result[b][idx] = eos_word_ids break gt_ids = input_ids[:, args.len_vis_input + 2:] reward = get_self_critical_reward(greedy_res, gt_ids, gen_result, gt_ids.size(0)) reward = torch.from_numpy(reward).float().to( gen_result.device) mean_reward = reward.mean() loss = rl_crit(sample_logprobs, gen_result.data, reward) loss_tuple = [ loss, loss.new(1).fill_(0.), loss.new(1).fill_(0.) ] # disable pretext_loss_deprecated for now masked_lm_loss, pretext_loss_deprecated, ans_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. For dist, this is done through gradient addition. masked_lm_loss = masked_lm_loss.mean() pretext_loss_deprecated = pretext_loss_deprecated.mean() ans_loss = ans_loss.mean() loss = masked_lm_loss + pretext_loss_deprecated + ans_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) train_loss.append(loss.item()) pretext_loss.append(pretext_loss_deprecated.item()) vqa2_loss.append(ans_loss.item()) scst_reward.append(mean_reward.item()) if step % 100 == 0: logger.info( "Epoch {}, Iter {}, Loss {:.2f}, Pretext {:.2f}, VQA2 {:.2f}, Mean R {:.3f}\n" .format(i_epoch, step, np.mean(train_loss), np.mean(pretext_loss), np.mean(vqa2_loss), np.mean(scst_reward))) if args.enable_visdom: if vis_window['iter'] is None: vis_window['iter'] = vis.line( X=np.tile( np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack( (np.asarray([np.mean(train_loss)]), )), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total'])) else: vis.line(X=np.tile( np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack( (np.asarray([np.mean(train_loss)]), )), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total']), win=vis_window['iter'], update='append') # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "model.{0}.bin".format(i_epoch)) output_optim_file = os.path.join(args.output_dir, "optim.{0}.bin".format(i_epoch)) if args.global_rank in ( -1, 0): # save model if the first device or no dist torch.save( copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file) # torch.save(optimizer.state_dict(), output_optim_file) # disable for now, need to sanitize state and ship everthing back to cpu logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.world_size > 1: torch.distributed.barrier()
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Other parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--local_debug", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: print("Loading Train Dataset", args.data_dir) bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline, corpus_preprocessors=corpus_preprocessors) train_dataset.initial() print(len(train_dataset.ex_list)) print(train_dataset.batch_size) # assert 1==0 if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # c = 0 # for i_epoch in trange(0, int(args.num_train_epochs)+1, desc="Epoch", disable=args.local_rank not in (-1, 0)): # if args.local_rank != -1: # train_sampler.set_epoch(i_epoch) # iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', # disable=args.local_rank not in (-1, 0)) # for step, batch in enumerate(iter_bar): # batch = [ # t.to(device) if t is not None else None for t in batch] # if args.has_sentence_oracle: # input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch # else: # input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch # oracle_pos, oracle_weights, oracle_labels = None, None, None # c += input_ids.shape[0] # # # print(input_ids) # # # # print(input_ids.shape) # # print(segment_ids) # # print(segment_ids.shape) # # print(is_next) # # print(task_idx) # # print(sop_label) # # print(task_idx.shape) # # for i in range(input_mask.shape[0]): # # print(input_mask[i]) # print(c) # print(train_dataset.c) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) if args.local_rank == 0: dist.barrier() if args.fp16: model.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch else: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, sop_label = batch print(sop_label) print(task_idx) oracle_pos, oracle_weights, oracle_labels = None, None, None # loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next, # masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, # masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, # masked_labels_2=oracle_labels, mask_qkv=mask_qkv) loss_tuple = model(input_ids, segment_ids, input_mask, lm_label_ids, sop_label, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. # loss = loss.mean() masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() print('mask_lm_loss {}'.format(masked_lm_loss)) print('next_sentence_loss {}'.format(next_sentence_loss)) print('----------------------------------------------') loss = masked_lm_loss + next_sentence_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model if (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.{0}.bin".format(i_epoch)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.output_dir, "optim.{0}.bin".format(i_epoch)) torch.save(optimizer.state_dict(), output_optim_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache()
def main(): parser = argparse.ArgumentParser() # General parser.add_argument("--bert_model", default="bert-base-cased", type=str, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") # For decoding parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--max_tgt_length', type=int, default=20, help="maximum length of target sequence") # Others for VLP parser.add_argument("--src_file", default='/mnt/dat/COCO/annotations/dataset_coco.json', type=str, help="The input data file name.") parser.add_argument("--ref_file", default='pythia/data/v2_mscoco_val2014_annotations.json', type=str, help="The annotation reference file name.") parser.add_argument('--dataset', default='coco', type=str, help='coco | flickr30k | cc') parser.add_argument('--len_vis_input', type=int, default=100) # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth') parser.add_argument('--image_root', type=str, default='/mnt/dat/COCO/images') parser.add_argument('--split', type=str, default='val') parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--region_bbox_file', default='coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str) parser.add_argument('--region_det_file_prefix', default='feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval', type=str) parser.add_argument("--output_dir", default='tmp', type=str, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument('--file_valid_jpgs', default='', type=str) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) if args.enable_butd: assert(args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join(args.image_root, args.region_det_file_prefix) if args.dataset in ('cc', 'coco') and args.region_det_file_prefix != '' else '' device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case) args.max_seq_length = args.max_tgt_length + args.len_vis_input + 3 # +3 for 2x[SEP] and [CLS] tokenizer.max_len = args.max_seq_length bi_uni_pipeline = [] bi_uni_pipeline = [seq2seq_loader.Preprocess4Seq2seq(0, 0, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': args.len_vis_input, 'max_len_b': args.max_tgt_length, 'trunc_seg': 'b', 'always_truncate_tail': True}, mode="bi", len_vis_input=args.len_vis_input, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, load_vqa_ann=True)] amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 logger.info('Attempting to recover models from: {}'.format(args.model_recover_path)) if 0 == len(glob.glob(args.model_recover_path.strip())): logger.error('There are no models to recover. The program will exit.') sys.exit(1) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=0, max_position_embeddings=512, cache_dir=args.output_dir+'/.pretrained_model_{}'.format(-1), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks='vqa2') del model_recover # deprecated # from vlp.resnet import resnet # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning if args.fp16: model.half() # cnn.half() model.to(device) # cnn.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) # cnn = torch.nn.DataParallel(cnn) torch.cuda.empty_cache() model.eval() # cnn.eval() eval_lst = [] img_dat = np.load(args.src_file, allow_pickle=True) img_idx = 0 for i in range(1, img_dat.shape[0]): if args.enable_butd: src_tk = os.path.join(args.image_root, img_dat[i]['image_name'].split('_')[1], img_dat[i]['feature_path']) else: raise NotImplementedError tgt_tk = tokenizer.tokenize(img_dat[i]['question_str']) eval_lst.append((img_idx, src_tk, tgt_tk, img_dat[i]['question_id'])) img_idx += 1 input_lines = eval_lst next_i = 0 output_lines = [""] * len(input_lines) score_trace_list = [None] * len(input_lines) total_batch = math.ceil(len(input_lines) / args.batch_size) predictions = [] print('start the VQA evaluation...') with tqdm(total=total_batch) as pbar: while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf = [(x[1], x[2]) for x in _chunk] buf_id = [(x[0], x[3]) for x in _chunk] next_i += args.batch_size instances = [] for instance in buf: for proc in bi_uni_pipeline: instances.append(proc(instance[:2]+({'answers': ['dummy']},))) with torch.no_grad(): batch = batch_list_to_batch_tensors( instances) batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, _ = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1).permute(0,2,1).contiguous() ans_idx = model(conv_feats, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, None, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, vis_masked_pos=vis_masked_pos, drop_worst_ratio=0, vqa_inference=True) for ind, (eval_idx, ques_id) in enumerate(buf_id): predictions.append({'question_id': ques_id, 'answer': bi_uni_pipeline[0].ans_proc.idx2word(ans_idx[ind])}) pbar.update(1) results_file = os.path.join(args.output_dir, 'vqa2-results-'+args.model_recover_path.split('/')[-2]+'-'+args.split+'-'+args.model_recover_path.split('/')[-1].split('.')[-2]+'.json') json.dump(predictions, open(results_file, 'w')) if args.split == 'test2015': print('*'*80) print('[WARNING] Evaluation unavailable for the test set!\ \n Please submit your saved JSON file named\ \n `{}`\ \n to the VQA 2.0 server:\ \n https://evalai.cloudcv.org/web/challenges/challenge-page/163/submission'.format(results_file)) print('*'*80) else: import subprocess print('Evaluating result file {}'.format(results_file)) subprocess.Popen(['python', 'pythia/pythia/legacy/eval_model/eval_demo.py', args.ref_file, results_file])
def main(): args = process_args() if args.loss_type == 'mlm': assert args.neg_num == 0 and args.multiple_neg == 0 elif args.loss_type == 'nsp': assert int(args.bi_prob) == 1 and args.max_pred == 0 and args.neg_num > 0 if args.adaptive_weight == 1: assert args.neg_num > 1 if args.add_boundary == 1: assert args.inc_full_hist if args.world_size > 1: print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank)) # Input format: [CLS] img [SEP] hist [SEP_0] ques [SEP_1] ans [SEP] args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1 args.mask_image_regions = (args.vis_mask_prob > 0) # whether to mask out image regions args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir) # arguments inspection assert args.enable_butd, 'only support region attn! featmap attn deprecated' if args.enable_butd: if args.visdial_v == '1.0': assert (args.len_vis_input == 36) or (args.len_vis_input == 0) elif args.visdial_v == '0.9': if (args.len_vis_input == 100): args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join(args.image_root, args.region_det_file_prefix) if args.dataset in ( 'cc', 'coco') and args.region_det_file_prefix != '' else '' # output config os.makedirs(args.output_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join( args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode='w', format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) ch = logging.StreamHandler(sys.stdout) ch.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')) ch.setLevel(logging.INFO) logger.addHandler(ch) if args.local_rank == -1 or args.no_cuda: device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.global_rank) logger.info('Arguments: %s\n' % (' '.join(sys.argv[:]))) logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = int( args.train_batch_size / args.gradient_accumulation_steps) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # plotting loss, optional if args.enable_visdom: import visdom vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir) vis_window = {'iter': None, 'score': None} tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank)) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer assert args.do_train logger.info('Max seq length: %d, batch size: %d\n' % (args.max_seq_length, args.train_batch_size)) bi_uni_pipeline = [Preprocess4TrainVisdial(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'len_vis_input': args.len_vis_input, 'max_len_hist_ques': args.max_len_hist_ques, 'max_len_ans': args.max_len_ans}, mask_image_regions=args.mask_image_regions, mode="s2s", vis_mask_prob=args.vis_mask_prob, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, image_features_hdfpath=args.image_features_hdfpath, visdial_v=args.visdial_v, pad_hist=args.pad_hist, finetune=args.finetune, only_mask_ans=args.only_mask_ans, add_boundary=args.add_boundary, only_qa=args.only_qa), Preprocess4TrainVisdial(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'len_vis_input': args.len_vis_input, 'max_len_hist_ques': args.max_len_hist_ques, 'max_len_ans': args.max_len_ans}, mask_image_regions=args.mask_image_regions, mode="bi", vis_mask_prob=args.vis_mask_prob, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, image_features_hdfpath=args.image_features_hdfpath, visdial_v=args.visdial_v, pad_hist=args.pad_hist, finetune=args.finetune, only_mask_ans=args.only_mask_ans, add_boundary=args.add_boundary, only_qa=args.only_qa)] train_dataset = VisdialDataset( args.src_file, args.train_batch_size, data_tokenizer, use_num_imgs=args.use_num_imgs, bi_uni_pipeline=bi_uni_pipeline, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, is_train=args.do_train, neg_num=args.neg_num, inc_gt_rel=args.inc_gt_rel, inc_full_hist=args.inc_full_hist, just_for_pretrain=args.just_for_pretrain, sub_sample=args.sub_sample) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int(len(train_dataloader) * args.num_train_epochs * 1. / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == 'img2txt' else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"]) # index in BERT vocab: 103, 102, 0 if (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init assert args.scst == False, 'must init from maximum likelihood training' _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type, neg_num=args.neg_num, adaptive_weight=args.adaptive_weight, add_attn_fuse=args.add_attn_fuse, no_h0=args.no_h0, no_vision=args.no_vision) global_step = 0 else: if args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path) global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type, neg_num=args.neg_num, adaptive_weight=args.adaptive_weight, add_attn_fuse=args.add_attn_fuse, no_h0=args.no_h0, no_vision=args.no_vision) del model_recover torch.cuda.empty_cache() if args.fp16: model.half() # cnn.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: model = DataParallelImbalance(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any( nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any( nd in n for nd in no_decay)], 'weight_decay': 0.0} ] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State( optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State( optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) logger.info(" Loader length = %d", len(train_dataloader)) model.train() start_epoch = 1 logger.info("Begin training from epoch = %d", start_epoch) t0 = time.time() for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"): if args.multiple_neg and i_epoch > 1: train_dataset = VisdialDataset( args.src_file, args.train_batch_size, data_tokenizer, use_num_imgs=args.use_num_imgs, bi_uni_pipeline=bi_uni_pipeline, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, is_train=args.do_train, neg_num=args.neg_num, inc_gt_rel=args.inc_gt_rel, inc_full_hist=args.inc_full_hist, just_for_pretrain=args.just_for_pretrain, sub_sample=args.sub_sample) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) if args.local_rank >= 0: train_sampler.set_epoch(i_epoch - 1) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)') nbatches = len(train_dataloader) losses = [] pretext_loss = [] mlm_losses = [] nsp_losses = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, \ task_idx, vis_masked_pos, img, vis_pe = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data loss_tuple = model(conv_feats, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions, drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0) # disable pretext_loss_deprecated for now masked_lm_loss, pretext_loss_deprecated, nsp_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. For dist, this is done through gradient addition. masked_lm_loss = masked_lm_loss.mean() pretext_loss_deprecated = pretext_loss_deprecated.mean() nsp_loss = nsp_loss.mean() loss = masked_lm_loss + pretext_loss_deprecated + nsp_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) losses.append(loss.item()) mlm_losses.append(masked_lm_loss.item()) pretext_loss.append(pretext_loss_deprecated.item()) nsp_losses.append(nsp_loss.item()) if step % max(1, nbatches // 10) == 0: logger.info( "Epoch {}, Iter {}, Loss {:.4f}, MLM {:.4f}, NSP {:.4f}, Elapse time {:.2f}\n".format( i_epoch, step, np.mean(losses), np.mean(mlm_losses), np.mean(nsp_losses), time.time() - t0)) if args.enable_visdom: if vis_window['iter'] is None: vis_window['iter'] = vis.line( X=np.tile(np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack((np.asarray([np.mean(losses)]),)), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total']) ) else: vis.line( X=np.tile(np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack((np.asarray([np.mean(losses)]),)), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total']), win=vis_window['iter'], update='append' ) # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step / t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.output_dir, "model.%d.%.3f.bin" % (i_epoch, np.mean(losses))) if args.global_rank in (-1, 0): # save model if the first device or no dist torch.save(copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file) logger.info("Save model to %s", output_model_file) logger.info("Finish training epoch %d, avg loss: %.2f and takes %.2f seconds" % ( i_epoch, np.mean(losses), time.time() - t0)) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.world_size > 1: torch.distributed.barrier()
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--src_file", default=None, type=str, help="The input data file name.") parser.add_argument("--topic_model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining topic model.") parser.add_argument("--topic_model_dict_path", default=None, type=str, help="The file of fine-tuned pretraining topic model.") parser.add_argument("--tgt_file", default=None, type=str, help="The output data file name.") parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, required=True, help="The file of fine-tuned pretraining model.") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") parser.add_argument('--topic_mode', default=1, type=float, help="1:idea1 1.1:idea1_wo_theta 2:idea2 ") parser.add_argument('--topic_model', default=False, type=bool, help="if only use topic model") # Other parameters parser.add_argument( "--max_seq_length", default=192, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument( "--train_batch_size", default=32, type=int, help="Total batch size for training.") #batch_size = batch_size/n_gpus parser.add_argument("--eval_batch_size", default=16, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=30, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() assert Path( args.model_recover_path).exists(), "--model_recover_path doesn't exist" args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]', os.getenv('PT_OUTPUT_DIR', '')) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) print("args.local_rank", args.local_rank) print("args.no_cuda", args.no_cuda) if args.local_rank == -1 or args.no_cuda: #-1 False device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") #device = cuda n_gpu = torch.cuda.device_count() print("n_gpu_1", n_gpu) else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs dist.init_process_group(backend='nccl') print("n_gpu_1", n_gpu) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.local_rank == 0: dist.barrier() if args.do_train: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') fn_src = os.path.join(args.data_dir, args.src_file if args.src_file else 'train.src') fn_tgt = os.path.join(args.data_dir, args.tgt_file if args.tgt_file else 'train.tgt') train_dataset = seq2seq_loader.Seq2SeqDataset( fn_src, fn_tgt, args.data_dir, args.topic_model_dict_path, args.train_batch_size, data_tokenizer, args.max_seq_length, file_oracle=file_oracle, bi_uni_pipeline=bi_uni_pipeline) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset, replacement=False) _batch_size = args.train_batch_size else: train_sampler = DistributedSampler(train_dataset) _batch_size = args.train_batch_size // dist.get_world_size() train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size) t_total = int( len(train_dataloader) * args.num_train_epochs / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 ### type_vocab_size=6 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None unilm = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * t_total / args.num_train_epochs) elif args.model_recover_path: # here is the entrance logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 unilm = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb) #1. 模型初始化,入口定义好 gsm = GSM(train_dataset.vocabsize) gsm_checkpoint = torch.load(args.topic_model_recover_path) gsm.load_state_dict(gsm_checkpoint["net"]) if args.local_rank == 0: dist.barrier() if args.fp16: unilm.half() gsm.half() if args.fp32_embedding: unilm.bert.embeddings.word_embeddings.float() unilm.bert.embeddings.position_embeddings.float() unilm.bert.embeddings.token_type_embeddings.float() unilm.to(device) gsm.to(device) if args.local_rank != -1: try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("DistributedDataParallel") unilm = DDP(unilm, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) elif n_gpu > 1: # model = torch.nn.DataParallel(model) unilm = DataParallelImbalance(unilm) gsm = DataParallelImbalance(gsm) # Prepare optimizer total = 0 param_optimizer = list(unilm.named_parameters()) param_optimizer_topic = list(gsm.named_parameters()) for name, parameters in unilm.named_parameters(): if "idea" in name: if "11" in name and "idea2" in name: total += np.prod(parameters.size()) # print(name, ':', parameters.size()) else: total += np.prod(parameters.size()) # print(name, ':', parameters.size()) print("gsm have {} paramerters in total".format( sum(x.numel() for x in gsm.parameters()))) print("Number of parameter: %.6fM" % (total / 1e6)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] if not args.topic_model: optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01, 'topic': False }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'topic': False }, { 'params': [p for n, p in param_optimizer_topic], 'weight_decay': 0.0, 'lr': 1e-3, 'topic': True }] else: optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer_topic], 'weight_decay': 0.0, 'lr': 1e-3, 'topic': True }] #一部分是有weight的,一部分是没有weight_dacay的 # print("optimizer_grouped_parameters", optimizer_grouped_parameters) if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load(os.path.join( args.output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) unilm.train() gsm.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 print("000000", args.local_rank, start_epoch, int(args.num_train_epochs) + 1) topicloss = [] unilmloss = [] topicloss_lst = [] unilmloss_lst = [] for i_epoch in trange(start_epoch, int(args.num_train_epochs) + 1, desc="Epoch", disable=args.local_rank not in (-1, 0)): loss_sum = 0.0 ppx_sum = 0.0 word_count = 0.0 doc_count = 0.0 if args.local_rank != -1: train_sampler.set_epoch(i_epoch) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)', disable=args.local_rank not in (-1, 0)) for step, batch in enumerate(iter_bar): batch = [ t.to(device) if t is not None else None for t in batch ] if args.has_sentence_oracle: #false input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch else: #这里加了bows input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, bows = batch oracle_pos, oracle_weights, oracle_labels = None, None, None p_x, mus, log_vars, theta, beta, topic_embedding = gsm(bows) if not args.topic_model: loss_tuple = unilm(input_ids, theta, beta, topic_embedding, args.topic_mode, segment_ids, input_mask, lm_label_ids, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple ## topic loss logsoftmax = torch.log(p_x + 1e-10) rec_loss = -1.0 * torch.sum( bows * logsoftmax ) #bows*logsoftmax = [batch_size, |V|], 其中torch.sum 把所有的loss全部加起来了,也可以只用加某一维度。 rec_loss_per = -1.0 * torch.sum(bows * logsoftmax, dim=1) rec_loss_per = rec_loss_per.cpu().detach().numpy() kl_div = -0.5 * torch.sum(1 + log_vars - mus.pow(2) - log_vars.exp()) loss_topic = rec_loss + kl_div if n_gpu > 1: # mean() to average on multi-gpu. loss_topic = loss_topic.mean() if not args.topic_model: masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() if not args.topic_model: loss_unilm = masked_lm_loss + next_sentence_loss # cal perplexity word_count_list = [] loss_sum += loss_topic.item() for bow in bows: word_num = torch.sum(bow).cpu().numpy() word_count_list.append(word_num) word_count += word_num word_count_np = np.array(word_count_list) doc_count += len(bows) ppx_sum += np.sum(np.true_divide(rec_loss_per, word_count_np)) topicloss_lst.append(loss_topic.item() / len(bows)) if not args.topic_model: unilmloss_lst.append(loss_unilm.item()) #topic_loss end if not args.topic_model: loss = loss_unilm + loss_topic else: loss = loss_topic # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: # =1 loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: if not param_group['topic']: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 if not args.topic_model: iter_bar.set_description( 'Iter (loss_unilm=%5.3f),Iter (ppl=%5.3f)' % (loss_unilm.item(), np.sum(np.true_divide(rec_loss_per, word_count_np)))) else: iter_bar.set_description( 'Iter (loss_topic=%5.3f), (ppl=%5.3f)' % (loss_topic.item(), np.sum(np.true_divide(rec_loss_per, word_count_np)))) #Save a trained model ppx_word = np.exp(loss_sum / word_count) ppx_document = np.exp(ppx_sum / doc_count) print("********") print("word_count", word_count) print("ppx_word", ppx_word) print("ppx_document", ppx_document) if (args.local_rank == -1 or torch.distributed.get_rank() == 0): #save unilm model logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") unilm_model_to_save = unilm.module if hasattr( unilm, 'module') else unilm # Only save the model it-self output_unilm_model_file = os.path.join( args.output_dir, "unilm.{0}.bin".format(i_epoch)) torch.save(unilm_model_to_save.state_dict(), output_unilm_model_file) #save topic model logger.info( "** ** * Saving topic model and optimizer ** ** * ") topic_model_to_save = gsm.module if hasattr( gsm, 'module') else gsm # Only save the model it-self output_topic_model_file = os.path.join( args.output_dir, "topic.{0}.ckpt".format(i_epoch)) torch.save(topic_model_to_save.state_dict(), output_topic_model_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() smth_pts = smooth_curve(topicloss_lst) # plt.plot(range(len(topicloss_lst)), topicloss_lst) plt.plot(range(len(smth_pts)), smth_pts) plt.xlabel('epochs') plt.title('Topic Model Train Loss') plt.savefig(args.output_dir + '/topic_loss.png') plt.cla() plt.plot(range(len(unilmloss_lst)), unilmloss_lst) plt.xlabel('epochs') plt.title('Unilm Train Loss') plt.savefig(args.output_dir + '/unilm_loss.png')
def main(): parser = argparse.ArgumentParser() # General parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.", ) parser.add_argument( "--config_path", default=None, type=str, help="Bert config file path." ) parser.add_argument( "--output_dir", default="tmp", type=str, help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--log_file", default="eval.log", type=str, help="The output directory where the log will be written.", ) parser.add_argument( "--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.", ) parser.add_argument( "--do_train", action="store_true", help="Whether to run training. This should ALWAYS be set to True.", ) parser.add_argument( "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", ) parser.add_argument( "--train_batch_size", default=64, type=int, help="Total batch size for training.", ) parser.add_argument( "--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.", ) parser.add_argument( "--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.", ) parser.add_argument( "--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.", ) parser.add_argument( "--finetune_decay", action="store_true", help="Weight decay to the original weights.", ) parser.add_argument( "--num_train_epochs", default=30, type=int, help="Total number of training epochs to perform.", ) parser.add_argument( "--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.", ) parser.add_argument( "--no_cuda", action="store_true", help="Whether not to use CUDA when available" ) parser.add_argument( "--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus", ) parser.add_argument( "--global_rank", type=int, default=-1, help="global_rank for distributed training on gpus", ) parser.add_argument( "--seed", type=int, default=42, help="random seed for initialization" ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--fp16", action="store_true", help="Whether to use 16-bit float precision instead of 32-bit", ) parser.add_argument( "--fp32_embedding", action="store_true", help="Whether to use 32-bit float precision instead of 32-bit for embeddings", ) parser.add_argument( "--loss_scale", type=float, default=0, help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n", ) parser.add_argument( "--amp", action="store_true", help="Whether to use amp for fp16" ) parser.add_argument( "--from_scratch", action="store_true", help="Initialize parameters with random values (i.e., training from scratch).", ) parser.add_argument( "--new_segment_ids", action="store_true", help="Use new segment ids for bi-uni-directional LM.", ) parser.add_argument( "--tokenized_input", action="store_true", help="Whether the input is tokenized." ) parser.add_argument( "--len_vis_input", type=int, default=100, help="The length of visual token input", ) parser.add_argument( "--max_len_b", type=int, default=20, help="Truncate_config: maximum length of segment B.", ) parser.add_argument( "--trunc_seg", default="b", help="Truncate_config: first truncate segment A/B (option: a, b).", ) parser.add_argument( "--always_truncate_tail", action="store_true", help="Truncate_config: Whether we should always truncate tail.", ) parser.add_argument( "--mask_prob", default=0.15, type=float, help="Number of prediction is sometimes less than max_pred when sequence is short.", ) parser.add_argument( "--max_pred", type=int, default=3, help="Max tokens of prediction." ) parser.add_argument( "--num_workers", default=4, type=int, help="Number of workers for the data loader.", ) parser.add_argument( "--max_position_embeddings", type=int, default=None, help="max position embeddings", ) # Others for VLP parser.add_argument( "--src_file", default=["/mnt/dat/COCO/annotations/dataset_coco.json"], type=str, nargs="+", help="The input data file name.", ) parser.add_argument("--enable_visdom", action="store_true") parser.add_argument("--visdom_port", type=int, default=8888) # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth') parser.add_argument("--image_root", type=str, default="/mnt/dat/COCO/images") parser.add_argument( "--dataset", default="coco", type=str, help="coco | flickr30k | cc" ) parser.add_argument("--split", type=str, nargs="+", default=["train", "restval"]) parser.add_argument( "--world_size", default=1, type=int, help="number of distributed processes" ) parser.add_argument( "--dist_url", default="file://[PT_OUTPUT_DIR]/nonexistent_file", type=str, help="url used to set up distributed training", ) parser.add_argument( "--file_valid_jpgs", default="/mnt/dat/COCO/annotations/coco_valid_jpgs.json", type=str, ) parser.add_argument( "--sche_mode", default="warmup_linear", type=str, help="warmup_linear | warmup_constant | warmup_cosine", ) parser.add_argument("--drop_prob", default=0.1, type=float) parser.add_argument("--use_num_imgs", default=-1, type=int) parser.add_argument("--vis_mask_prob", default=0, type=float) parser.add_argument("--max_drop_worst_ratio", default=0, type=float) parser.add_argument("--drop_after", default=6, type=int) parser.add_argument( "--s2s_prob", default=1, type=float, help="Percentage of examples that are bi-uni-directional LM (seq2seq).", ) parser.add_argument( "--bi_prob", default=0, type=float, help="Percentage of examples that are bidirectional LM.", ) parser.add_argument( "--enable_butd", action="store_true", help="set to take in region features" ) parser.add_argument( "--region_bbox_file", default="coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5", type=str, ) parser.add_argument( "--region_det_file_prefix", default="feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval", type=str, ) parser.add_argument("--tasks", default="img2txt", help="img2txt | vqa2") parser.add_argument( "--relax_projection", action="store_true", help="Use different projection layers for tasks.", ) parser.add_argument( "--scst", action="store_true", help="Self-critical sequence training" ) args = parser.parse_args() print("global_rank: {}, local rank: {}".format(args.global_rank, args.local_rank)) args.max_seq_length = ( args.max_len_b + args.len_vis_input + 3 ) # +3 for 2x[SEP] and [CLS] args.mask_image_regions = ( args.vis_mask_prob > 0 ) # whether to mask out image regions args.dist_url = args.dist_url.replace("[PT_OUTPUT_DIR]", args.output_dir) # arguments inspection assert args.tasks in ("img2txt", "vqa2") assert args.enable_butd == True, "only support region attn! featmap attn deprecated" assert (not args.scst) or args.dataset == "coco", "scst support on coco only!" if args.scst: assert args.dataset == "coco", "scst support on coco only!" assert args.max_pred == 0 and args.mask_prob == 0, "no mask for scst!" rl_crit = RewardCriterion() if args.enable_butd: assert args.len_vis_input == 100 args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = ( os.path.join(args.image_root, args.region_det_file_prefix) if args.dataset in ("cc", "coco") and args.region_det_file_prefix != "" else "" ) # output config os.makedirs(args.output_dir, exist_ok=True) json.dump( args.__dict__, open(os.path.join(args.output_dir, "eval_opt.json"), "w"), sort_keys=True, indent=2, ) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode="w", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) if args.local_rank == -1 or args.no_cuda: device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" ) n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=args.global_rank, ) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16 ) ) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps ) ) args.train_batch_size = int( args.train_batch_size / args.gradient_accumulation_steps ) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # plotting loss, optional if args.enable_visdom: import visdom vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir) vis_window = {"iter": None, "score": None} # preprocessing/data loader tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank), ) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer if args.do_train: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ "max_len_b": args.max_len_b, "trunc_seg": args.trunc_seg, "always_truncate_tail": args.always_truncate_tail, }, mask_image_regions=args.mask_image_regions, mode="s2s", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == "vqa2"), ) ] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ "max_len_b": args.max_len_b, "trunc_seg": args.trunc_seg, "always_truncate_tail": args.always_truncate_tail, }, mask_image_regions=args.mask_image_regions, mode="bi", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == "vqa2"), ) ) train_dataset = seq2seq_loader.Img2txtDataset( args.src_file, args.image_root, args.split, args.train_batch_size, data_tokenizer, args.max_seq_length, file_valid_jpgs=args.file_valid_jpgs, bi_uni_pipeline=bi_uni_pipeline, use_num_imgs=args.use_num_imgs, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, enable_butd=args.enable_butd, tasks=args.tasks, ) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True, ) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int( len(train_dataloader) * args.num_train_epochs * 1.0 / args.gradient_accumulation_steps ) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == "img2txt" else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"] ) # index in BERT vocab: 103, 102, 0 if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init assert args.scst == False, "must init from maximum likelihood training" _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks, ) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load( os.path.join(args.output_dir, "model.{0}.bin".format(recover_step)) ) # recover_step == number of epochs global_step = math.floor( recover_step * t_total * 1.0 / args.num_train_epochs ) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path) global_step = 0 if not args.scst: model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks, ) else: model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=task_idx_proj, mask_word_id=mask_word_id, search_beam_size=1, eos_id=eos_word_ids, mode="s2s", enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, ) del model_recover torch.cuda.empty_cache() # deprecated # from vlp.resnet import resnet # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning if args.fp16: model.half() # cnn.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) # cnn.to(device) if args.local_rank != -1: try: # from apex.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # cnn = DDP(cnn) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # cnn = DataParallelImbalance(cnn) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam( optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0, ) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State( optimizer, static_loss_scale=args.loss_scale ) else: optimizer = BertAdam( optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total, ) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load( os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step)) ) if hasattr(optim_recover, "state_dict"): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: model.eval() losses = [] for batch in tqdm(train_dataloader): # wrangle batch batch = [t.to(device) for t in batch] ( input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels, ) = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = ( conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1) .permute(0, 2, 1) .contiguous() ) # compute loss masked_lm_loss, _, _ = model( conv_feats, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, ans_labels, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions, drop_worst_ratio=args.max_drop_worst_ratio ) # average across multiple GPUs if n_gpu > 1: masked_lm_loss = masked_lm_loss.mean() losses.append(masked_lm_loss.item()) print(args.split, 'perplexity:', np.exp(np.mean(losses)))
def main(): parser = argparse.ArgumentParser() # Path parameters parser.add_argument("--data_dir", default=None, type=str, required=True, help="The raw data dir.") parser.add_argument("--vocab_path", default=None, type=str, required=True, help="bert vocab path") parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--model_output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_dir", default='', type=str, required=True, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, help="The param init of pretrain or finetune") parser.add_argument("--optim_recover_path", default=None, type=str, help="The file of pretraining optimizer.") # Data Process Parameters parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--max_len_a', type=int, default=0, help="Truncate_config: maximum length of segment A.") parser.add_argument('--max_len_b', type=int, default=0, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument( "--mask_prob_eos", default=0, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=20, help="Max tokens of prediction.") parser.add_argument('--mask_source_words', action='store_true', help="Whether to mask source words for training") parser.add_argument('--skipgram_prb', type=float, default=0.0, help='prob of ngram mask') parser.add_argument('--skipgram_size', type=int, default=1, help='the max size of ngram mask') parser.add_argument('--mask_whole_word', action='store_true', help="Whether masking a whole word.") parser.add_argument('--do_l2r_training', action='store_true', help="Whether to do left to right training") parser.add_argument( '--has_sentence_oracle', action='store_true', help="Whether to have sentence level oracle for training. " "Only useful for summary generation") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") parser.add_argument("--num_workers", default=0, type=int, help="Number of workers for the data loader.") # Model Paramters parser.add_argument("--sop", action='store_true', help="whether use sop task.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for eval.") parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate for hidden states.") parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate for attention probabilities.") parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") # Train Eval Test Paramters parser.add_argument("--checkpoint_steps", required=True, type=int, help="save model eyery checkpoint_steps") parser.add_argument("--total_steps", required=True, type=int, help="all steps of training model") parser.add_argument("--max_checkpoint", required=True, type=int, help="max saved model in model_output_dir") parser.add_argument( "--examples_size_once", type=int, default=1000, help="read how many examples every time in pretrain or finetune") parser.add_argument("--local_rank", type=int, default=-1, help="process rank in local") parser.add_argument("--local_debug", action='store_true', help="whether debug") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--fine_tune", action='store_true', help="Whether to run fine_tune.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 16-bit for embeddings" ) parser.add_argument( '--loss_scale', type=str, default='dynamic', help= '(float or str, optional, default=None): Optional property override. ' 'If passed as a string,must be a string representing a number, e.g., "128.0", or the string "dynamic".' ) parser.add_argument( '--opt_level', type=str, default='O1', help= ' (str, optional, default="O1"): Pure or mixed precision optimization level. ' 'Accepted values are "O0", "O1", "O2", and "O3", explained in detail above.' ) parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) # Other Patameters parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--rank', type=int, default=0, help="global rank of current process") parser.add_argument("--world_size", default=2, type=int, help="Number of process(显卡)") args = parser.parse_args() cur_env = os.environ args.rank = int(cur_env.get('RANK', -1)) args.world_size = int(cur_env.get('WORLD_SIZE', -1)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) assert args.train_batch_size >= 1, 'batch_size < 1 ' # 更新一次模型参数需要多少个样本 examples_per_update = args.world_size * args.train_batch_size * args.gradient_accumulation_steps args.examples_size_once = args.examples_size_once // examples_per_update * examples_per_update if args.fine_tune: args.examples_size_once = examples_per_update os.makedirs(args.model_output_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.model_output_dir, 'unilm_config.json'), 'w'), sort_keys=True, indent=2) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = torch.cuda.device_count() dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.rank) logger.info( "world_size:{}, rank:{}, device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}" .format(args.world_size, args.rank, device, n_gpu, bool(args.world_size > 1), args.fp16)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.fine_tune and not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() tokenizer = BertTokenizer.from_pretrained(args.vocab_path, do_lower_case=args.do_lower_case) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings if args.local_rank == 0: dist.barrier() bi_uni_pipeline = [ Preprocess4Seq2seq(args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_a': args.max_len_a, 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_source_words=args.mask_source_words, skipgram_prb=args.skipgram_prb, skipgram_size=args.skipgram_size, mask_whole_word=args.mask_whole_word, mode="s2s", has_oracle=args.has_sentence_oracle, num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift, fine_tune=args.fine_tune) ] file_oracle = None if args.has_sentence_oracle: file_oracle = os.path.join(args.data_dir, 'train.oracle') # t_total表示模型参数更新的次数 # t_total = args.train_steps # Prepare model recover_step = _get_max_epoch_model(args.model_output_dir) cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 num_sentlvl_labels = 2 if args.has_sentence_oracle else 0 relax_projection = 4 if args.relax_projection else 0 if args.local_rank not in (-1, 0): # Make sure only the first process in distributed training will download model & vocab dist.barrier() if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load(os.path.join( args.output_model_dir, "model.{0}.bin".format(recover_step)), map_location='cpu') # recover_step == number of epochs global_step = math.floor(recover_step * args.checkpoint_step) # 预训练时模型的参数初始化,比如使用chinese-bert-base的模型参数进行初始化 elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path, map_location='cpu') global_step = 0 model = BertForPreTrainingLossMask.from_pretrained( state_dict=model_recover, num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size, config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids, ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv, seg_emb=args.seg_emb, local_debug=args.local_debug) total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info("模型参数: {}".format(total_trainable_params)) if args.local_rank == 0: dist.barrier() model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=args.total_steps) if args.amp and args.fp16: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale) from apex.parallel import DistributedDataParallel as DDP model = DDP(model) else: from torch.nn.parallel import DistributedDataParallel as DDP model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) if recover_step: logger.info("** ** * Recover optimizer: %d * ** **", recover_step) optim_recover = torch.load(os.path.join( args.model_output_dir, "optim.{0}.bin".format(recover_step)), map_location='cpu') if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.fp16 and args.amp: amp_recover = torch.load(os.path.join( args.model_output_dir, "amp.{0}.bin".format(recover_step)), map_location='cpu') logger.info("** ** * Recover amp: %d * ** **", recover_step) amp.load_state_dict(amp_recover) logger.info("** ** * CUDA.empty_cache() * ** **") torch.cuda.empty_cache() if args.rank == 0: writer = SummaryWriter(log_dir=args.log_dir) logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Param Update Num = %d", args.total_steps) model.train() PRE = "rank{},local_rank {},".format(args.rank, args.local_rank) step = 1 start = time.time() train_data_loader = TrainDataLoader( bi_uni_pipline=bi_uni_pipeline, examples_size_once=args.examples_size_once, world_size=args.world_size, train_batch_size=args.train_batch_size, num_workers=args.num_workers, data_dir=args.data_dir, tokenizer=tokenizer, max_len=args.max_seq_length) best_result = -float('inf') for global_step, batch in enumerate(train_data_loader, start=global_step): batch = [t.to(device) if t is not None else None for t in batch] if args.has_sentence_oracle: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch else: input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label = batch oracle_pos, oracle_weights, oracle_labels = None, None, None if not args.sop: # 不使用sop训练任务 sop_label = None loss_tuple = model(input_ids, segment_ids, input_mask, masked_lm_labels=lm_label_ids, next_sentence_label=sop_label, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, masked_pos_2=oracle_pos, masked_weights_2=oracle_weights, masked_labels_2=oracle_labels, mask_qkv=mask_qkv) masked_lm_loss, next_sentence_loss = loss_tuple # mean() to average on multi-gpu. if n_gpu > 1: masked_lm_loss = masked_lm_loss.mean() next_sentence_loss = next_sentence_loss.mean() # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: masked_lm_loss = masked_lm_loss / args.gradient_accumulation_steps next_sentence_loss = next_sentence_loss / args.gradient_accumulation_steps if not args.sop: loss = masked_lm_loss else: loss = masked_lm_loss + next_sentence_loss if args.fp16 and args.amp: with amp.scale_loss(loss, optimizer) as scale_loss: scale_loss.backward() else: loss.backward() if (global_step + 1) % args.gradient_accumulation_steps == 0: if args.rank == 0: writer.add_scalar('unilm/mlm_loss', masked_lm_loss, global_step) writer.add_scalar('unilm/sop_loss', next_sentence_loss, global_step) lr_this_step = args.learning_rate * warmup_linear( global_step / args.total_steps, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() #global_step += 1 #更新一次模型参数花费的时间,单位:秒 cost_time_per_update = time.time() - start # 更新完所有参数花费的时间,单位:小时 need_time = cost_time_per_update * (args.total_steps - global_step) / 3600.0 cost_time_per_chectpoint = cost_time_per_update * args.checkpoint_steps / 3600.0 start = time.time() if args.local_rank in [-1, 0]: INFO = PRE + '当前/chcklpoint_steps/total:{}/{}/{},loss{}/{},更新一次参数{}秒,checkpoint_steps {}小时,' \ '训练完成{}小时\n'.format(global_step, args.checkpoint_steps, args.total_steps, round(masked_lm_loss.item(), 5), round(next_sentence_loss.item(), 5), round(cost_time_per_update, 4), round(cost_time_per_chectpoint, 3), round(need_time, 3)) print(INFO) # Save a trained model if (global_step + 1) % args.checkpoint_steps == 0: checkpoint_index = (global_step + 1) % args.checkpoint_steps if args.rank >= 0: train_data_loader.train_sampler.set_epoch(checkpoint_index) # if args.eval: # # 如果是pretrain,验证MLM;如果微调,验证评价指标 # result = None #if best_result < result and _get_checkpont_num(args.model_output_num): if args.rank in [0, -1]: logger.info("** ** * Saving model and optimizer * ** **") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( args.model_output_dir, "model.{0}.bin".format(checkpoint_index)) torch.save(model_to_save.state_dict(), output_model_file) output_optim_file = os.path.join( args.model_output_dir, "optim.{0}.bin".format(checkpoint_index)) torch.save(optimizer.state_dict(), output_optim_file) if args.fp16 and args.amp: logger.info("** ** * Saving amp state * ** **") output_amp_file = os.path.join( args.model_output_dir, "amp.{0}.bin".format(checkpoint_index)) torch.save(amp.state_dict(), output_amp_file) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.rank == 0: writer.close() print('** ** * train finished * ** **')
def main(): args = process_args() os.makedirs(args.output_dir, exist_ok=True) if args.enable_butd: if args.visdial_v == '1.0': assert (args.len_vis_input == 36) elif args.visdial_v == '0.9': assert (args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join( args.image_root, args.region_det_file_prefix) if args.dataset in ( 'cc', 'coco') and args.region_det_file_prefix != '' else '' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1 tokenizer.max_len = args.max_seq_length bi_uni_pipeline = [ Preprocess4TestVisdialDiscTest( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'len_vis_input': args.len_vis_input, 'max_len_hist_ques': args.max_len_hist_ques, 'max_len_ans': args.max_len_ans }, mode="bi", region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, image_features_hdfpath=args.image_features_hdfpath, visdial_v=args.visdial_v, pad_hist=args.pad_hist, inc_full_hist=args.inc_full_hist) ] amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 logger.info('Attempting to recover models from: {}'.format( args.model_recover_path)) if 0 == len(glob.glob(args.model_recover_path.strip())): logger.error('There are no models to recover. The program will exit.') sys.exit(1) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=0, max_position_embeddings=512, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(-1), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type, eval_disc=True) del model_recover if args.fp16: model.half() # cnn.half() model.to(device) # cnn.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) # cnn = torch.nn.DataParallel(cnn) torch.cuda.empty_cache() model.eval() def read_data(src_file): eval_lst = [] with open(src_file, "r", encoding='utf-8') as f_src: data = json.load(f_src)['data'] dialogs = data['dialogs'] questions = data['questions'] answers = data['answers'] img_idx = 0 for dialog in tqdm(dialogs): if img_idx < args.use_num_imgs or args.use_num_imgs == -1: img_id = dialog['image_id'] cap_tokens = tokenizer.tokenize(dialog['caption']) ques_id = [ item['question'] for item in dialog['dialog'] ] ques_tokens = [ tokenizer.tokenize(questions[id] + '?') for id in ques_id ] turn_num = len(ques_id) ans_id = [ item['answer'] for item in dialog['dialog'] if 'answer' in item.keys() ] ans_tokens = [ tokenizer.tokenize(answers[id]) for id in ans_id ] assert len(ques_id) == len(ans_id) + 1 ans_opts = dialog['dialog'][len( ans_id)]['answer_options'] ans_opts_tokens = [ tokenizer.tokenize(answers[id]) for id in ans_opts ] eval_lst.append( (img_id, cap_tokens, ques_tokens, ans_tokens, ans_opts_tokens, turn_num)) img_idx += 1 return eval_lst input_lines = read_data(args.src_file) next_i = 0 total_batch = math.ceil(len(input_lines) / args.batch_size) print('start the visdial decode evaluation...') ranks_json = [] scores_json = [] score_fn = args.save_ranks_path.replace('.json', '_score.json') with tqdm(total=total_batch) as pbar: while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf_id = [x[0] for x in _chunk] buf = [x[:-1] for x in _chunk] turn_id = [x[-1] for x in _chunk] next_i += args.batch_size instances = [] for instance in buf: instances.append(bi_uni_pipeline[0](instance)) with torch.no_grad(): batch_data = list(zip(*instances)) img, vis_pe = (torch.stack(x).to(device) for x in batch_data[-2:]) task_idx = torch.tensor(batch_data[-3], dtype=torch.long).to(device) conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data input_ids = torch.tensor(batch_data[0], dtype=torch.long).to(device) segment_ids = torch.tensor(batch_data[1], dtype=torch.long).to(device) input_mask = torch.stack(batch_data[2]).to(device) output_scores = model(conv_feats, vis_pe, input_ids, segment_ids, input_mask, task_idx=task_idx) output_scores = output_scores[:, :, 1] # [batch_size, num_options] ranks = scores_to_ranks(output_scores.unsqueeze( 1)) # [batch_size, num_rounds, num_options] for i in range(len(buf_id)): # Cast into types explicitly to ensure no errors in schema. # Round ids are 1-10, not 0-9 if args.split == "test": ranks_json.append({ "image_id": buf_id[i], "round_id": turn_id[i], "ranks": [rank.item() for rank in ranks[i][0]] }) scores_json.append({ "image_id": buf_id[i], "round_id": turn_id[i], "ranks": [rank.item() for rank in ranks[i][0]], "scores": [score.item() for score in output_scores[i]] }) pbar.update(1) json.dump(ranks_json, open(args.save_ranks_path, "w")) json.dump(scores_json, open(score_fn, "w")) logger.info("Finish writing rankings into %s" % (args.save_ranks_path)) logger.info("Finish writing scores into %s" % score_fn)
def main(): args = process_args() os.makedirs(args.output_dir, exist_ok=True) if args.enable_butd: if args.visdial_v == '1.0' and not args.no_vision: assert (args.len_vis_input == 36) elif args.visdial_v == '0.9': assert (args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join( args.image_root, args.region_det_file_prefix) if args.dataset in ( 'cc', 'coco') and args.region_det_file_prefix != '' else '' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) args.max_seq_length = args.len_vis_input + 2 + args.max_len_hist_ques + 2 + args.max_len_ans + 1 tokenizer.max_len = args.max_seq_length bi_uni_pipeline = [ Preprocess4TestVisdialDisc( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'len_vis_input': args.len_vis_input, 'max_len_hist_ques': args.max_len_hist_ques, 'max_len_ans': args.max_len_ans }, mode="bi", region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, image_features_hdfpath=args.image_features_hdfpath, visdial_v=args.visdial_v, pad_hist=args.pad_hist, inc_full_hist=args.inc_full_hist, only_qa=args.only_qa) ] amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 logger.info('Attempting to recover models from: {}'.format( args.model_recover_path)) if 0 == len(glob.glob(args.model_recover_path.strip())): logger.error('There are no models to recover. The program will exit.') sys.exit(1) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=0, max_position_embeddings=512, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(-1), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, visdial_v=args.visdial_v, loss_type=args.loss_type, eval_disc=True, add_attn_fuse=args.add_attn_fuse, no_vision=args.no_vision) del model_recover if args.fp16: model.half() # cnn.half() model.to(device) # cnn.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) # cnn = torch.nn.DataParallel(cnn) torch.cuda.empty_cache() model.eval() def read_data(src_file): eval_lst = [] with open(src_file, "r", encoding='utf-8') as f_src: data = json.load(f_src)['data'] dialogs = data['dialogs'] questions = data['questions'] answers = data['answers'] img_idx = 0 for dialog in tqdm(dialogs): if img_idx < args.use_num_imgs or args.use_num_imgs == -1: img_id = dialog['image_id'] cap_tokens = tokenizer.tokenize(dialog['caption']) ques_id = [ item['question'] for item in dialog['dialog'] ] ques_tokens = [ tokenizer.tokenize(questions[id] + '?') for id in ques_id ] ans_id = [item['answer'] for item in dialog['dialog']] ans_tokens = [ tokenizer.tokenize(answers[id]) for id in ans_id ] gt_id = [item['gt_index'] for item in dialog['dialog']] ans_opts = [ item['answer_options'] for item in dialog['dialog'] ] ans_opts_tokens = [[ tokenizer.tokenize(answers[id]) for id in ans ] for ans in ans_opts] assert len(ques_tokens) == len(ans_tokens) == len(ans_opts_tokens) == 10, \ "ques num: %d, ans num: %d, ans opt num: %d" % ( len(ques_tokens), len(ans_tokens), len(ans_opts_tokens)) assert all([ len(ans_opt) == 100 for ans_opt in ans_opts_tokens ]), "all the answer have 100 options" eval_lst.append((img_id, cap_tokens, ques_tokens, ans_tokens, ans_opts_tokens, gt_id)) img_idx += 1 return eval_lst def get_gt_rel_dict(fname): gt_rel_dict = {} gt_rel_data = json.load(open(fname)) for item in gt_rel_data: image_id = item['image_id'] round_id = item['round_id'] gt_relevance = item['gt_relevance'] # each image only at most has one turn having dense annotation if image_id not in gt_rel_dict: gt_rel_dict[image_id] = (round_id, gt_relevance) return gt_rel_dict if args.gt_rel_file != '': gt_rel_dict = get_gt_rel_dict(args.gt_rel_file) input_lines = read_data(args.src_file) next_i = 0 total_batch = math.ceil(len(input_lines) / args.batch_size) print('start the visdial decode evaluation...') t0 = time.time() ranks_json = [] sparse_metrics = SparseGTMetrics() ndcg = NDCG() with tqdm(total=total_batch) as pbar: while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf_id = [x[0] for x in _chunk] buf = [x[:-1] for x in _chunk] buf_gt_id = [x[-1] for x in _chunk] next_i += args.batch_size instances = [] for instance in buf: instances.append(bi_uni_pipeline[0](instance)) with torch.no_grad(): buf_gt_id = torch.tensor(buf_gt_id).long().to(device) batch_data = list(zip(*instances)) task_idx = torch.tensor(batch_data[-3], dtype=torch.long).to(device) if args.no_vision: conv_feats = [] vis_pe = [] else: img, vis_pe = (torch.stack(x).to(device) for x in batch_data[-2:]) conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data output_scores_turn = [] input_ids_turns = [[x[turn_i] for x in batch_data[0]] for turn_i in range(10)] segment_ids_turns = [[x[turn_i] for x in batch_data[1]] for turn_i in range(10)] input_mask_turns = [[x[turn_i] for x in batch_data[2]] for turn_i in range(10)] for turn_i in range(10): input_ids = torch.tensor(input_ids_turns[turn_i], dtype=torch.long).to(device) segment_ids = torch.tensor(segment_ids_turns[turn_i], dtype=torch.long).to(device) input_mask = torch.stack( input_mask_turns[turn_i]).to(device) output_scores = model(conv_feats, vis_pe, input_ids, segment_ids, input_mask, task_idx=task_idx) output_scores = output_scores[:, :, 1] # [batch_size, num_options] output_scores_turn.append(output_scores) output_scores_turn = torch.stack( output_scores_turn, 1) # [batch_size, num_rounds, num_options] ranks = scores_to_ranks(output_scores_turn) # output_scores_turn_cheat = output_scores_turn.scatter_(2, buf_gt_id.unsqueeze(2), 100.0) sparse_metrics.observe(output_scores_turn, buf_gt_id) for i in range(len(buf_id)): # Cast into types explicitly to ensure no errors in schema. # Round ids are 1-10, not 0-9 if args.split == "val": for j in range(10): ranks_json.append({ "image_id": buf_id[i], "round_id": int(j + 1), "ranks": [rank.item() for rank in ranks[i][j]], }) if args.gt_rel_file: scores = [] gt_rels = [] for i in range(len(buf_id)): if buf_id[i] in gt_rel_dict: turn_idx, gt_rel = gt_rel_dict[buf_id[i]] scores.append(output_scores_turn[i, turn_idx - 1, :]) gt_rels.append( torch.tensor( gt_rel, dtype=torch.float32).to(device)) scores = torch.stack(scores) gt_rels = torch.stack(gt_rels) ndcg.observe(scores, gt_rels) pbar.update(1) json.dump(ranks_json, open(args.save_ranks_path, "w")) logger.info("Finish writing rankings into %s" % (args.save_ranks_path)) if args.split == "val": fw = open(args.save_ranks_path.replace('.json', '_results.txt'), "w") all_metrics = {} all_metrics.update(sparse_metrics.retrieve(reset=True)) if args.gt_rel_file: all_metrics.update(ndcg.retrieve(reset=True)) for metric_name, metric_value in all_metrics.items(): print(f"{metric_name}: {metric_value}") fw.write("%s: %.6f\n" % (metric_name, metric_value))