def _update_network(self): # sample the episodes transitions = self.buffer.sample(self.args.batch_size) # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) if self.args.cuda: inputs_norm_tensor = inputs_norm_tensor.cuda() inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs actions_next = self.actor_target_network(inputs_next_norm_tensor) q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next) q_next_value = q_next_value.detach() target_q_value = r_tensor + self.args.gamma * q_next_value target_q_value = target_q_value.detach() # clip the q value clip_return = 1 / (1 - self.args.gamma) target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor) critic_loss = (target_q_value - real_q_value).pow(2).mean() # the actor loss actions_real = self.actor_network(inputs_norm_tensor) actor_loss = -self.critic_network(inputs_norm_tensor, actions_real).mean() actor_loss += self.args.action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() # start to update the network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step() # update the critic_network self.critic_optim.zero_grad() critic_loss.backward() sync_grads(self.critic_network) self.critic_optim.step()
def train(rank, args, tokenizer, train_dataset, test_dataset, model_s, model_t, params_to_tune, head_importance=None, loss_num=-1, tune_iter=0): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train stage: ", train_count) printlog(model_s) if head_importance is not None: head_mask = torch.ones(*list(head_importance.shape)).to(args.device) head_mask.requires_grad_(requires_grad=True) else: head_mask = None num_train_epochs = args.num_train_epochs if loss_num > 0: num_train_epochs = 0.25 #short train for incremental loss per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size #get total batch size and if tune_iter > 0 and args.total_train_batch_size_for_tune: total_train_batch_size = args.total_train_batch_size_for_tune else: total_train_batch_size = args.total_train_batch_size gradient_accumulation_steps = total_train_batch_size // train_batch_size if tune_iter > 0 and args.learning_rate_for_tune: learning_rate = args.learning_rate_for_tune else: learning_rate = args.learning_rate if check_model_type(model_s, BertModelEMB): #use 2 datasets for embedding question and context separatly if rank in [-1, 0]: printlog("dataset_q size", len(train_dataset.q_dataset)) printlog("dataset_c size", len(train_dataset.c_dataset)) datasets = [train_dataset.q_dataset, train_dataset.c_dataset] else: if rank in [-1, 0]: printlog("dataset size", len(train_dataset)) datasets = [train_dataset] if rank > -1: #for distributed train use sample that take only part of samples for each process train_dataloaders = [ DataLoader(dataset, sampler=torch.utils.data.distributed.DistributedSampler( dataset, rank=rank), batch_size=per_gpu_train_batch_size) for dataset in datasets ] else: train_dataloaders = [ DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=train_batch_size, num_workers=4) for dataset in datasets ] steps_per_epoch = sum(len(d) for d in train_dataloaders) steps_total = int(steps_per_epoch // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and scheduler name_set = set() for n, p in model_s.named_parameters(): if any(p is pp for pp in params_to_tune): name_set.add(n) named_params = [(n, p) for n, p in model_s.named_parameters() if n in name_set] if rank in [-1, 0]: for n, p in named_params: printlog('param for tune', n) def new_optimizer(): return AdamW([p for n, p in named_params], lr=learning_rate, eps=1e-08, weight_decay=0.0) optimizer = new_optimizer() def lr_lambda(current_step): p = float(current_step) / float(steps_total) warmup = 0.01 if p < warmup: return p / warmup p = (p - warmup) / (1 - warmup) return 1 if tune_iter == 0 else max(1 - p, 0) scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: printlog("epoches", num_train_epochs) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size) printlog("n_gpu", args.n_gpu) printlog("world_size", world_size) printlog("gradient_accumulation_steps", gradient_accumulation_steps) printlog("total train batch size", train_batch_size * gradient_accumulation_steps) printlog("steps_total", steps_total) restore_count = 0 if rank in [-1, 0]: if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) restore_file = os.path.join(args.output_dir, 'last_good_state.pth') restore_loss = None losses_list = [] global_step = 0 for epoch in range(math.ceil(num_train_epochs)): switch_to_train(rank, model_t) switch_to_train(rank, model_s) model_s.zero_grad() utils.sync_models(rank, model_s) time_last = time.time() for train_dataloader in train_dataloaders: printlog("rank", rank, "len(train_dataloader)", len(train_dataloader)) if rank > -1: train_dataloader.sampler.set_epoch(epoch) if len(train_dataloaders) > 1: # reset last loss to avoid restore due to dataset changing printlog("rank", rank, "reset restore_loss") restore_loss = None for step, batch in enumerate(train_dataloader): epoch_fp = epoch + step / len(train_dataloader) if epoch_fp > num_train_epochs: break inputs = { 'input_ids': batch[0].to(args.device), 'attention_mask': batch[1].to(args.device), 'token_type_ids': batch[2].to(args.device) } outputs_s = model_s(**inputs, head_mask=head_mask, output_hidden_states=True) losses = [] with torch.no_grad(): outputs_t = model_t(**inputs, output_hidden_states=True) out_s, out_t = outputs_s[-1], outputs_t[-1] assert len( out_s ) == model_s.config.num_hidden_layers + 1, "can not find hidden states in student model outputs" assert len( out_t ) == model_t.config.num_hidden_layers + 1, "can not find hidden states in teacher model outputs" if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [ out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s) ] assert len(out_s) == len( out_t ), "can not align number of outputs between student and teacher" assert all( s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape) ), "output shapes for student and teacher are not the same" out_pairs = list(zip(out_s, out_t)) if loss_num > 0: out_pairs = out_pairs[:loss_num] losses += [(s - t.detach()).pow(2).mean() for s, t in out_pairs] losses_list.append([l.item() for l in losses]) if tune_iter == 0: loss = sum(losses) / len(losses) else: weights = [ args.loss_weight_alpha**i for i in range(len(losses)) ] losses_w = [w * l for w, l in zip(weights, losses)] loss = sum(losses_w) / sum(weights) if gradient_accumulation_steps > 1: loss = loss / gradient_accumulation_steps loss.backward() del out_s del out_t del outputs_s del outputs_t if head_importance is not None: #collect gradient statistics to find most valuable heads head_mask.grad.detach_() head_importance += (head_mask.grad.abs().detach() - head_importance) * 0.001 head_mask.grad.zero_() if (step + 1) % gradient_accumulation_steps == 0: global_step += 1 #sync gradients before calc step utils.sync_grads(rank, named_params, global_step == 1) torch.nn.utils.clip_grad_norm_( [p for n, p in named_params], 1) optimizer.step() scheduler.step() model_s.zero_grad() if (step + 1) % 50 == 0: str_out = "{} ep {:.2f} lrp {:.2f} rc {:02}".format( train_count, epoch_fp, np.log10(scheduler.get_last_lr()[0]), restore_count) ll = np.array(losses_list).mean(0) if rank > -1: #sync indicators llt = torch.tensor(ll).to(args.device) torch.distributed.all_reduce( llt, op=torch.distributed.ReduceOp.SUM) ll = llt.cpu().numpy() / float(world_size) loss = ll.mean() str_out += " loss {:.4f}".format(loss) losses_txt = ["{:.3f}".format(l) for l in ll] if tune_iter > 0: losses_txt = [ "{:.2f}x".format(w) + lt for w, lt in zip(weights, losses_txt) ] str_out += " ll " + " ".join(losses_txt) if time_last: dt_iter = (time.time() - time_last) / len(losses_list) dt_ep = dt_iter * steps_per_epoch str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format( dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) losses_list = [] time_last = time.time() if rank in [-1, 0]: logger.info(str_out) if rank > -1: #sync losses loss_tensor = torch.tensor([loss], device=args.device) torch.distributed.all_reduce( loss_tensor, op=torch.distributed.ReduceOp.SUM) loss = loss_tensor.item() / world_size if restore_loss is None or loss < restore_loss * 1.5: #good result lets save it restore_loss = loss if rank in [-1, 0]: torch.save( { 'model_state_dict': model_s.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, restore_file) if rank > -1: torch.distributed.barrier() else: #bad result lets restore restore_count += 1 logger.info( "rank {} restore #{} from {} with {} loss". format(rank, restore_count, restore_file, restore_loss)) checkpoint = torch.load(restore_file) model_s.load_state_dict( checkpoint['model_state_dict']) #optimizer.load_state_dict(checkpoint['optimizer_state_dict']) optimizer = new_optimizer() switch_to_train(rank, model_s) if loss_num <= 0: if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) save_model(args, model_s, tokenizer, check_point_name) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) switch_to_eval(rank, model_s) result_s = evaluate(args, model_s, test_dataset) for k, v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, v)) if rank > -1: torch.distributed.barrier() if rank in [-1, 0]: if os.path.exists(restore_file): os.remove(restore_file)
def train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc, fq_tune_only, model_controller): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train model", train_count) printlog(model) q_dataset = train_dataset_qc.q_dataset per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size if fq_tune_only: gradient_accumulation_steps = 1 num_train_epochs = 1 else: gradient_accumulation_steps = args.total_train_batch_size // train_batch_size num_train_epochs = args.num_train_epochs if rank < 0: #single process take all q_sampler = RandomSampler(q_dataset) q_dataloader = DataLoader(q_dataset, sampler=q_sampler, batch_size=train_batch_size, num_workers=4) else: #special sampler that divide samples between processes q_sampler = torch.utils.data.distributed.DistributedSampler(q_dataset, rank=rank) q_dataloader = DataLoader(q_dataset, sampler=q_sampler, batch_size=per_gpu_train_batch_size) steps_total = int( len(q_dataloader) // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and schedule named_params, groups = utils.make_param_groups( rank, model, args. freeze_list, #list or str with subnames to define frozen parameters args.learning_rate, #learning rate for no FQ parameters 0.01, # learning rate for FQ parameters fq_tune_only, #true if only FQ parameters will be optimized model_controller) optimizer = AdamW(groups, eps=1e-08, lr=args.learning_rate, weight_decay=0) def lr_lambda(current_step): p = float(current_step) / float(steps_total) return 1 - p scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: for n, p in named_params: printlog('param for tune', n) printlog("fq_tune_only", fq_tune_only) printlog("dataset size", len(q_dataset)) printlog("epoches", num_train_epochs) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size) printlog("n_gpu", args.n_gpu) printlog("world_size", world_size) printlog("gradient_accumulation_steps", gradient_accumulation_steps) printlog("total train batch size", train_batch_size * gradient_accumulation_steps) printlog("steps_total", steps_total) global_step = 1 model.zero_grad() indicators = collections.defaultdict(list) softplus = torch.nn.Softplus() loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) hnm_hist = {} for epoch in range(math.ceil(num_train_epochs)): indicators = collections.defaultdict(list) model.train() if model_t: model_t.train() if rank > -1: #set epoch to make different samples division betwen process for different epoches q_sampler.set_epoch(epoch) utils.sync_models(rank, model) for step, q_batch in enumerate(q_dataloader): epoch_fp = epoch + step / len(q_dataloader) if epoch_fp > num_train_epochs: break losses = [] context_ids_pos = q_batch[3] q_inputs = get_inputs(q_batch, args.device) q_outputs = model(**q_inputs, output_hidden_states=(model_t is not None)) q_vec = q_outputs[0] #get positive embeddings c_batch = train_dataset_qc.c_dataset[context_ids_pos.detach().data] c_inputs = get_inputs(c_batch, args.device) c_outputs = model(**c_inputs, output_hidden_states=(model_t is not None)) c_vec_pos = c_outputs[0] if model_t is not None: q_emb_s, q_hidden_s = q_outputs c_emb_s, c_hidden_s = c_outputs with torch.no_grad(): q_emb_t, q_hidden_t = model_t(**q_inputs, output_hidden_states=True) c_emb_t, c_hidden_t = model_t(**c_inputs, output_hidden_states=True) def align_and_loss_outputs(out_s, out_t): if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [ out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s) ] assert len(out_s) == len( out_t ), "can not align number of outputs between student and teacher" assert all( s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape) ), "output shapes for student and teacher are not the same" return [(s - t.detach()).pow(2).mean() for s, t in zip(out_s, out_t)] l_q = align_and_loss_outputs(q_hidden_s, q_hidden_t) l_c = align_and_loss_outputs(c_hidden_s, c_hidden_t) emb_loss = loss_cfg.get('emb_loss', '') if emb_loss == 'L2': l_q.append((q_emb_s - q_emb_t.detach()).pow(2).mean()) l_c.append((c_emb_s - c_emb_t.detach()).pow(2).mean()) elif emb_loss == 'L1': l_q.append((q_emb_s - q_emb_t.detach()).abs().mean()) l_c.append((c_emb_s - c_emb_t.detach()).abs().mean()) elif emb_loss.lower() not in ['', 'none', '0', 'disable']: raise Exception( 'emb_loss={} is unsupported'.format(emb_loss)) losses.extend([args.supervision_weight * l for l in l_c + l_q]) triplet_num = int(loss_cfg.get('triplet_num', 1)) if fq_tune_only: triplet_num = 0 if triplet_num > 0: #disable grad to select negatives with torch.no_grad(): hnm_scores = [] hnm_idxs = [] #check that current step has no HNM conext vector if global_step not in hnm_hist and args.hnm_num > 0: #generate the new one if world_size > 1 and (args.hnm_num % world_size) != 0: #aligh hnm_num per each replica hnm_plus = world_size - (args.hnm_num % world_size) args.hnm_num += hnm_plus logger.warning( "rank {} args.hnm_num increased by {} from {} to {} to be the same after division by {} replicas." .format(rank, hnm_plus, args.hnm_num - hnm_plus, args.hnm_num, world_size)) # generate random contexts to calc embedding context_ids_all = torch.randint( low=0, high=len(train_dataset_qc.c_dataset), size=[args.hnm_num]) if rank < 0: #single process take all context_ids = context_ids_all else: #broadcast one sigle indicies to all processes context_ids_all = context_ids_all.to(args.device) torch.distributed.broadcast(context_ids_all, 0) context_ids_all = context_ids_all.cpu() #each process take only small part to calc embedding s = ((rank + 0) * args.hnm_num) // world_size e = ((rank + 1) * args.hnm_num) // world_size context_ids = context_ids_all[s:e] batch_size = min(args.hnm_batch_size, context_ids.shape[0]) s, e = 0, batch_size c_outputs = [] while e > s: idx = context_ids.detach()[s:e] c_batch = train_dataset_qc.c_dataset[idx] inputs = get_inputs(c_batch, args.device) outputs = model(**inputs, output_hidden_states=False) c_outputs.append(outputs[0]) s, e = e, min(e + batch_size, context_ids.shape[0]) context_emb = torch.cat(c_outputs, dim=0) if rank < 0: # single process calculated all context_emb_all = context_emb else: context_emb_list = [ torch.zeros_like(context_emb) for _ in range(world_size) ] torch.distributed.all_gather( context_emb_list, context_emb) context_emb_all = torch.cat(context_emb_list, dim=0) hnm_hist[global_step] = (context_ids_all, context_emb_all) #check history size and crop the oldest one if len(hnm_hist) > args.hnm_hist_num: del hnm_hist[min(hnm_hist.keys())] #calc HNM scores for current question batch for hist_step, (c_idx, c_vec) in hnm_hist.items(): w = args.hnm_hist_alpha**(global_step - hist_step) t1 = q_vec[:, None, :] t2 = c_vec[None, :, :] d = (t1 - t2) score = -d.norm(2, dim=-1) score = score * w hnm_scores.append(score) hnm_idxs.append(c_idx) if hnm_scores: #choose the hardest negative if we have scores score = torch.cat(hnm_scores, dim=-1) idx = torch.cat(hnm_idxs, dim=-1) score = score.cpu() pos_mask = (context_ids_pos[:, None] == idx[None, :]).to( dtype=score.dtype, device=score.device) score = (1 - pos_mask) * score + pos_mask * score.min( ) #make positive context with small score to avoid chose it as hard neg hn_idx = score.argmax(dim=1, keepdim=True) context_ids_neg = idx[hn_idx] else: #just random selection in case of no scores for HNM size = (context_ids_pos.shape[0], 1) context_ids_neg = torch.randint( 0, len(train_dataset_qc.c_dataset) - 1, size) shift = (context_ids_neg >= context_ids_pos[:, None]) context_ids_neg = context_ids_neg + shift.to( dtype=context_ids_neg.dtype) d_pos = (q_vec - c_vec_pos).norm(2, dim=-1) # get negative embeddings and calc losses for neg_index in range(context_ids_neg.shape[1]): ids = context_ids_neg[:, neg_index] c_batch = train_dataset_qc.c_dataset[ids.detach()] inputs = get_inputs(c_batch, args.device) outputs = model(**inputs, output_hidden_states=False) c_vec_neg = outputs[0] for triplet_index in range(triplet_num): if triplet_index == 0: d_neg = (q_vec - c_vec_neg).norm(2, dim=-1) if triplet_index == 1: d_neg = (c_vec_pos - c_vec_neg).norm(2, dim=-1) d_diff = d_pos - d_neg indicators['dd' + str(triplet_index)].append( [v.mean().item() for v in (d_pos, d_neg, d_diff)]) l = softplus(d_diff) losses.append(l) del d_neg del d_pos #average over batch losses = [l.mean() for l in losses] l = sum(losses) / len(losses) (l / gradient_accumulation_steps).backward() indicators['loss'].append(l.item()) indicators['ll'].append([lll.item() for lll in losses]) #del losses del l if (step + 1) % gradient_accumulation_steps == 0: utils.sync_grads(rank, named_params, report_no_grad_params=(global_step == 1)) torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if global_step % 10 == 0: # Log metrics wall_time = epoch + step / len(q_dataloader) lrp = [ '{:.2f}'.format(i) for i in np.log10(scheduler.get_last_lr()) ] str_out = "{} ep {:.2f} lrp {}".format( train_count, epoch_fp, " ".join(lrp)) for k, v in indicators.items(): v = np.array(v) if len(v.shape) == 1: v = v[:, None] if rank > -1: #sync indicators vt = torch.tensor(v).to(args.device) torch.distributed.all_reduce( vt, op=torch.distributed.ReduceOp.SUM) v = vt.cpu().numpy() / float(world_size) str_out += " {} {}".format( k, " ".join(["{:.3f}".format(t) for t in v.mean(0)])) if 'score' in locals(): str_out += " SS {}".format(list(score.shape)) if 'time_last' in locals(): dt_iter = (time.time() - time_last) / len( indicators['loss']) dt_ep = dt_iter * len(q_dataloader) str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format( dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) time_last = time.time() indicators = collections.defaultdict(list) if rank in [-1, 0]: logger.info(str_out) if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) result_s = evaluate(args, model.eval(), test_dataset_qc) for k, v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, v)) if rank > -1: torch.distributed.barrier()
def _update_network(self, transitions): # pre-process the observation and goal o, o_next, g = transitions['obs'], transitions[ 'obs_next'], transitions['g'] transitions['obs'], transitions['g'] = self._preproc_og(o, g) transitions['obs_next'], transitions['g_next'] = self._preproc_og( o_next, g) # start to do the update obs_norm = self.o_norm.normalize(transitions['obs']) g_norm = self.g_norm.normalize(transitions['g']) inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) obs_next_norm = self.o_norm.normalize(transitions['obs_next']) g_next_norm = self.g_norm.normalize(transitions['g_next']) inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) # transfer them into the tensor inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) if self.config['cuda']: inputs_norm_tensor = inputs_norm_tensor.cuda() inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() actions_tensor = actions_tensor.cuda() r_tensor = r_tensor.cuda() # calculate the target Q value function with torch.no_grad(): # do the normalization # concatenate the stuffs actions_next = self.actor_target_network(inputs_next_norm_tensor) q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next) q_next_value = q_next_value.detach() target_q_value = r_tensor + self.config['gamma'] * q_next_value target_q_value = target_q_value.detach() # clip the q value clip_return = 1 / (1 - self.config['gamma']) target_q_value = torch.clamp(target_q_value, -clip_return, 0) # the q loss real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor) critic_loss = (target_q_value - real_q_value).pow(2).mean() # self.main.Q_tf ==> real_q_value # self.main.Q_pi_tf ==> self.critic_network(inputs_norm_tensor, actions_real) ==> approx_q_value # the actor loss action_l2 = self.config['action_l2'] actions_real = self.actor_network(inputs_norm_tensor) approx_q_value = self.critic_network(inputs_norm_tensor, actions_real) if self.bc_loss: # train with demonstrations using behavior cloning # choose only the demo buffer samples b_size = self.config['batch_size'] demo_b_size = self.config['demo_batch_size'] mask = np.concatenate( (np.zeros(b_size - demo_b_size), np.ones(demo_b_size)), axis=0) mask = torch.tensor(mask, dtype=torch.uint8, device=actions_real.device) if self.q_filter: # use Q-filter trick to perform BC only when needed with torch.no_grad(): mask &= (real_q_value > approx_q_value).squeeze() prm_loss_weight = self.config['prm_loss_weight'] cloning_loss = self.config['aux_loss_weight'] * ( actions_real[mask] - actions_tensor[mask]).pow(2).sum() else: # train without demonstrations prm_loss_weight = 1.0 cloning_loss = None actor_loss = -prm_loss_weight * approx_q_value.mean() actor_loss += prm_loss_weight * action_l2 * ( actions_real / self.env_params['action_max']).pow(2).mean() if cloning_loss is not None: actor_loss += cloning_loss # update actor network self.actor_optim.zero_grad() actor_loss.backward() sync_grads(self.actor_network) self.actor_optim.step() # update critic network self.critic_optim.zero_grad() critic_loss.backward() sync_grads(self.critic_network) self.critic_optim.step() res = dict(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) if cloning_loss is not None: res['cloning_loss'] = cloning_loss.item() return res
def train(self, epoch_start): global_step = 0 self.check_loss_raise = CheckLossRaise() for epoch in range(epoch_start, math.ceil(self.args.num_train_epochs)): self.indicators = collections.defaultdict(list) utils.sync_models(self.rank, self.model) self.model.train() self.model.zero_grad() grad_count = 0 if self.rank > -1: #set epoch to make different samples division betwen proceses for different epoches self.dataloader.sampler.set_epoch(epoch) for step, batch in enumerate(self.dataloader): epoch_fp = epoch + step/len(self.dataloader) if epoch_fp > self.args.num_train_epochs: break x_noise, x_clean = [t.to(self.args.device) for t in batch] #augment and mix signals x_clean, x_noise, x = self.mix_signals(x_clean, x_noise) #forward pass y_clean, Y_clean, _ = self.model(x) #calc specter for clean input signal tail_size = self.model.wnd_length - self.model.hop_length X_clean = self.model.encode(torch.nn.functional.pad(x_clean, (tail_size, 0))) # crop target and model output to align to each other sample_ahead = self.model.get_sample_ahead() spectre_ahead = self.model.ahead if sample_ahead > 0: x = x[:, :-sample_ahead] x_clean = x_clean[:, :-sample_ahead] y_clean = y_clean[:, sample_ahead:] if spectre_ahead > 0: Y_clean = Y_clean[:, :, :, spectre_ahead:] X_clean = X_clean[:, :, :, :-spectre_ahead] loss = self.loss(epoch_fp, y_clean, Y_clean, x_clean, X_clean) self.indicators['loss'].append(loss.item()) #calculate and accumulate gradients loss.backward() grad_count += 1 #continue if not all gradients were accumulated if grad_count < self.gradient_accumulation_steps: continue #make optimization step utils.sync_grads(self.rank, self.named_params, global_step==0, grad_count) self.optimizer.step() # make optimization step self.scheduler.step() # Update learning rate schedule global_step += 1 self.model.zero_grad() grad_count = 0 #make logs only after several steps if global_step % self.args.logacc != 0: continue #average indicator over GPUs and iterations self.aver_indicators() #check that negsisdr suddenly raise #if high raise detected then model parameters are restored and optimizer is reset self.check_loss_raise.check( self.indicators_mean["negsisdr"], self.named_params, self.optimizer ) self.log_indicators(epoch_fp) self.indicators = collections.defaultdict(list) self.save_and_eval_checkpoint(epoch+1)
def train(rank, args, model, model_t, train_dataset_qa, test_dataset_qa, scale_tune): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train model",train_count) printlog(model) per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size gradient_accumulation_steps = args.total_train_batch_size // train_batch_size num_train_epochs = args.num_train_epochs if scale_tune: gradient_accumulation_steps = 1 num_train_epochs = 1 if rank < 0: #single process take all samples sampler = RandomSampler(train_dataset_qa) dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=train_batch_size, num_workers=4) else: #special sampler that divide samples beween processes sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_qa, rank=rank) dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=per_gpu_train_batch_size) steps_total = int(len(dataloader) // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and schedule freeze_list = args.freeze_list.split(',') if args.freeze_list else [] named_params = [] for n, p in model.named_parameters(): if n.lower()!="none" and any(fn in n for fn in freeze_list): if rank in [-1, 0]: logger.warning("rank {} {} param is frozen and excluded from tune".format(rank,n)) continue named_params.append( (n, p) ) # split parameters to scale and the rest named_params_scale = [(n, p) for n, p in named_params if '.scale' in n] named_params_rest = [(n, p) for n, p in named_params if '.scale' not in n] if scale_tune: #keep only scale parameters named_params = named_params_scale named_params_rest = [] groups = [] if named_params_scale: groups.append({'params': [p for n, p in named_params_scale], 'lr': 0.01}) if named_params_rest: groups.append({'params': [p for n, p in named_params_rest], 'lr': args.learning_rate}) optimizer = AdamW( groups, eps=1e-08, lr=args.learning_rate, weight_decay=0) def lr_lambda(current_step): p = float(current_step) / float(steps_total) return 1 - p scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: for n,p in named_params: printlog('param for tune',n) printlog("scale_tune", scale_tune ) printlog("dataset size", len(train_dataset_qa) ) printlog("epoches", num_train_epochs ) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size ) printlog("n_gpu", args.n_gpu ) printlog("world_size", world_size ) printlog("gradient_accumulation_steps", gradient_accumulation_steps ) printlog("total train batch size", train_batch_size * gradient_accumulation_steps ) printlog("steps_total",steps_total ) global_step = 0 model.zero_grad() indicators = collections.defaultdict(list) softplus = torch.nn.Softplus() loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) if args.loss_cfg else dict() for epoch in range(math.ceil(num_train_epochs)): indicators = collections.defaultdict(list) model.train() set_output_hidden_states(rank, model, (model_t is not None)) utils.sync_models(rank, model) if model_t is not None: set_output_hidden_states(rank, model_t, True) model_t.train() if rank > -1: #set epoch to make different samples division betwen process for different epoches sampler.set_epoch(epoch) for step, batch in enumerate(dataloader): epoch_fp = epoch + step/len(dataloader) if epoch_fp > num_train_epochs: break epoch_fp = epoch + step/len(dataloader) losses = [] inputs = get_inputs(batch, args.device) targets = get_targets(batch, args.device) outputs = model(**inputs, **targets, output_hidden_states=(model_t is not None)) losses.append(outputs[0]) outputs = outputs[1:] if model_t is not None: with torch.no_grad(): outputs_t = model_t(**inputs, output_hidden_states=True) hidden_t = outputs_t[2] assert isinstance(hidden_t, (tuple,list)), "hidden states output is not detected right" assert len(hidden_t) == model_t.config.num_hidden_layers+1, "hidden states output is not detected right" if args.kd_weight>0: # Calculate knowladge distilation loss kd_losses = [] for logit_s,logit_t in zip(outputs[0:2],outputs_t[0:2]): T = 1 prob_t = torch.nn.functional.softmax(logit_t.detach() / T, dim=1) logprob_s = torch.nn.functional.log_softmax(logit_s / T, dim=1) kd_losses.append( -(logprob_s * prob_t).mean() * (T * T * prob_t.shape[1]) ) losses.append(args.kd_weight*sum(kd_losses)/len(kd_losses)) hidden_s = outputs[2] assert isinstance(hidden_s, (tuple,list)), "hidden states output is not detected right" assert len(hidden_s) == model.config.num_hidden_layers+1, "hidden states output is not detected right" def align_and_loss_outputs(out_s, out_t): if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [out_t[(i*(n_t-1))//(n_s-1)] for i in range(n_s)] assert len(out_s) == len(out_t), "can not align number of outputs between student and teacher" assert all(s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)), "output shapes for student and teacher are not the same" return [(s - t.detach()).pow(2).mean() for s,t in zip(out_s, out_t)] sw_losses = align_and_loss_outputs(hidden_s,hidden_t) losses.extend([args.supervision_weight*l for l in sw_losses]) #average over batch losses = [l.mean() for l in losses] l = sum(losses)/len(losses) indicators['loss'].append(l.item()) indicators['ll'].append([lll.item() for lll in losses]) (l/gradient_accumulation_steps).backward() del l if (step + 1) % gradient_accumulation_steps == 0: global_step += 1 utils.sync_grads(rank, named_params, report_no_grad_params=(global_step==1)) torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() if global_step % 50 == 0: # Log metrics wall_time = epoch + step / len(dataloader) lrp = " ".join(['{:.2f}'.format(t) for t in np.log10(scheduler.get_last_lr())]) str_out = "{} ep {:.2f} lrp {}".format(train_count, epoch_fp, lrp) for k,v in indicators.items(): v = np.array(v) if len(v.shape)==1: v = v[:,None] if rank>-1: #sync indicators vt = torch.tensor(v).to(args.device) torch.distributed.all_reduce(vt, op=torch.distributed.ReduceOp.SUM) v = vt.cpu().numpy() / float(world_size) str_out += " {} {}".format(k," ".join(["{:.3f}".format(t) for t in v.mean(0)])) if 'time_last' in locals(): #estimate processing times dt_iter = (time.time() - time_last) / len(indicators['loss']) dt_ep = dt_iter * len(dataloader) str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format(dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) time_last = time.time() indicators = collections.defaultdict(list) if rank in [-1, 0]: logger.info(str_out) if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) model.eval() set_output_hidden_states(rank, model, False) result_s = evaluate(args, model, test_dataset_qa) for k,v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, result_s[k])) if rank>-1: torch.distributed.barrier()