def prepare_model_and_optimizer(self): # Prepare model self.config = BertConfig.from_json_file(self.args.config_file) # Padding for divisibility by 8 if self.config.vocab_size % 8 != 0: self.config.vocab_size += 8 - (self.config.vocab_size % 8) self.model = BertForPreTraining(self.config) self.another_model = BertForPreTraining(self.config) self.model.to(self.device) self.another_model.to(self.device) param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [] names = [] for n, p in param_optimizer: if not any(nd in n for nd in no_decay): optimizer_grouped_parameters.append({ 'params': [p], 'weight_decay': 0.01, 'name': n }) names.append({'params': [n], 'weight_decay': 0.01}) if any(nd in n for nd in no_decay): optimizer_grouped_parameters.append({ 'params': [p], 'weight_decay': 0.00, 'name': n }) names.append({'params': [n], 'weight_decay': 0.00}) if self.args.phase2: max_steps = self.args.max_steps tmp = max_steps * 10 r = self.args.phase1_end_step / tmp lr = self.args.learning_rate * (1 - r) else: max_steps = int(self.args.max_steps / 9 * 10) lr = self.args.learning_rate if self.args.optimizer == "lamb": self.optimizer = BertLAMB(optimizer_grouped_parameters, lr=lr, warmup=self.args.warmup_proportion if not self.args.phase2 else -1, t_total=max_steps) elif self.args.optimizer == "adam": self.optimizer = BertAdam(optimizer_grouped_parameters, lr=lr, warmup=self.args.warmup_proportion if not self.args.phase2 else -1, t_total=max_steps)
def prepare_model_and_optimizer(args, device): # Prepare model config = BertConfig.from_json_file(args.config_file) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) model = BertForPreTraining(config) checkpoint = None if not args.resume_from_checkpoint: global_step = 0 else: if args.resume_step == -1: model_names = [ f for f in os.listdir(args.output_dir) if f.endswith(".pt") ] args.resume_step = max([ int(x.split(".pt")[0].split("_")[1].strip()) for x in model_names ]) global_step = args.resume_step checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu") model.load_state_dict(checkpoint["model"], strict=False) if args.phase2: global_step -= args.phase1_end_step if is_main_process(): print("resume step from ", args.resume_step) model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ["bias", "gamma", "beta", "LayerNorm"] optimizer_grouped_parameters = [] names = [] count = 1 for n, p in param_optimizer: count += 1 if not any(nd in n for nd in no_decay): optimizer_grouped_parameters.append({ "params": [p], "weight_decay": 0.01, "name": n }) names.append({"params": [n], "weight_decay": 0.01}) if any(nd in n for nd in no_decay): optimizer_grouped_parameters.append({ "params": [p], "weight_decay": 0.00, "name": n }) names.append({"params": [n], "weight_decay": 0.00}) optimizer = BertLAMB(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=args.max_steps) if args.fp16: if args.loss_scale == 0: # optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) model, optimizer = amp.initialize( model, optimizer, opt_level="O2", loss_scale="dynamic", master_weights=False if args.accumulate_into_fp16 else True, ) else: # optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) model, optimizer = amp.initialize( model, optimizer, opt_level="O2", loss_scale=args.loss_scale, master_weights=False if args.accumulate_into_fp16 else True, ) amp._amp_state.loss_scalers[0]._loss_scale = 2**20 if args.resume_from_checkpoint: if args.phase2: keys = list(checkpoint["optimizer"]["state"].keys()) # Override hyperparameters from Phase 1 for key in keys: checkpoint["optimizer"]["state"][key]["step"] = global_step for iter, item in enumerate( checkpoint["optimizer"]["param_groups"]): checkpoint["optimizer"]["param_groups"][iter][ "t_total"] = args.max_steps checkpoint["optimizer"]["param_groups"][iter][ "warmup"] = args.warmup_proportion checkpoint["optimizer"]["param_groups"][iter][ "lr"] = args.learning_rate optimizer.load_state_dict(checkpoint["optimizer"]) # , strict=False) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint["optimizer"]) for param, saved_param in zip(amp.master_params(optimizer), checkpoint["master params"]): param.data.copy_(saved_param.data) if args.local_rank != -1: if not args.allreduce_post_accumulation: model = DDP( model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size()) else: flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0, )) elif args.n_gpu > 1: model = torch.nn.DataParallel(model) return model, optimizer, checkpoint, global_step