def get_model(args, version=None): """Build the model.""" print_rank_0('building Bert model ...') if version is None: model = BertMixtureModel(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, layernorm_epsilon=args.layernorm_epsilon, max_sequence_length=args.max_position_embeddings, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True, num_experts=args.num_experts, type_vocab_size=2) elif version == "v0": model = BertMixtureModel_v0(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, layernorm_epsilon=args.layernorm_epsilon, max_sequence_length=args.max_position_embeddings, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True, num_experts=args.num_experts, type_vocab_size=2) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) #To prevent OOM for model sizes that cannot fit in GPU memory in full precision if args.deepspeed and args.fp16: model.half() # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training. if USE_TORCH_DDP: i = torch.cuda.current_device() model = DDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) else: model = DDP(model) return model
def get_model(args): """Build the model.""" print_rank_0('building Bert MoE model ...') model = BertMixtureModel( num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, layernorm_epsilon=args.layernorm_epsilon, max_sequence_length=args.max_position_embeddings, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=False, num_experts=args.num_experts, type_vocab_size=2) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training. model = DDP(model) return model