def cleit_train_step(model, reference_encoder, transmitter, s_batch, t_batch, device, optimizer, alpha, history, scheduler=None): model.zero_grad() model.train() transmitter.zero_grad() reference_encoder.zero_grad() transmitter.train() reference_encoder.eval() s_x = s_batch[0].to(device) s_y = s_batch[1].to(device) t_x = t_batch[0].to(device) t_y = t_batch[1].to(device) x_m_code = transmitter(model.encoder(t_x)) x_g_code = reference_encoder(s_x) code_loss = contrastive_loss(y_true=x_g_code, y_pred=x_m_code, device=device) loss = masked_simse(preds=model(t_x), labels=t_y) + alpha * code_loss optimizer.zero_grad() loss.backward() optimizer.step() if scheduler is not None: scheduler.step() history['loss'].append(loss.cpu().detach().item()) history['code_loss'].append(code_loss.cpu().detach().item()) return history
def cleit_train_step(ae, reference_encoder, batch, device, optimizer, history, scheduler=None): ae.zero_grad() reference_encoder.zero_grad() ae.train() reference_encoder.eval() x_m = batch[0].to(device) x_g = batch[1].to(device) loss_dict = ae.loss_function(*ae(x_m)) optimizer.zero_grad() x_m_code = ae.encoder(x_m) x_g_code = reference_encoder(x_g) code_loss = contrastive_loss(y_true=x_g_code, y_pred=x_m_code, device=device) loss = loss_dict['loss'] + code_loss optimizer.zero_grad() loss.backward() optimizer.step() if scheduler is not None: scheduler.step() for k, v in loss_dict.items(): history[k].append(v) history['code_loss'].append(code_loss.cpu().detach().item()) return history
def cleit_train_step(ae, reference_encoder, transmitter, batch, device, optimizer, history, scheduler=None): ae.zero_grad() transmitter.zero_grad() reference_encoder.zero_grad() ae.train() transmitter.train() reference_encoder.eval() x_m = batch[0].to(device) x_g = batch[1].to(device) loss_dict = ae.loss_function(*ae(x_m)) optimizer.zero_grad() x_m_code = transmitter(ae.encoder(x_m)) x_g_code = reference_encoder(x_g) code_loss = contrastive_loss(y_true=x_g_code, y_pred=x_m_code, device=device) loss = loss_dict['loss'] + code_loss optimizer.zero_grad() loss.backward() # if torch.isnan(x_m_code).any(): # print(loss_dict) # print(code_loss) # print(torch.sum(torch.isnan(x_m_code),dim=1)) # print(torch.isnan(x_m).any()) # print(torch.isnan(list(ae.encoder.modules())[0][-1].weight).any()) # print(list(ae.encoder.modules())[0][-1].weight.grad) # print("="*20) # cleit_params = [ # ae.parameters(), # transmitter.parameters() # ] # torch.nn.utils.clip_grad_norm_(chain(*cleit_params), 0.1) optimizer.step() #torch.autograd.set_detect_anomaly(True) if scheduler is not None: scheduler.step() for k, v in loss_dict.items(): history[k].append(v) history['code_loss'].append(code_loss.cpu().detach().item()) return history
def ceae_train_step(p_ae, t_ae, transmitter, batch, device, optimizer, history, scheduler=None): p_ae.zero_grad() t_ae.zero_grad() transmitter.zero_grad() p_ae.train() t_ae.train() transmitter.train() x_p = batch[0].to(device) x_t = batch[1].to(device) p_loss_dict = p_ae.loss_function(*p_ae(x_p)) t_loss_dict = t_ae.loss_function(*t_ae(x_t)) optimizer.zero_grad() x_t_code = t_ae.encode(x_t) # x_p_code = p_ae.encode(x_p) code_loss = contrastive_loss(y_true=transmitter(x_t_code), y_pred=x_p_code, device=device) # loss = loss_dict['loss'] # loss = code_loss loss = p_loss_dict['loss'] + t_loss_dict['loss'] + code_loss optimizer.zero_grad() loss.backward() optimizer.step() if scheduler is not None: scheduler.step() for k, v in p_loss_dict.items(): history[f'p_{k}'].append(v) for k, v in t_loss_dict.items(): history[f't_{k}'].append(v) history['code_loss'].append(code_loss.cpu().detach().item()) return history
def ceae_train_step(ae, transmitter, batch, device, optimizer, history, scheduler=None): ae.zero_grad() transmitter.zero_grad() ae.train() transmitter.train() x_p = batch[0].to(device) x_t = batch[1].to(device) loss_dict = ae.loss_function(*ae(x_p)) optimizer.zero_grad() x_t_code = x_t # x_p_code = ae.encode(x_p) code_loss = contrastive_loss(y_true=x_t_code, y_pred=transmitter(x_p_code), device=device) # loss = loss_dict['loss'] loss = loss_dict['loss'] + code_loss # loss = code_loss optimizer.zero_grad() loss.backward() optimizer.step() if scheduler is not None: scheduler.step() for k, v in loss_dict.items(): history[k].append(v) history['code_loss'].append(code_loss.cpu().detach().item()) return history