def train(rank, args, tokenizer, train_dataset, test_dataset, model_s, model_t, params_to_tune, head_importance=None, loss_num=-1, tune_iter=0): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train stage: ", train_count) printlog(model_s) if head_importance is not None: head_mask = torch.ones(*list(head_importance.shape)).to(args.device) head_mask.requires_grad_(requires_grad=True) else: head_mask = None num_train_epochs = args.num_train_epochs if loss_num > 0: num_train_epochs = 0.25 #short train for incremental loss per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size #get total batch size and if tune_iter > 0 and args.total_train_batch_size_for_tune: total_train_batch_size = args.total_train_batch_size_for_tune else: total_train_batch_size = args.total_train_batch_size gradient_accumulation_steps = total_train_batch_size // train_batch_size if tune_iter > 0 and args.learning_rate_for_tune: learning_rate = args.learning_rate_for_tune else: learning_rate = args.learning_rate if check_model_type(model_s, BertModelEMB): #use 2 datasets for embedding question and context separatly if rank in [-1, 0]: printlog("dataset_q size", len(train_dataset.q_dataset)) printlog("dataset_c size", len(train_dataset.c_dataset)) datasets = [train_dataset.q_dataset, train_dataset.c_dataset] else: if rank in [-1, 0]: printlog("dataset size", len(train_dataset)) datasets = [train_dataset] if rank > -1: #for distributed train use sample that take only part of samples for each process train_dataloaders = [ DataLoader(dataset, sampler=torch.utils.data.distributed.DistributedSampler( dataset, rank=rank), batch_size=per_gpu_train_batch_size) for dataset in datasets ] else: train_dataloaders = [ DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=train_batch_size, num_workers=4) for dataset in datasets ] steps_per_epoch = sum(len(d) for d in train_dataloaders) steps_total = int(steps_per_epoch // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and scheduler name_set = set() for n, p in model_s.named_parameters(): if any(p is pp for pp in params_to_tune): name_set.add(n) named_params = [(n, p) for n, p in model_s.named_parameters() if n in name_set] if rank in [-1, 0]: for n, p in named_params: printlog('param for tune', n) def new_optimizer(): return AdamW([p for n, p in named_params], lr=learning_rate, eps=1e-08, weight_decay=0.0) optimizer = new_optimizer() def lr_lambda(current_step): p = float(current_step) / float(steps_total) warmup = 0.01 if p < warmup: return p / warmup p = (p - warmup) / (1 - warmup) return 1 if tune_iter == 0 else max(1 - p, 0) scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: printlog("epoches", num_train_epochs) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size) printlog("n_gpu", args.n_gpu) printlog("world_size", world_size) printlog("gradient_accumulation_steps", gradient_accumulation_steps) printlog("total train batch size", train_batch_size * gradient_accumulation_steps) printlog("steps_total", steps_total) restore_count = 0 if rank in [-1, 0]: if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) restore_file = os.path.join(args.output_dir, 'last_good_state.pth') restore_loss = None losses_list = [] global_step = 0 for epoch in range(math.ceil(num_train_epochs)): switch_to_train(rank, model_t) switch_to_train(rank, model_s) model_s.zero_grad() utils.sync_models(rank, model_s) time_last = time.time() for train_dataloader in train_dataloaders: printlog("rank", rank, "len(train_dataloader)", len(train_dataloader)) if rank > -1: train_dataloader.sampler.set_epoch(epoch) if len(train_dataloaders) > 1: # reset last loss to avoid restore due to dataset changing printlog("rank", rank, "reset restore_loss") restore_loss = None for step, batch in enumerate(train_dataloader): epoch_fp = epoch + step / len(train_dataloader) if epoch_fp > num_train_epochs: break inputs = { 'input_ids': batch[0].to(args.device), 'attention_mask': batch[1].to(args.device), 'token_type_ids': batch[2].to(args.device) } outputs_s = model_s(**inputs, head_mask=head_mask, output_hidden_states=True) losses = [] with torch.no_grad(): outputs_t = model_t(**inputs, output_hidden_states=True) out_s, out_t = outputs_s[-1], outputs_t[-1] assert len( out_s ) == model_s.config.num_hidden_layers + 1, "can not find hidden states in student model outputs" assert len( out_t ) == model_t.config.num_hidden_layers + 1, "can not find hidden states in teacher model outputs" if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [ out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s) ] assert len(out_s) == len( out_t ), "can not align number of outputs between student and teacher" assert all( s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape) ), "output shapes for student and teacher are not the same" out_pairs = list(zip(out_s, out_t)) if loss_num > 0: out_pairs = out_pairs[:loss_num] losses += [(s - t.detach()).pow(2).mean() for s, t in out_pairs] losses_list.append([l.item() for l in losses]) if tune_iter == 0: loss = sum(losses) / len(losses) else: weights = [ args.loss_weight_alpha**i for i in range(len(losses)) ] losses_w = [w * l for w, l in zip(weights, losses)] loss = sum(losses_w) / sum(weights) if gradient_accumulation_steps > 1: loss = loss / gradient_accumulation_steps loss.backward() del out_s del out_t del outputs_s del outputs_t if head_importance is not None: #collect gradient statistics to find most valuable heads head_mask.grad.detach_() head_importance += (head_mask.grad.abs().detach() - head_importance) * 0.001 head_mask.grad.zero_() if (step + 1) % gradient_accumulation_steps == 0: global_step += 1 #sync gradients before calc step utils.sync_grads(rank, named_params, global_step == 1) torch.nn.utils.clip_grad_norm_( [p for n, p in named_params], 1) optimizer.step() scheduler.step() model_s.zero_grad() if (step + 1) % 50 == 0: str_out = "{} ep {:.2f} lrp {:.2f} rc {:02}".format( train_count, epoch_fp, np.log10(scheduler.get_last_lr()[0]), restore_count) ll = np.array(losses_list).mean(0) if rank > -1: #sync indicators llt = torch.tensor(ll).to(args.device) torch.distributed.all_reduce( llt, op=torch.distributed.ReduceOp.SUM) ll = llt.cpu().numpy() / float(world_size) loss = ll.mean() str_out += " loss {:.4f}".format(loss) losses_txt = ["{:.3f}".format(l) for l in ll] if tune_iter > 0: losses_txt = [ "{:.2f}x".format(w) + lt for w, lt in zip(weights, losses_txt) ] str_out += " ll " + " ".join(losses_txt) if time_last: dt_iter = (time.time() - time_last) / len(losses_list) dt_ep = dt_iter * steps_per_epoch str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format( dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) losses_list = [] time_last = time.time() if rank in [-1, 0]: logger.info(str_out) if rank > -1: #sync losses loss_tensor = torch.tensor([loss], device=args.device) torch.distributed.all_reduce( loss_tensor, op=torch.distributed.ReduceOp.SUM) loss = loss_tensor.item() / world_size if restore_loss is None or loss < restore_loss * 1.5: #good result lets save it restore_loss = loss if rank in [-1, 0]: torch.save( { 'model_state_dict': model_s.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, restore_file) if rank > -1: torch.distributed.barrier() else: #bad result lets restore restore_count += 1 logger.info( "rank {} restore #{} from {} with {} loss". format(rank, restore_count, restore_file, restore_loss)) checkpoint = torch.load(restore_file) model_s.load_state_dict( checkpoint['model_state_dict']) #optimizer.load_state_dict(checkpoint['optimizer_state_dict']) optimizer = new_optimizer() switch_to_train(rank, model_s) if loss_num <= 0: if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) save_model(args, model_s, tokenizer, check_point_name) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) switch_to_eval(rank, model_s) result_s = evaluate(args, model_s, test_dataset) for k, v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, v)) if rank > -1: torch.distributed.barrier() if rank in [-1, 0]: if os.path.exists(restore_file): os.remove(restore_file)
def process(rank, args, port): #init multiprocess if rank < 0: args.device = torch.device("cpu" if args.n_gpu < 1 else "cuda") else: # create default process group os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = str(port) torch.distributed.init_process_group("nccl", rank=rank, world_size=args.n_gpu) args.device = torch.device("cuda:{}".format(rank)) torch.cuda.set_device(rank) torch.cuda.manual_seed_all(args.seed) #set seed np.random.seed(args.seed) torch.manual_seed(args.seed) if rank > 0: #wait while 0 process load models torch.distributed.barrier() printlog("rank", rank, "load tokenizer", args.model_teacher) tokenizer = BertTokenizer.from_pretrained(args.model_student) config = AutoConfig.from_pretrained(args.model_student) if hasattr(config, 'pack_cfg') and 'base_class_name' in config.pack_cfg: #get model class from pach_cfg base_class_name = config.pack_cfg['base_class_name'] printlog("rank", rank, "base_class_name to pack", base_class_name) Model = globals()[base_class_name] else: #get model class from architectures filed of config if config.architectures: assert len( config.architectures ) == 1, "only single model is supported but {} has {}".format( args.model_student, config.architectures) Model = globals()[config.architectures[0]] else: Model = BertForQuestionAnswering printlog( "rank", rank, "load teacher {} model from {}".format(Model.__name__, args.model_teacher)) model_t = Model.from_pretrained(args.model_teacher) printlog( "rank", rank, "load student {} model from {}".format(Model.__name__, args.model_student)) model_s = BertPacked(Model).from_pretrained(args.model_student) if rank == 0: #release other process waiting torch.distributed.barrier() if rank > -1: #sync processes torch.distributed.barrier() params_packed = [] if hasattr(model_s.config, 'pack_cfg'): logger.warning("rank {} !!!model already packed!!!".format(rank)) logger.warning( "rank {} !!!just continue distill the already packed model!!!". format(rank)) else: pack_cfg = dict([t.split(':') for t in args.pack_cfg.split(',')]) pack_cfg['pack_emb'] = True if eval(pack_cfg['pack_emb']) else False printlog("rank", rank, "pack model by", pack_cfg) params_packed = model_s.pack_(pack_cfg) model_s.to(args.device) model_t.to(args.device) utils.sync_models(rank, model_s) if rank in [-1, 0]: save_model(args, model_s, tokenizer) def wrap_dropout(net): #remove dropout class PASS(torch.nn.Module): def __init__(self, dropout): super().__init__() self.dropout = dropout self.dropout_enable = False def forward(self, x): return x def __repr__(self): return "PASS( dropout_enable {} for {} )".format( self.dropout_enable, self.dropout.__repr__()) dropout_list = [(n, m, nn, mm) for n, m in net.named_modules() for nn, mm in m._modules.items() if isinstance(mm, torch.nn.Dropout)] for n, m, nn, mm in dropout_list: m._modules[nn] = PASS(mm) logger.info('rank {} {}.{} Dropout in warped by PASS'.format( rank, n, nn)) logger.info('rank {} warp dropout for teacher model'.format(rank)) wrap_dropout(model_t) logger.info('rank {} warp dropout for student model'.format(rank)) wrap_dropout(model_s) #calculate current number of heads in student model bert_s = model_s.get_bert() n_layers, n_heads = bert_s.config.num_hidden_layers, bert_s.config.num_attention_heads if hasattr(bert_s.config, 'pruned_heads'): pruned_nums = [len(v) for v in model_s.config.pruned_heads.values()] if pruned_nums: n_heads -= min(pruned_nums) #load train and evaluation datasets if check_model_type(model_s, BertModelEMB): train_dataset = create_squad_qcemb_dataset(rank, args.device, args.squad_train_data, tokenizer, args.max_seq_length_q, args.max_seq_length_c) test_dataset = create_squad_qcemb_dataset(rank, args.device, args.squad_dev_data, tokenizer, args.max_seq_length_q, args.max_seq_length_c) else: train_dataset = create_squad_qa_dataset(rank, args.device, args.squad_train_data, tokenizer, args.max_seq_length_q, args.max_seq_length_c) test_dataset = create_squad_qa_dataset(rank, args.device, args.squad_dev_data, tokenizer, args.max_seq_length_q, args.max_seq_length_c) if rank in [-1, 0]: switch_to_eval(rank, model_t) result_t = evaluate(args, model_t, test_dataset) for k, v in result_t.items(): logger.info("{} teacher {}".format(k, v)) if rank > -1: torch.distributed.barrier() params_emb = [] for n, p in model_s.named_parameters(): if any(p is pp for pp in params_packed) and 'embedding' in n: params_emb.append(p) if params_emb: params_inp = [ p for n, p in model_s.named_parameters() if 'input_transform' in n ] #tune embeddings transformation params_tune = params_emb + params_inp loss_num = 1 train(rank, args, tokenizer, train_dataset, test_dataset, model_s, model_t, params_tune, head_importance=None, loss_num=loss_num) #iterative add bert encoder blocks encoder = model_s.get_bert().encoder for l, t in zip(encoder.layer, encoder.output_transforms): params_tune.extend(l.parameters()) params_tune.extend(t.parameters()) loss_num += 1 train(rank, args, tokenizer, train_dataset, test_dataset, model_s, model_t, params_tune, head_importance=None, loss_num=loss_num) if params_packed: #on the first stage the FF block only reduced and tuned #the number of self attention heads is the same #check that head prune is needed and run second train to tune the rest heads pack_head_num = int( model_s.config.pack_cfg.get('num_attention_heads', n_heads)) pack_heads_flag = (pack_head_num < n_heads) head_importance = torch.zeros(n_layers, n_heads).to( args.device) if pack_heads_flag else None params_ff = [ p for n, p in model_s.named_parameters() if 'encoder.' in n and 'attention.' not in n ] train(rank, args, tokenizer, train_dataset, test_dataset, model_s, model_t, params_packed + params_ff, head_importance=head_importance) if head_importance is not None and rank > -1: torch.distributed.all_reduce(head_importance.data, op=torch.distributed.ReduceOp.SUM) if pack_heads_flag: #reduce number of heads before move to the second stage and tune all model if rank in [-1, 0]: logger.info('head_importance') logger.info(head_importance) logger.info('heads_to_prune') #prune heads heads_to_prune = {} for l in range(n_layers): imp = head_importance[l].tolist() idx = list(sorted(range(n_heads), key=lambda x: imp[x])) heads_to_prune[l] = idx[:-pack_head_num] if rank in [-1, 0]: logger.info("layer {} heads_to_prune {}".format( l, heads_to_prune[l])) model_s.prune_heads(heads_to_prune) utils.sync_models(rank, model_s) params_encoder = [ p for n, p in model_s.named_parameters() if 'encoder.' in n ] params_emb = [ p for n, p in model_s.named_parameters() if 'embedding' in n and 'linear' in n ] if params_emb: # if has linear then LayerNorm was trained params_emb += [ p for n, p in model_s.named_parameters() if 'embedding' in n and 'LayerNorm' in n ] train(rank, args, tokenizer, train_dataset, test_dataset, model_s, model_t, params_emb + params_encoder) params_encoder = [ p for n, p in model_s.named_parameters() if 'encoder.' in n ] params_emb = [ p for n, p in model_s.named_parameters() if 'embedding' in n and 'linear' in n ] if params_emb: #if has linear then LayerNorm was trained params_emb += [ p for n, p in model_s.named_parameters() if 'embedding' in n and 'LayerNorm' in n ] #final tune train(rank, args, tokenizer, train_dataset, test_dataset, model_s, model_t, params_emb + params_encoder, tune_iter=1) if rank in [-1, 0]: save_model(args, model_s, tokenizer) logger.info('Evaluate student model') logger.info('Model for evaluation') logger.info(model_s) switch_to_eval(rank, model_s) result_s = evaluate(args, model_s, test_dataset) for k, v in result_s.items(): logger.info("{} student {} teacher {}".format(k, v, result_t[k])) #merge some linear transformations into filters model_s.merge_() logger.info("student model") logger.info(model_s) result_s = evaluate(args, model_s, test_dataset) for k, v in result_s.items(): logger.info( "{} student {} after some operations are merged".format(k, v)) #save to onnx if check_model_type(model_s, BertModelEMB): output_names = ['embedding'] else: output_names = ['output_s', 'output_e'] inputs = tuple( torch.zeros(args.max_seq_length_q, dtype=torch.long) for t in range(4)) inputs = tuple(t.unsqueeze(0).to(args.device) for t in inputs) torch.onnx.export(model_s, inputs, os.path.join(args.output_dir, "model.onnx"), verbose=False, input_names=[ 'input_ids', 'attention_mask', 'token_type_ids', 'position_ids' ], output_names=output_names)
def process(rank, args, port): #init multiprocess if rank < 0: args.device = torch.device("cpu" if args.n_gpu < 1 else "cuda") else: # create default process group os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = str(port) torch.distributed.init_process_group("nccl", rank=rank, world_size=args.n_gpu) args.device = torch.device("cuda:{}".format(rank)) torch.cuda.set_device(rank) if rank > 0: #wait while process 0 load models torch.distributed.barrier() printlog("rank", rank, "load tokenizer", args.model_student) tokenizer = BertTokenizer.from_pretrained(args.model_student) printlog("rank", rank, "load model", args.model_student) config = AutoConfig.from_pretrained(args.model_student) if config.architectures and 'BertBasedClassPacked' in config.architectures: model = BertPacked(BertModelEMB).from_pretrained( args.model_student).to(args.device) else: model = BertModelEMB.from_pretrained(args.model_student).to( args.device) if args.supervision_weight > 0: model_t = BertModelEMB.from_pretrained(args.model_teacher).to( args.device) else: model_t = None if rank == 0: #wait while other processes load models torch.distributed.barrier() #create train and evaluate datasets train_dataset_qc = create_squad_qcemb_dataset(rank, args.device, args.squad_train_data, tokenizer, args.max_seq_length_q, args.max_seq_length_c) test_dataset_qc = create_squad_qcemb_dataset(rank, args.device, args.squad_dev_data, tokenizer, args.max_seq_length_q, args.max_seq_length_c) if rank >= 0: #lets sync after data loaded torch.distributed.barrier() model_controller = None if QUANTIZATION: if hasattr(model, 'merge_'): #if model is packed, then merge some linera transformations before quantization model.merge_() if rank in [0, -1]: #evaluate before quntization model.eval() result = evaluate(args, model, test_dataset_qc) for n, v in result.items(): logger.info("original {} - {}".format(n, v)) if rank >= 0: torch.distributed.barrier() nncf_config = nncf.NNCFConfig.from_json(args.nncf_config) class SquadInitializingDataloader( nncf.initialization.InitializingDataLoader): def get_inputs(self, batch): return [], get_inputs(batch, args.device) train_dataloader = DataLoader(train_dataset_qc.c_dataset, sampler=RandomSampler( train_dataset_qc.c_dataset), batch_size=args.per_gpu_train_batch_size) initializing_data_loader = SquadInitializingDataloader( train_dataloader) init_range = nncf.initialization.QuantizationRangeInitArgs( initializing_data_loader) nncf_config.register_extra_structs([init_range]) model_controller, model = nncf.create_compressed_model( model, nncf_config, dump_graphs=True) if rank > -1: model_controller.distributed() utils.sync_models(rank, model) if rank in [-1, 0]: #evaluate pure initialized int8 model model.eval() result = evaluate(args, model, test_dataset_qc) for n, v in result.items(): logger.info("int8 {} - {}".format(n, v)) if rank > -1: #lets sync after quantization torch.distributed.barrier() #tune FQ parameters only train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc, fq_tune_only=True, model_controller=model_controller) #tune whole quantized model train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc, fq_tune_only=False, model_controller=model_controller) if rank in [-1, 0]: #save and evaluate result os.makedirs(args.output_dir, exist_ok=True) model.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) model.eval() #get sample to pass for onnx generation with torch.no_grad(): torch.onnx.export(model, tuple( torch.zeros((1, args.max_seq_length_c), dtype=torch.long, device=args.device) for t in range(4)), os.path.join(args.output_dir, "model.onnx"), verbose=False, enable_onnx_checker=False, opset_version=10, input_names=[ 'input_ids', 'attention_mask', 'token_type_ids', 'position_ids' ], output_names=['embedding']) # Evaluate final model result = evaluate(args, model, test_dataset_qc) for n, v in result.items(): logger.info("{} - {}".format(n, v)) logger.info("checkpoint final result {}".format(result))
def train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc, fq_tune_only, model_controller): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train model", train_count) printlog(model) q_dataset = train_dataset_qc.q_dataset per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size if fq_tune_only: gradient_accumulation_steps = 1 num_train_epochs = 1 else: gradient_accumulation_steps = args.total_train_batch_size // train_batch_size num_train_epochs = args.num_train_epochs if rank < 0: #single process take all q_sampler = RandomSampler(q_dataset) q_dataloader = DataLoader(q_dataset, sampler=q_sampler, batch_size=train_batch_size, num_workers=4) else: #special sampler that divide samples between processes q_sampler = torch.utils.data.distributed.DistributedSampler(q_dataset, rank=rank) q_dataloader = DataLoader(q_dataset, sampler=q_sampler, batch_size=per_gpu_train_batch_size) steps_total = int( len(q_dataloader) // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and schedule named_params, groups = utils.make_param_groups( rank, model, args. freeze_list, #list or str with subnames to define frozen parameters args.learning_rate, #learning rate for no FQ parameters 0.01, # learning rate for FQ parameters fq_tune_only, #true if only FQ parameters will be optimized model_controller) optimizer = AdamW(groups, eps=1e-08, lr=args.learning_rate, weight_decay=0) def lr_lambda(current_step): p = float(current_step) / float(steps_total) return 1 - p scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: for n, p in named_params: printlog('param for tune', n) printlog("fq_tune_only", fq_tune_only) printlog("dataset size", len(q_dataset)) printlog("epoches", num_train_epochs) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size) printlog("n_gpu", args.n_gpu) printlog("world_size", world_size) printlog("gradient_accumulation_steps", gradient_accumulation_steps) printlog("total train batch size", train_batch_size * gradient_accumulation_steps) printlog("steps_total", steps_total) global_step = 1 model.zero_grad() indicators = collections.defaultdict(list) softplus = torch.nn.Softplus() loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) hnm_hist = {} for epoch in range(math.ceil(num_train_epochs)): indicators = collections.defaultdict(list) model.train() if model_t: model_t.train() if rank > -1: #set epoch to make different samples division betwen process for different epoches q_sampler.set_epoch(epoch) utils.sync_models(rank, model) for step, q_batch in enumerate(q_dataloader): epoch_fp = epoch + step / len(q_dataloader) if epoch_fp > num_train_epochs: break losses = [] context_ids_pos = q_batch[3] q_inputs = get_inputs(q_batch, args.device) q_outputs = model(**q_inputs, output_hidden_states=(model_t is not None)) q_vec = q_outputs[0] #get positive embeddings c_batch = train_dataset_qc.c_dataset[context_ids_pos.detach().data] c_inputs = get_inputs(c_batch, args.device) c_outputs = model(**c_inputs, output_hidden_states=(model_t is not None)) c_vec_pos = c_outputs[0] if model_t is not None: q_emb_s, q_hidden_s = q_outputs c_emb_s, c_hidden_s = c_outputs with torch.no_grad(): q_emb_t, q_hidden_t = model_t(**q_inputs, output_hidden_states=True) c_emb_t, c_hidden_t = model_t(**c_inputs, output_hidden_states=True) def align_and_loss_outputs(out_s, out_t): if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [ out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s) ] assert len(out_s) == len( out_t ), "can not align number of outputs between student and teacher" assert all( s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape) ), "output shapes for student and teacher are not the same" return [(s - t.detach()).pow(2).mean() for s, t in zip(out_s, out_t)] l_q = align_and_loss_outputs(q_hidden_s, q_hidden_t) l_c = align_and_loss_outputs(c_hidden_s, c_hidden_t) emb_loss = loss_cfg.get('emb_loss', '') if emb_loss == 'L2': l_q.append((q_emb_s - q_emb_t.detach()).pow(2).mean()) l_c.append((c_emb_s - c_emb_t.detach()).pow(2).mean()) elif emb_loss == 'L1': l_q.append((q_emb_s - q_emb_t.detach()).abs().mean()) l_c.append((c_emb_s - c_emb_t.detach()).abs().mean()) elif emb_loss.lower() not in ['', 'none', '0', 'disable']: raise Exception( 'emb_loss={} is unsupported'.format(emb_loss)) losses.extend([args.supervision_weight * l for l in l_c + l_q]) triplet_num = int(loss_cfg.get('triplet_num', 1)) if fq_tune_only: triplet_num = 0 if triplet_num > 0: #disable grad to select negatives with torch.no_grad(): hnm_scores = [] hnm_idxs = [] #check that current step has no HNM conext vector if global_step not in hnm_hist and args.hnm_num > 0: #generate the new one if world_size > 1 and (args.hnm_num % world_size) != 0: #aligh hnm_num per each replica hnm_plus = world_size - (args.hnm_num % world_size) args.hnm_num += hnm_plus logger.warning( "rank {} args.hnm_num increased by {} from {} to {} to be the same after division by {} replicas." .format(rank, hnm_plus, args.hnm_num - hnm_plus, args.hnm_num, world_size)) # generate random contexts to calc embedding context_ids_all = torch.randint( low=0, high=len(train_dataset_qc.c_dataset), size=[args.hnm_num]) if rank < 0: #single process take all context_ids = context_ids_all else: #broadcast one sigle indicies to all processes context_ids_all = context_ids_all.to(args.device) torch.distributed.broadcast(context_ids_all, 0) context_ids_all = context_ids_all.cpu() #each process take only small part to calc embedding s = ((rank + 0) * args.hnm_num) // world_size e = ((rank + 1) * args.hnm_num) // world_size context_ids = context_ids_all[s:e] batch_size = min(args.hnm_batch_size, context_ids.shape[0]) s, e = 0, batch_size c_outputs = [] while e > s: idx = context_ids.detach()[s:e] c_batch = train_dataset_qc.c_dataset[idx] inputs = get_inputs(c_batch, args.device) outputs = model(**inputs, output_hidden_states=False) c_outputs.append(outputs[0]) s, e = e, min(e + batch_size, context_ids.shape[0]) context_emb = torch.cat(c_outputs, dim=0) if rank < 0: # single process calculated all context_emb_all = context_emb else: context_emb_list = [ torch.zeros_like(context_emb) for _ in range(world_size) ] torch.distributed.all_gather( context_emb_list, context_emb) context_emb_all = torch.cat(context_emb_list, dim=0) hnm_hist[global_step] = (context_ids_all, context_emb_all) #check history size and crop the oldest one if len(hnm_hist) > args.hnm_hist_num: del hnm_hist[min(hnm_hist.keys())] #calc HNM scores for current question batch for hist_step, (c_idx, c_vec) in hnm_hist.items(): w = args.hnm_hist_alpha**(global_step - hist_step) t1 = q_vec[:, None, :] t2 = c_vec[None, :, :] d = (t1 - t2) score = -d.norm(2, dim=-1) score = score * w hnm_scores.append(score) hnm_idxs.append(c_idx) if hnm_scores: #choose the hardest negative if we have scores score = torch.cat(hnm_scores, dim=-1) idx = torch.cat(hnm_idxs, dim=-1) score = score.cpu() pos_mask = (context_ids_pos[:, None] == idx[None, :]).to( dtype=score.dtype, device=score.device) score = (1 - pos_mask) * score + pos_mask * score.min( ) #make positive context with small score to avoid chose it as hard neg hn_idx = score.argmax(dim=1, keepdim=True) context_ids_neg = idx[hn_idx] else: #just random selection in case of no scores for HNM size = (context_ids_pos.shape[0], 1) context_ids_neg = torch.randint( 0, len(train_dataset_qc.c_dataset) - 1, size) shift = (context_ids_neg >= context_ids_pos[:, None]) context_ids_neg = context_ids_neg + shift.to( dtype=context_ids_neg.dtype) d_pos = (q_vec - c_vec_pos).norm(2, dim=-1) # get negative embeddings and calc losses for neg_index in range(context_ids_neg.shape[1]): ids = context_ids_neg[:, neg_index] c_batch = train_dataset_qc.c_dataset[ids.detach()] inputs = get_inputs(c_batch, args.device) outputs = model(**inputs, output_hidden_states=False) c_vec_neg = outputs[0] for triplet_index in range(triplet_num): if triplet_index == 0: d_neg = (q_vec - c_vec_neg).norm(2, dim=-1) if triplet_index == 1: d_neg = (c_vec_pos - c_vec_neg).norm(2, dim=-1) d_diff = d_pos - d_neg indicators['dd' + str(triplet_index)].append( [v.mean().item() for v in (d_pos, d_neg, d_diff)]) l = softplus(d_diff) losses.append(l) del d_neg del d_pos #average over batch losses = [l.mean() for l in losses] l = sum(losses) / len(losses) (l / gradient_accumulation_steps).backward() indicators['loss'].append(l.item()) indicators['ll'].append([lll.item() for lll in losses]) #del losses del l if (step + 1) % gradient_accumulation_steps == 0: utils.sync_grads(rank, named_params, report_no_grad_params=(global_step == 1)) torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if global_step % 10 == 0: # Log metrics wall_time = epoch + step / len(q_dataloader) lrp = [ '{:.2f}'.format(i) for i in np.log10(scheduler.get_last_lr()) ] str_out = "{} ep {:.2f} lrp {}".format( train_count, epoch_fp, " ".join(lrp)) for k, v in indicators.items(): v = np.array(v) if len(v.shape) == 1: v = v[:, None] if rank > -1: #sync indicators vt = torch.tensor(v).to(args.device) torch.distributed.all_reduce( vt, op=torch.distributed.ReduceOp.SUM) v = vt.cpu().numpy() / float(world_size) str_out += " {} {}".format( k, " ".join(["{:.3f}".format(t) for t in v.mean(0)])) if 'score' in locals(): str_out += " SS {}".format(list(score.shape)) if 'time_last' in locals(): dt_iter = (time.time() - time_last) / len( indicators['loss']) dt_ep = dt_iter * len(q_dataloader) str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format( dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) time_last = time.time() indicators = collections.defaultdict(list) if rank in [-1, 0]: logger.info(str_out) if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) result_s = evaluate(args, model.eval(), test_dataset_qc) for k, v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, v)) if rank > -1: torch.distributed.barrier()
def train(self, epoch_start): global_step = 0 self.check_loss_raise = CheckLossRaise() for epoch in range(epoch_start, math.ceil(self.args.num_train_epochs)): self.indicators = collections.defaultdict(list) utils.sync_models(self.rank, self.model) self.model.train() self.model.zero_grad() grad_count = 0 if self.rank > -1: #set epoch to make different samples division betwen proceses for different epoches self.dataloader.sampler.set_epoch(epoch) for step, batch in enumerate(self.dataloader): epoch_fp = epoch + step/len(self.dataloader) if epoch_fp > self.args.num_train_epochs: break x_noise, x_clean = [t.to(self.args.device) for t in batch] #augment and mix signals x_clean, x_noise, x = self.mix_signals(x_clean, x_noise) #forward pass y_clean, Y_clean, _ = self.model(x) #calc specter for clean input signal tail_size = self.model.wnd_length - self.model.hop_length X_clean = self.model.encode(torch.nn.functional.pad(x_clean, (tail_size, 0))) # crop target and model output to align to each other sample_ahead = self.model.get_sample_ahead() spectre_ahead = self.model.ahead if sample_ahead > 0: x = x[:, :-sample_ahead] x_clean = x_clean[:, :-sample_ahead] y_clean = y_clean[:, sample_ahead:] if spectre_ahead > 0: Y_clean = Y_clean[:, :, :, spectre_ahead:] X_clean = X_clean[:, :, :, :-spectre_ahead] loss = self.loss(epoch_fp, y_clean, Y_clean, x_clean, X_clean) self.indicators['loss'].append(loss.item()) #calculate and accumulate gradients loss.backward() grad_count += 1 #continue if not all gradients were accumulated if grad_count < self.gradient_accumulation_steps: continue #make optimization step utils.sync_grads(self.rank, self.named_params, global_step==0, grad_count) self.optimizer.step() # make optimization step self.scheduler.step() # Update learning rate schedule global_step += 1 self.model.zero_grad() grad_count = 0 #make logs only after several steps if global_step % self.args.logacc != 0: continue #average indicator over GPUs and iterations self.aver_indicators() #check that negsisdr suddenly raise #if high raise detected then model parameters are restored and optimizer is reset self.check_loss_raise.check( self.indicators_mean["negsisdr"], self.named_params, self.optimizer ) self.log_indicators(epoch_fp) self.indicators = collections.defaultdict(list) self.save_and_eval_checkpoint(epoch+1)
def train(rank, args, model, model_t, train_dataset_qa, test_dataset_qa, scale_tune): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train model",train_count) printlog(model) per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size gradient_accumulation_steps = args.total_train_batch_size // train_batch_size num_train_epochs = args.num_train_epochs if scale_tune: gradient_accumulation_steps = 1 num_train_epochs = 1 if rank < 0: #single process take all samples sampler = RandomSampler(train_dataset_qa) dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=train_batch_size, num_workers=4) else: #special sampler that divide samples beween processes sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_qa, rank=rank) dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=per_gpu_train_batch_size) steps_total = int(len(dataloader) // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and schedule freeze_list = args.freeze_list.split(',') if args.freeze_list else [] named_params = [] for n, p in model.named_parameters(): if n.lower()!="none" and any(fn in n for fn in freeze_list): if rank in [-1, 0]: logger.warning("rank {} {} param is frozen and excluded from tune".format(rank,n)) continue named_params.append( (n, p) ) # split parameters to scale and the rest named_params_scale = [(n, p) for n, p in named_params if '.scale' in n] named_params_rest = [(n, p) for n, p in named_params if '.scale' not in n] if scale_tune: #keep only scale parameters named_params = named_params_scale named_params_rest = [] groups = [] if named_params_scale: groups.append({'params': [p for n, p in named_params_scale], 'lr': 0.01}) if named_params_rest: groups.append({'params': [p for n, p in named_params_rest], 'lr': args.learning_rate}) optimizer = AdamW( groups, eps=1e-08, lr=args.learning_rate, weight_decay=0) def lr_lambda(current_step): p = float(current_step) / float(steps_total) return 1 - p scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: for n,p in named_params: printlog('param for tune',n) printlog("scale_tune", scale_tune ) printlog("dataset size", len(train_dataset_qa) ) printlog("epoches", num_train_epochs ) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size ) printlog("n_gpu", args.n_gpu ) printlog("world_size", world_size ) printlog("gradient_accumulation_steps", gradient_accumulation_steps ) printlog("total train batch size", train_batch_size * gradient_accumulation_steps ) printlog("steps_total",steps_total ) global_step = 0 model.zero_grad() indicators = collections.defaultdict(list) softplus = torch.nn.Softplus() loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) if args.loss_cfg else dict() for epoch in range(math.ceil(num_train_epochs)): indicators = collections.defaultdict(list) model.train() set_output_hidden_states(rank, model, (model_t is not None)) utils.sync_models(rank, model) if model_t is not None: set_output_hidden_states(rank, model_t, True) model_t.train() if rank > -1: #set epoch to make different samples division betwen process for different epoches sampler.set_epoch(epoch) for step, batch in enumerate(dataloader): epoch_fp = epoch + step/len(dataloader) if epoch_fp > num_train_epochs: break epoch_fp = epoch + step/len(dataloader) losses = [] inputs = get_inputs(batch, args.device) targets = get_targets(batch, args.device) outputs = model(**inputs, **targets, output_hidden_states=(model_t is not None)) losses.append(outputs[0]) outputs = outputs[1:] if model_t is not None: with torch.no_grad(): outputs_t = model_t(**inputs, output_hidden_states=True) hidden_t = outputs_t[2] assert isinstance(hidden_t, (tuple,list)), "hidden states output is not detected right" assert len(hidden_t) == model_t.config.num_hidden_layers+1, "hidden states output is not detected right" if args.kd_weight>0: # Calculate knowladge distilation loss kd_losses = [] for logit_s,logit_t in zip(outputs[0:2],outputs_t[0:2]): T = 1 prob_t = torch.nn.functional.softmax(logit_t.detach() / T, dim=1) logprob_s = torch.nn.functional.log_softmax(logit_s / T, dim=1) kd_losses.append( -(logprob_s * prob_t).mean() * (T * T * prob_t.shape[1]) ) losses.append(args.kd_weight*sum(kd_losses)/len(kd_losses)) hidden_s = outputs[2] assert isinstance(hidden_s, (tuple,list)), "hidden states output is not detected right" assert len(hidden_s) == model.config.num_hidden_layers+1, "hidden states output is not detected right" def align_and_loss_outputs(out_s, out_t): if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [out_t[(i*(n_t-1))//(n_s-1)] for i in range(n_s)] assert len(out_s) == len(out_t), "can not align number of outputs between student and teacher" assert all(s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)), "output shapes for student and teacher are not the same" return [(s - t.detach()).pow(2).mean() for s,t in zip(out_s, out_t)] sw_losses = align_and_loss_outputs(hidden_s,hidden_t) losses.extend([args.supervision_weight*l for l in sw_losses]) #average over batch losses = [l.mean() for l in losses] l = sum(losses)/len(losses) indicators['loss'].append(l.item()) indicators['ll'].append([lll.item() for lll in losses]) (l/gradient_accumulation_steps).backward() del l if (step + 1) % gradient_accumulation_steps == 0: global_step += 1 utils.sync_grads(rank, named_params, report_no_grad_params=(global_step==1)) torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() if global_step % 50 == 0: # Log metrics wall_time = epoch + step / len(dataloader) lrp = " ".join(['{:.2f}'.format(t) for t in np.log10(scheduler.get_last_lr())]) str_out = "{} ep {:.2f} lrp {}".format(train_count, epoch_fp, lrp) for k,v in indicators.items(): v = np.array(v) if len(v.shape)==1: v = v[:,None] if rank>-1: #sync indicators vt = torch.tensor(v).to(args.device) torch.distributed.all_reduce(vt, op=torch.distributed.ReduceOp.SUM) v = vt.cpu().numpy() / float(world_size) str_out += " {} {}".format(k," ".join(["{:.3f}".format(t) for t in v.mean(0)])) if 'time_last' in locals(): #estimate processing times dt_iter = (time.time() - time_last) / len(indicators['loss']) dt_ep = dt_iter * len(dataloader) str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format(dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) time_last = time.time() indicators = collections.defaultdict(list) if rank in [-1, 0]: logger.info(str_out) if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) model.eval() set_output_hidden_states(rank, model, False) result_s = evaluate(args, model, test_dataset_qa) for k,v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, result_s[k])) if rank>-1: torch.distributed.barrier()