def get_model(tokenizer, args): """Build the model.""" print('building BERT model ...') model = BertModel(tokenizer, args) print(' > number of parameters: {}'.format( sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: print("fp16 mode") model = FP16_Module(model) if args.fp32_embedding: model.module.model.bert.embeddings.word_embeddings.float() model.module.model.bert.embeddings.position_embeddings.float() model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_tokentypes: model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_layernorm: for name, _module in model.named_modules(): if 'LayerNorm' in name: _module.float() # Wrap model for distributed training. if args.world_size > 1: model = DDP(model) return model
def get_model(args): """Build the model.""" print_rank_0('building BERT model ...') model = BertModel(args) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) if args.fp32_embedding: model.module.model.bert.embeddings.word_embeddings.float() if args.ds_type=='BERT': model.module.model.bert.embeddings.position_embeddings.float() else: model.module.model.bert.embeddings.token_position_embeddings.float() model.module.model.bert.embeddings.para_position_embeddings.float() model.module.model.bert.embeddings.sent_position_embeddings.float() model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_tokentypes: model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_layernorm: for name, _module in model.named_modules(): if 'LayerNorm' in name: _module.float() # Wrap model for distributed training. if args.DDP_impl == 'torch': i = torch.cuda.current_device() args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel model = args.DDP_type(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) elif args.DDP_impl == 'local': args.DDP_type = LocalDDP model = args.DDP_type(model) else: print_rank_0('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) exit() return model
def get_model(args): """Build the model.""" print_rank_0('building BERT model ...') model = BertModel(args) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) if args.fp32_embedding: model.module.model.bert.embeddings.word_embeddings.float() model.module.model.bert.embeddings.position_embeddings.float() model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_tokentypes: model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_layernorm: for name, _module in model.named_modules(): if 'LayerNorm' in name: _module.float() # Wrap model for distributed training. if USE_TORCH_DDP: i = torch.cuda.current_device() model = DDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) else: model = DDP(model) return model