def test_ofa(self): ofa_model = OFA(self.model, self.run_config, distill_config=self.distill_config, elastic_order=self.elastic_order) start_epoch = 0 for idx in range(len(self.run_config.n_epochs)): cur_idx = self.run_config.n_epochs[idx] for ph_idx in range(len(cur_idx)): cur_lr = self.run_config.init_learning_rate[idx][ph_idx] adam = paddle.optimizer.Adam( learning_rate=cur_lr, parameters=(ofa_model.parameters() + ofa_model.netAs_param)) for epoch_id in range(start_epoch, self.run_config.n_epochs[idx][ph_idx]): if epoch_id == 0: ofa_model.set_epoch(epoch_id) for model_no in range(self.run_config.dynamic_batch_size[ idx]): output, _ = ofa_model(self.data) loss = paddle.mean(output) if self.distill_config.mapping_layers != None: dis_loss = ofa_model.calc_distill_loss() loss += dis_loss dis_loss = dis_loss.numpy()[0] else: dis_loss = 0 print('epoch: {}, loss: {}, distill loss: {}'.format( epoch_id, loss.numpy()[0], dis_loss)) loss.backward() adam.minimize(loss) adam.clear_gradients() start_epoch = self.run_config.n_epochs[idx][ph_idx]
class TestOFAV2(unittest.TestCase): def setUp(self): model = ModelV1() sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0]) self.model = Convert(sp_net_config).convert(model) self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32') def test_ofa(self): self.ofa_model = OFA(self.model) self.ofa_model.set_epoch(0) self.ofa_model.set_task('expand_ratio') out, _ = self.ofa_model(self.images)
class TestExportCase1(unittest.TestCase): def setUp(self): model = ModelLinear1() data_np = np.random.random((3, 64)).astype(np.int64) self.data = paddle.to_tensor(data_np) self.ofa_model = OFA(model) self.ofa_model.set_epoch(0) outs, _ = self.ofa_model(self.data) self.config = self.ofa_model.current_config def test_export_model(self): self.ofa_model.export( self.config, input_shapes=[[3, 64]], input_dtypes=['int64']) assert len(self.ofa_model.ofa_layers) == 4
class Testelementwise(unittest.TestCase): def setUp(self): model = ModelElementwise() sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0]) self.model = Convert(sp_net_config).convert(model) self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32') def test_elementwise(self): self.ofa_model = OFA(self.model) self.ofa_model.set_epoch(0) self.ofa_model.set_task('expand_ratio') out, _ = self.ofa_model(self.images) assert list( self.ofa_model._ofa_layers.keys()) == ['conv2.0', 'conv3.0']
class TestMultiExit(unittest.TestCase): def setUp(self): self.images = paddle.randn(shape=[1, 3, 224, 224], dtype='float32') model = ModelMultiExit() sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0]) self.model = Convert(sp_net_config).convert(model) def test_multiexit(self): self.ofa_model = OFA(self.model) self.ofa_model.set_epoch(0) self.ofa_model.set_task('expand_ratio') out, _ = self.ofa_model(self.images) assert list(self.ofa_model._ofa_layers.keys()) == [ 'conv1.0', 'block1.0', 'block1.4', 'block2.0', 'out2.0' ]
class TestExportCase2(unittest.TestCase): def setUp(self): model = ModelLinear() data_np = np.random.random((3, 64)).astype(np.int64) self.data = paddle.to_tensor(data_np) self.ofa_model = OFA(model) self.ofa_model.set_epoch(0) outs, _ = self.ofa_model(self.data) self.config = self.ofa_model.current_config def test_export_model_linear2(self): config = self.ofa_model._sample_config( task='expand_ratio', phase=None, sample_type='smallest') ex_model = self.ofa_model.export( config, input_shapes=[[3, 64]], input_dtypes=['int64']) ex_model(self.data) assert len(self.ofa_model.ofa_layers) == 3
class TestShortCut(unittest.TestCase): def setUp(self): model = resnet50() sp_net_config = supernet(expand_ratio=[0.25, 0.5, 1.0]) self.model = Convert(sp_net_config).convert(model) self.images = paddle.randn(shape=[2, 3, 224, 224], dtype='float32') self._test_clear_search_space() def _test_clear_search_space(self): self.ofa_model = OFA(self.model) self.ofa_model.set_epoch(0) outs, _ = self.ofa_model(self.images) self.config = self.ofa_model.current_config def test_export_model(self): self.ofa_model.export( self.config, input_shapes=[[2, 3, 224, 224]], input_dtypes=['float32']) assert len(self.ofa_model.ofa_layers) == 37
class TestShortcutSkiplayers(unittest.TestCase): def setUp(self): model = ModelShortcut() sp_net_config = supernet(expand_ratio=[0.5, 1.0]) self.model = Convert(sp_net_config).convert(model) self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32') self.init_config() self.ofa_model = OFA(self.model, run_config=self.run_config) self.ofa_model._clear_search_space(self.images) def init_config(self): default_run_config = {'skip_layers': ['branch1.6']} self.run_config = RunConfig(**default_run_config) def test_shortcut(self): self.ofa_model.set_epoch(0) self.ofa_model.set_task('expand_ratio') for i in range(5): self.ofa_model(self.images) assert list(self.ofa_model._ofa_layers.keys()) == ['branch2.0']
def do_train(args): paddle.set_device("gpu" if args.n_gpu else "cpu") if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() set_seed(args) args.task_name = args.task_name.lower() dataset_class, metric_class = TASK_CLASSES[args.task_name] args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] train_ds = dataset_class.get_datasets(['train']) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) trans_func = partial(convert_example, tokenizer=tokenizer, label_list=train_ds.get_labels(), max_seq_length=args.max_seq_length) train_ds = train_ds.apply(trans_func, lazy=True) train_batch_sampler = paddle.io.DistributedBatchSampler( train_ds, batch_size=args.batch_size, shuffle=True) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # input Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment Stack(), # length Stack(dtype="int64" if train_ds.get_labels() else "float32") # label ): [data for i, data in enumerate(fn(samples)) if i != 2] train_data_loader = DataLoader(dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True) if args.task_name == "mnli": dev_dataset_matched, dev_dataset_mismatched = dataset_class.get_datasets( ["dev_matched", "dev_mismatched"]) dev_dataset_matched = dev_dataset_matched.apply(trans_func, lazy=True) dev_dataset_mismatched = dev_dataset_mismatched.apply(trans_func, lazy=True) dev_batch_sampler_matched = paddle.io.BatchSampler( dev_dataset_matched, batch_size=args.batch_size, shuffle=False) dev_data_loader_matched = DataLoader( dataset=dev_dataset_matched, batch_sampler=dev_batch_sampler_matched, collate_fn=batchify_fn, num_workers=0, return_list=True) dev_batch_sampler_mismatched = paddle.io.BatchSampler( dev_dataset_mismatched, batch_size=args.batch_size, shuffle=False) dev_data_loader_mismatched = DataLoader( dataset=dev_dataset_mismatched, batch_sampler=dev_batch_sampler_mismatched, collate_fn=batchify_fn, num_workers=0, return_list=True) else: dev_dataset = dataset_class.get_datasets(["dev"]) dev_dataset = dev_dataset.apply(trans_func, lazy=True) dev_batch_sampler = paddle.io.BatchSampler(dev_dataset, batch_size=args.batch_size, shuffle=False) dev_data_loader = DataLoader(dataset=dev_dataset, batch_sampler=dev_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True) num_labels = 1 if train_ds.get_labels() == None else len( train_ds.get_labels()) model = model_class.from_pretrained(args.model_name_or_path, num_classes=num_labels) if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) # Step1: Initialize a dictionary to save the weights from the origin BERT model. origin_weights = {} for name, param in model.named_parameters(): origin_weights[name] = param # Step2: Convert origin model to supernet. sp_config = supernet(expand_ratio=args.width_mult_list) model = Convert(sp_config).convert(model) # Use weights saved in the dictionary to initialize supernet. utils.set_state_dict(model, origin_weights) del origin_weights # Step3: Define teacher model. teacher_model = model_class.from_pretrained(args.model_name_or_path, num_classes=num_labels) # Step4: Config about distillation. mapping_layers = ['bert.embeddings'] for idx in range(model.bert.config['num_hidden_layers']): mapping_layers.append('bert.encoder.layers.{}'.format(idx)) default_distill_config = { 'lambda_distill': 0.1, 'teacher_model': teacher_model, 'mapping_layers': mapping_layers, } distill_config = DistillConfig(**default_distill_config) # Step5: Config in supernet training. ofa_model = OFA(model, distill_config=distill_config, elastic_order=['width']) criterion = paddle.nn.loss.CrossEntropyLoss() if train_ds.get_labels( ) else paddle.nn.loss.MSELoss() metric = metric_class() if args.task_name == "mnli": dev_data_loader = (dev_data_loader_matched, dev_data_loader_mismatched) # Step6: Calculate the importance of neurons and head, # and then reorder them according to the importance. head_importance, neuron_importance = utils.compute_neuron_head_importance( args.task_name, ofa_model.model, dev_data_loader, loss_fct=criterion, num_layers=model.bert.config['num_hidden_layers'], num_heads=model.bert.config['num_attention_heads']) reorder_neuron_head(ofa_model.model, head_importance, neuron_importance) lr_scheduler = paddle.optimizer.lr.LambdaDecay( args.learning_rate, lambda current_step, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps if args.max_steps > 0 else (len(train_data_loader) * args.num_train_epochs): float( current_step) / float(max(1, num_warmup_steps)) if current_step < num_warmup_steps else max( 0.0, float(num_training_steps - current_step) / float( max(1, num_training_steps - num_warmup_steps)))) optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, epsilon=args.adam_epsilon, parameters=ofa_model.model.parameters(), weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in [ p.name for n, p in ofa_model.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) global_step = 0 tic_train = time.time() for epoch in range(args.num_train_epochs): # Step7: Set current epoch and task. ofa_model.set_epoch(epoch) ofa_model.set_task('width') for step, batch in enumerate(train_data_loader): global_step += 1 input_ids, segment_ids, labels = batch for width_mult in args.width_mult_list: # Step8: Broadcast supernet config from width_mult, # and use this config in supernet training. net_config = apply_config(ofa_model, width_mult) ofa_model.set_net_config(net_config) logits, teacher_logits = ofa_model(input_ids, segment_ids, attention_mask=[None, None]) rep_loss = ofa_model.calc_distill_loss() if args.task_name == 'sts-b': logit_loss = 0.0 else: logit_loss = soft_cross_entropy(logits, teacher_logits.detach()) loss = rep_loss + args.lambda_logit * logit_loss loss.backward() optimizer.step() lr_scheduler.step() ofa_model.model.clear_gradients() if global_step % args.logging_steps == 0: if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: logger.info( "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" % (global_step, epoch, step, loss, args.logging_steps / (time.time() - tic_train))) tic_train = time.time() if global_step % args.save_steps == 0: if args.task_name == "mnli": evaluate(teacher_model, criterion, metric, dev_data_loader_matched, width_mult=100) evaluate(teacher_model, criterion, metric, dev_data_loader_mismatched, width_mult=100) else: evaluate(teacher_model, criterion, metric, dev_data_loader, width_mult=100) for idx, width_mult in enumerate(args.width_mult_list): net_config = apply_config(ofa_model, width_mult) ofa_model.set_net_config(net_config) tic_eval = time.time() if args.task_name == "mnli": acc = evaluate(ofa_model, criterion, metric, dev_data_loader_matched, width_mult) evaluate(ofa_model, criterion, metric, dev_data_loader_mismatched, width_mult) print("eval done total : %s s" % (time.time() - tic_eval)) else: acc = evaluate(ofa_model, criterion, metric, dev_data_loader, width_mult) print("eval done total : %s s" % (time.time() - tic_eval)) if (not args.n_gpu > 1 ) or paddle.distributed.get_rank() == 0: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) if not os.path.exists(output_dir): os.makedirs(output_dir) # need better way to get inner model of DataParallel model_to_save = model._layers if isinstance( model, paddle.DataParallel) else model model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental if args.use_lr_decay: opt = AdamW(learning_rate=LinearDecay( args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=ofa_model.model.parameters(), weight_decay=args.wd, grad_clip=g_clip) else: opt = AdamW(args.lr, parameter_list=ofa_model.model.parameters(), weight_decay=args.wd, grad_clip=g_clip) for epoch in range(max(run_config.n_epochs[-1])): ofa_model.set_epoch(epoch) if epoch <= int(max(run_config.n_epochs[0])): ofa_model.set_task('width') depth_mult_list = [1.0] else: ofa_model.set_task('depth') depth_mult_list = run_config.elastic_depth for step, d in enumerate( tqdm(train_ds.start(place), desc='training')): ids, sids, label = d accumulate_gradients = dict() for param in opt._parameter_list: accumulate_gradients[param.name] = 0.0 for depth_mult in depth_mult_list:
def do_train(args): paddle.set_device(args.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() set_seed(args) args.task_name = args.task_name.lower() metric_class = METRIC_CLASSES[args.task_name] args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] train_ds = load_dataset('clue', args.task_name, splits='train') tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) trans_func = partial(convert_example, label_list=train_ds.label_list, tokenizer=tokenizer, max_seq_length=args.max_seq_length) train_ds = train_ds.map(trans_func, lazy=True) train_batch_sampler = paddle.io.DistributedBatchSampler( train_ds, batch_size=args.batch_size, shuffle=True) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # input Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment Stack(dtype="int64" if train_ds.label_list else "float32") # label ): fn(samples) train_data_loader = DataLoader(dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True) dev_ds = load_dataset('clue', args.task_name, splits='dev') dev_ds = dev_ds.map(trans_func, lazy=True) dev_batch_sampler = paddle.io.BatchSampler(dev_ds, batch_size=args.batch_size, shuffle=False) dev_data_loader = DataLoader(dataset=dev_ds, batch_sampler=dev_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True) num_labels = 1 if train_ds.label_list == None else len(train_ds.label_list) model = model_class.from_pretrained(args.model_name_or_path, num_classes=num_labels) # Step1: Initialize a dictionary to save the weights from the origin PPMiniLM model. origin_weights = model.state_dict() # Step2: Convert origin model to supernet. sp_config = supernet(expand_ratio=[1.0]) model = Convert(sp_config).convert(model) # Use weights saved in the dictionary to initialize supernet. utils.set_state_dict(model, origin_weights) del origin_weights super_sd = paddle.load( os.path.join(args.model_name_or_path, 'model_state.pdparams')) model.set_state_dict(super_sd) # Step3: Define teacher model. teacher_model = model_class.from_pretrained(args.model_name_or_path, num_classes=num_labels) # Step4: Config about distillation. mapping_layers = ['ppminilm.embeddings'] for idx in range(model.ppminilm.config['num_hidden_layers']): mapping_layers.append('ppminilm.encoder.layers.{}'.format(idx)) default_distill_config = { 'lambda_distill': 0.1, 'teacher_model': teacher_model, 'mapping_layers': mapping_layers, } distill_config = DistillConfig(**default_distill_config) # Step5: Config in supernet training. ofa_model = OFA(model, distill_config=distill_config, elastic_order=['width']) criterion = paddle.nn.loss.CrossEntropyLoss( ) if train_ds.label_list else paddle.nn.loss.MSELoss() metric = metric_class() #### Step6: Calculate the importance of neurons and head, #### and then reorder them according to the importance. head_importance, neuron_importance = nlp_utils.compute_neuron_head_importance( args.task_name, ofa_model.model, dev_data_loader, loss_fct=criterion, num_layers=model.ppminilm.config['num_hidden_layers'], num_heads=model.ppminilm.config['num_attention_heads']) reorder_neuron_head(ofa_model.model, head_importance, neuron_importance) if paddle.distributed.get_world_size() > 1: ofa_model.model = paddle.DataParallel(ofa_model.model) if args.max_steps > 0: num_training_steps = args.max_steps num_train_epochs = math.ceil(num_training_steps / len(train_data_loader)) else: num_training_steps = len(train_data_loader) * args.num_train_epochs num_train_epochs = args.num_train_epochs warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, warmup) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. decay_params = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, beta1=0.9, beta2=0.999, epsilon=args.adam_epsilon, parameters=model.parameters(), weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_params, grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm)) global_step = 0 tic_train = time.time() best_res = 0.0 for epoch in range(num_train_epochs): # Step7: Set current epoch and task. ofa_model.set_epoch(epoch) ofa_model.set_task('width') for step, batch in enumerate(train_data_loader): global_step += 1 input_ids, segment_ids, labels = batch for width_mult in args.width_mult_list: # Step8: Broadcast supernet config from width_mult, # and use this config in supernet training. net_config = utils.dynabert_config(ofa_model, width_mult) ofa_model.set_net_config(net_config) logits, teacher_logits = ofa_model(input_ids, segment_ids, attention_mask=[None, None]) rep_loss = ofa_model.calc_distill_loss() logit_loss = soft_cross_entropy(logits, teacher_logits.detach()) loss = rep_loss + args.lambda_logit * logit_loss loss.backward() optimizer.step() lr_scheduler.step() optimizer.clear_grad() if global_step % args.logging_steps == 0: logger.info( "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" % (global_step, epoch, step, loss, args.logging_steps / (time.time() - tic_train))) tic_train = time.time() if global_step % args.save_steps == 0 or global_step == num_training_steps: tic_eval = time.time() evaluate(teacher_model, metric, dev_data_loader, width_mult=100) print("eval done total : %s s" % (time.time() - tic_eval)) for idx, width_mult in enumerate(args.width_mult_list): net_config = utils.dynabert_config(ofa_model, width_mult) ofa_model.set_net_config(net_config) tic_eval = time.time() res = evaluate(ofa_model, metric, dev_data_loader, width_mult) print("eval done total : %s s" % (time.time() - tic_eval)) if best_res < res: output_dir = args.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) # need better way to get inner model of DataParallel model_to_save = model._layers if isinstance( model, paddle.DataParallel) else model model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) best_res = res if global_step >= num_training_steps: print("best_res: ", best_res) return print("best_res: ", best_res)
def do_train(args): paddle.set_device(args.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() set_seed(args) args.task_name = args.task_name.lower() metric_class = METRIC_CLASSES[args.task_name] args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] train_ds = load_dataset('glue', args.task_name, splits="train") tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) trans_func = partial( convert_example, tokenizer=tokenizer, label_list=train_ds.label_list, max_seq_length=args.max_seq_length) train_ds = train_ds.map(trans_func, lazy=True) train_batch_sampler = paddle.io.DistributedBatchSampler( train_ds, batch_size=args.batch_size, shuffle=True) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # input Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment Stack(dtype="int64" if train_ds.label_list else "float32") # label ): fn(samples) train_data_loader = DataLoader( dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True) if args.task_name == "mnli": dev_ds_matched, dev_ds_mismatched = load_dataset( 'glue', args.task_name, splits=["dev_matched", "dev_mismatched"]) dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True) dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True) dev_batch_sampler_matched = paddle.io.BatchSampler( dev_ds_matched, batch_size=args.batch_size, shuffle=False) dev_data_loader_matched = DataLoader( dataset=dev_ds_matched, batch_sampler=dev_batch_sampler_matched, collate_fn=batchify_fn, num_workers=0, return_list=True) dev_batch_sampler_mismatched = paddle.io.BatchSampler( dev_ds_mismatched, batch_size=args.batch_size, shuffle=False) dev_data_loader_mismatched = DataLoader( dataset=dev_ds_mismatched, batch_sampler=dev_batch_sampler_mismatched, collate_fn=batchify_fn, num_workers=0, return_list=True) else: dev_ds = load_dataset('glue', args.task_name, splits='dev') dev_ds = dev_ds.map(trans_func, lazy=True) dev_batch_sampler = paddle.io.BatchSampler( dev_ds, batch_size=args.batch_size, shuffle=False) dev_data_loader = DataLoader( dataset=dev_ds, batch_sampler=dev_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True) num_labels = 1 if train_ds.label_list == None else len(train_ds.label_list) # Step1: Initialize the origin BERT model. model = model_class.from_pretrained( args.model_name_or_path, num_classes=num_labels) origin_weights = model.state_dict() # Step2: Convert origin model to supernet. sp_config = supernet(expand_ratio=args.width_mult_list) model = Convert(sp_config).convert(model) # Use weights saved in the dictionary to initialize supernet. utils.set_state_dict(model, origin_weights) # Step3: Define teacher model. teacher_model = model_class.from_pretrained( args.model_name_or_path, num_classes=num_labels) new_dict = utils.utils.remove_model_fn(teacher_model, origin_weights) teacher_model.set_state_dict(new_dict) del origin_weights, new_dict default_run_config = {'elastic_depth': args.depth_mult_list} run_config = RunConfig(**default_run_config) # Step4: Config about distillation. mapping_layers = ['bert.embeddings'] for idx in range(model.bert.config['num_hidden_layers']): mapping_layers.append('bert.encoder.layers.{}'.format(idx)) default_distill_config = { 'lambda_distill': args.lambda_rep, 'teacher_model': teacher_model, 'mapping_layers': mapping_layers, } distill_config = DistillConfig(**default_distill_config) # Step5: Config in supernet training. ofa_model = OFA(model, run_config=run_config, distill_config=distill_config, elastic_order=['depth']) #elastic_order=['width']) criterion = paddle.nn.CrossEntropyLoss( ) if train_ds.label_list else paddle.nn.MSELoss() metric = metric_class() if args.task_name == "mnli": dev_data_loader = (dev_data_loader_matched, dev_data_loader_mismatched) if paddle.distributed.get_world_size() > 1: ofa_model.model = paddle.DataParallel( ofa_model.model, find_unused_parameters=True) if args.max_steps > 0: num_training_steps = args.max_steps num_train_epochs = math.ceil(num_training_steps / len(train_data_loader)) else: num_training_steps = len(train_data_loader) * args.num_train_epochs num_train_epochs = args.num_train_epochs lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, args.warmup_steps) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. decay_params = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, epsilon=args.adam_epsilon, parameters=ofa_model.model.parameters(), weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_params) global_step = 0 tic_train = time.time() for epoch in range(num_train_epochs): # Step6: Set current epoch and task. ofa_model.set_epoch(epoch) ofa_model.set_task('depth') for step, batch in enumerate(train_data_loader): global_step += 1 input_ids, segment_ids, labels = batch for depth_mult in args.depth_mult_list: for width_mult in args.width_mult_list: # Step7: Broadcast supernet config from width_mult, # and use this config in supernet training. net_config = utils.dynabert_config(ofa_model, width_mult, depth_mult) ofa_model.set_net_config(net_config) logits, teacher_logits = ofa_model( input_ids, segment_ids, attention_mask=[None, None]) rep_loss = ofa_model.calc_distill_loss() if args.task_name == 'sts-b': logit_loss = 0.0 else: logit_loss = soft_cross_entropy(logits, teacher_logits.detach()) loss = rep_loss + args.lambda_logit * logit_loss loss.backward() optimizer.step() lr_scheduler.step() ofa_model.model.clear_gradients() if global_step % args.logging_steps == 0: if paddle.distributed.get_rank() == 0: logger.info( "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" % (global_step, epoch, step, loss, args.logging_steps / (time.time() - tic_train))) tic_train = time.time() if global_step % args.save_steps == 0: if args.task_name == "mnli": evaluate( teacher_model, criterion, metric, dev_data_loader_matched, width_mult=100) evaluate( teacher_model, criterion, metric, dev_data_loader_mismatched, width_mult=100) else: evaluate( teacher_model, criterion, metric, dev_data_loader, width_mult=100) for depth_mult in args.depth_mult_list: for width_mult in args.width_mult_list: net_config = utils.dynabert_config( ofa_model, width_mult, depth_mult) ofa_model.set_net_config(net_config) tic_eval = time.time() if args.task_name == "mnli": acc = evaluate(ofa_model, criterion, metric, dev_data_loader_matched, width_mult, depth_mult) evaluate(ofa_model, criterion, metric, dev_data_loader_mismatched, width_mult, depth_mult) print("eval done total : %s s" % (time.time() - tic_eval)) else: acc = evaluate(ofa_model, criterion, metric, dev_data_loader, width_mult, depth_mult) print("eval done total : %s s" % (time.time() - tic_eval)) if paddle.distributed.get_rank() == 0: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) if not os.path.exists(output_dir): os.makedirs(output_dir) # need better way to get inner model of DataParallel model_to_save = model._layers if isinstance( model, paddle.DataParallel) else model model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) if global_step >= num_training_steps: return
def __call__(self, model, param_state_dict): paddleslim = try_import('paddleslim') from paddleslim.nas.ofa import OFA, RunConfig, utils from paddleslim.nas.ofa.convert_super import Convert, supernet task = self.ofa_config['task'] expand_ratio = self.ofa_config['expand_ratio'] skip_neck = self.ofa_config['skip_neck'] skip_head = self.ofa_config['skip_head'] run_config = self.ofa_config['RunConfig'] if 'skip_layers' in run_config: skip_layers = run_config['skip_layers'] else: skip_layers = [] # supernet config sp_config = supernet(expand_ratio=expand_ratio) # convert to supernet model = Convert(sp_config).convert(model) skip_names = [] if skip_neck: skip_names.append('neck.') if skip_head: skip_names.append('head.') for name, sublayer in model.named_sublayers(): for n in skip_names: if n in name: skip_layers.append(name) run_config['skip_layers'] = skip_layers run_config = RunConfig(**run_config) # build ofa model ofa_model = OFA(model, run_config=run_config) ofa_model.set_epoch(0) ofa_model.set_task(task) input_spec = [{ "image": paddle.ones( shape=[1, 3, 640, 640], dtype='float32'), "im_shape": paddle.full( [1, 2], 640, dtype='float32'), "scale_factor": paddle.ones( shape=[1, 2], dtype='float32') }] ofa_model._clear_search_space(input_spec=input_spec) ofa_model._build_ss = True check_ss = ofa_model._sample_config('expand_ratio', phase=None) # tokenize the search space ofa_model.tokenize() # check token map, search cands and search space logger.info('Token map is {}'.format(ofa_model.token_map)) logger.info('Search candidates is {}'.format(ofa_model.search_cands)) logger.info('The length of search_space is {}, search_space is {}'. format(len(ofa_model._ofa_layers), ofa_model._ofa_layers)) # set model state dict into ofa model utils.set_state_dict(ofa_model.model, param_state_dict) return ofa_model