def _train_epoch(model, epoch, tqdm_data, kl_weight, optimizer=None): if optimizer is None: model.eval() else: model.train() kl_loss_values = mosesvocab.CircularBuffer(1000) recon_loss_values = mosesvocab.CircularBuffer(1000) loss_values = mosesvocab.CircularBuffer(1000) for i, input_batch in enumerate(tqdm_data): input_batch = tuple(data.cuda() for data in input_batch) # Forwardd kl_loss, recon_loss, _ = model(input_batch) kl_loss = torch.sum(kl_loss, 0) recon_loss = torch.sum(recon_loss, 0) loss = kl_weight * kl_loss + recon_loss # Backward if optimizer is not None: optimizer.zero_grad() # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() loss.backward() clip_grad_norm_((p for p in model.parameters() if p.requires_grad), 50) optimizer.step() # Log kl_loss_values.add(kl_loss.item()) recon_loss_values.add(recon_loss.item()) loss_values.add(loss.item()) lr = (optimizer.param_groups[0]['lr'] if optimizer is not None else None) # Update tqdm kl_loss_value = kl_loss_values.mean() recon_loss_value = recon_loss_values.mean() loss_value = loss_values.mean() postfix = [ f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}', f'recon={recon_loss_value:.5f})', f'klw={kl_weight:.5f} lr={lr:.5f}' ] tqdm_data.set_postfix_str(' '.join(postfix)) postfix = { 'epoch': epoch, 'kl_weight': kl_weight, 'lr': lr, 'kl_loss': kl_loss_value, 'recon_loss': recon_loss_value, 'loss': loss_value, 'mode': 'Eval' if optimizer is None else 'Train' } return postfix
def _train_epoch_binding(model, epoch, tqdm_data, optimizer=None): if optimizer is None: model.eval() else: model.train() loss_values = mosesvocab.CircularBuffer(10) binding_loss_values = mosesvocab.CircularBuffer(10) for i, (input_batch, binding) in enumerate(tqdm_data): input_batch = tuple(data.cuda() for data in input_batch) binding = binding.cuda().view(-1, 1) # Forwardd _, _, binding_loss, _ = model(input_batch, binding) binding_loss = torch.sum(binding_loss, 0) loss = binding_loss # Backward if optimizer is not None: optimizer.zero_grad() # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() loss.backward() clip_grad_norm_((p for p in model.parameters() if p.requires_grad), 50) optimizer.step() # Log loss_values.add(loss.item()) binding_loss_values.add(binding_loss.item()) lr = (optimizer.param_groups[0]['lr'] if optimizer is not None else None) # Update tqdm loss_value = loss_values.mean() binding_loss_value = binding_loss_values.mean() postfix = [f'loss={loss_value:.5f}', f'bloss={binding_loss_value:.5f}'] tqdm_data.set_postfix_str(' '.join(postfix)) postfix = { 'epoch': epoch, 'lr': lr, 'loss': loss_value, 'mode': 'Eval' if optimizer is None else 'Train' } return postfix
def _train_epoch_binding(model, epoch, tqdm_data, kl_weight, encoder_optim, decoder_optim): model.train() kl_loss_values = mosesvocab.CircularBuffer(10) recon_loss_values = mosesvocab.CircularBuffer(10) loss_values = mosesvocab.CircularBuffer(10) for i, (input_batch, _) in enumerate(tqdm_data): if epoch < 10: if i % 1 == 0: for (input_batch_, _) in train_loader_agg_tqdm: encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() input_batch_ = tuple(data.cuda() for data in input_batch_) # Forwardd kl_loss, recon_loss, _, logvar, x, y = model(input_batch_) kl_loss = torch.sum(kl_loss, 0) recon_loss = torch.sum(recon_loss, 0) _, predict = torch.max(F.softmax(y, dim=-1), -1) correct = float( (x == predict).sum().cpu().detach().item()) / float( x.shape[0] * x.shape[1]) kl_weight = min(kl_weight * 1e-1 + 1e-3, 1) # kl_weight = 1 loss = kl_weight * kl_loss + recon_loss # loss = kl_loss + recon_loss loss.backward() clip_grad_norm_( (p for p in model.parameters() if p.requires_grad), 25) encoder_optimizer.step() loss_value = loss.item() kl_loss_value = kl_loss.item() recon_loss_value = recon_loss.item() postfix = [ f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}', f'recon={recon_loss_value:.5f})', f'correct={correct:.5f}' ] train_loader_agg_tqdm.set_postfix_str(' '.join(postfix)) encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() input_batch = tuple(data.cuda() for data in input_batch) # Forwardd kl_loss, recon_loss, _, logvar, x, y = model(input_batch) _, predict = torch.max(F.softmax(y, dim=-1), -1) correct = float((x == predict).sum().cpu().detach().item()) / float( x.shape[0] * x.shape[1]) kl_loss = torch.sum(kl_loss, 0) recon_loss = torch.sum(recon_loss, 0) # kl_weight = min(kl_weight + 1e-3,1) kl_weight = 1 loss = recon_loss # loss = kl_loss + recon_loss loss.backward() clip_grad_norm_((p for p in model.parameters() if p.requires_grad), 50) if epoch >= 10: loss += kl_weight * kl_loss encoder_optimizer.step() decoder_optimizer.step() # Log kl_loss_values.add(kl_loss.item()) recon_loss_values.add(recon_loss.item()) loss_values.add(loss.item()) lr = (encoder_optim.param_groups[0]['lr'] if encoder_optim is not None else None) # Update tqdm kl_loss_value = kl_loss_values.mean() recon_loss_value = recon_loss_values.mean() loss_value = loss_values.mean() postfix = [ f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}', f'recon={recon_loss_value:.5f})', f'klw={kl_weight:.5f} lr={lr:.5f}' f'correct={correct:.5f}' ] tqdm_data.set_postfix_str(' '.join(postfix)) postfix = { 'epoch': epoch, 'kl_weight': kl_weight, 'lr': lr, 'kl_loss': kl_loss_value, 'recon_loss': recon_loss_value, 'loss': loss_value } return postfix
def _train_epoch_binding(model, epoch, tqdm_data, kl_weight, optimizer=None): model.eval() if optimizer is None: bindingmodel.eval() else: bindingmodel.train() kl_loss_values = mosesvocab.CircularBuffer(10) recon_loss_values = mosesvocab.CircularBuffer(10) loss_values = mosesvocab.CircularBuffer(10) binding_loss_values = mosesvocab.CircularBuffer(10) for i, (input_batch, binding) in enumerate(tqdm_data): input_batch = tuple(data.cuda() for data in input_batch) binding = binding.cuda().view(-1, 1) # Forwardd kl_loss, recon_loss, _, z = model(input_batch, binding) b_pred = bindingmodel(z.detach()) weights = torch.zeros(binding.shape) class_weights = torch.zeros(binding.shape) for i in range(binding.shape[0]): if binding[i] > 0.35: weights[i] = 1.0 class_weights[i] = 5.0 else: weights[i] = 0 class_weights[i] = 0.5 weights = weights.cuda() class_weights = class_weights.cuda() b_pred = F.binary_cross_entropy_with_logits(b_pred, weights, pos_weight=class_weights) kl_loss = torch.sum(kl_loss, 0) recon_loss = torch.sum(recon_loss, 0) binding_loss = torch.sum(b_pred, 0) loss_weight = 0 if epoch < 5: loss_weight = 0 else: loss_weight = kl_weight loss = binding_loss # Backward if optimizer is not None: optimizer.zero_grad() # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() loss.backward() clip_grad_norm_(bindingmodel.parameters(), 50) optimizer.step() # Log kl_loss_values.add(kl_loss.item()) recon_loss_values.add(recon_loss.item()) loss_values.add(loss.item()) binding_loss_values.add(binding_loss.item()) lr = (optimizer.param_groups[0]['lr'] if optimizer is not None else None) # Update tqdm kl_loss_value = kl_loss_values.mean() recon_loss_value = recon_loss_values.mean() loss_value = loss_values.mean() binding_loss_value = binding_loss_values.mean() postfix = [ f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}', f'recon={recon_loss_value:.5f})', f'klw={kl_weight:.5f} lr={lr:.5f}' f'bloss={binding_loss_value:.5f}' ] tqdm_data.set_postfix_str(' '.join(postfix)) postfix = { 'epoch': epoch, 'kl_weight': kl_weight, 'lr': lr, 'kl_loss': kl_loss_value, 'recon_loss': recon_loss_value, 'loss': loss_value, 'mode': 'Eval' if optimizer is None else 'Train' } return postfix