def hinton_train(model, student_model, T, alpha, optimizer, device, train_loader, is_debug=False): total_loss = 0. # One epoch step gradient for target optimizer.zero_grad() model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() if torch.cuda.device_count() > 1: teacher_logits = model.module.nforward(data) student_logits = student_model.module.nforward(data) else: teacher_logits = model.nforward(data) student_logits = student_model.nforward(data) loss = hinton_distillation(teacher_logits, student_logits, target, T, alpha) total_loss += float(loss.item()) loss.backward() # torch.nn.utils.clip_grad_value_(model.parameters(), 10) optimizer.step() if is_debug: break del loss del teacher_logits del student_logits # torch.cuda.empty_cache() return total_loss / len(train_loader)
def mmd_hinton_train_alt(current_epoch, epochs, teacher_model, student_model, optimizer_da, optimizer_kd, device, source_dataloader, target_dataloader, T, alpha, beta, kd_loss_fn, is_debug=False, **kwargs): logger = kwargs["logger"] if "logger_id" not in kwargs: logger_id = "" else: logger_id = kwargs["logger_id"] #teacher_model.train() #student_model.train() total_loss = 0. teacher_da_temp_loss = 0. kd_temp_loss = 0. kd_target_loss = 0. kd_source_loss = 0. iter_source = iter(source_dataloader) iter_target = iter(target_dataloader) for i in range(1, len(source_dataloader) + 1): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if data_source.shape[0] != data_target.shape[0]: if data_target.shape[0] < source_dataloader.batch_size: iter_target = iter(target_dataloader) data_target, _ = iter_target.next() if data_source.shape[0] < source_dataloader.batch_size: data_target = data_target[:data_source.shape[0]] data_source, label_source = data_source.to(device), label_source.to( device) data_target = data_target.to(device) # Teacher domain adaptation optimizer_da.zero_grad() teacher_label_source_pred, teacher_loss_mmd, _ = teacher_model( data_source, data_target) teacher_source_loss_cls = F.nll_loss( F.log_softmax(teacher_label_source_pred, dim=1), label_source) gamma = 2 / (1 + np.exp(-10 * (i) / len(source_dataloader))) - 1 teacher_da_mmd_loss = (1 - beta) * (teacher_source_loss_cls + gamma * teacher_loss_mmd) teacher_da_temp_loss += teacher_da_mmd_loss.mean().item() # Possible to do end2end or alternative here: For now it's alternative teacher_da_mmd_loss.mean().backward() optimizer_da.step() # May need to have 2 optimizers optimizer_da.zero_grad() #Knowledge distillation: We only learn on target logits now optimizer_kd.zero_grad() teacher_source_logits, teacher_loss_mmd, teacher_target_logits = teacher_model( data_source, data_target) student_source_logits, student_loss_mmd, student_target_logits = student_model( data_source, data_target) source_kd_loss = hinton_distillation(teacher_source_logits, student_source_logits, label_source, T, alpha, kd_loss_fn).abs() target_kd_loss = hinton_distillation_wo_ce(teacher_target_logits, student_target_logits, T, kd_loss_fn).abs() kd_source_loss += source_kd_loss.mean().item() kd_target_loss += target_kd_loss.mean().item() kd_loss = beta * (target_kd_loss + source_kd_loss) kd_temp_loss += kd_loss.mean().item() total_loss += teacher_da_mmd_loss.mean().item() + kd_loss.mean().item() kd_loss.mean().backward() optimizer_kd.step() optimizer_kd.zero_grad() if logger is not None: logger.log_scalar("iter_total_training_loss".format(logger_id), teacher_da_mmd_loss.item() + kd_loss.item(), i) logger.log_scalar("iter_total_da_loss".format(logger_id), teacher_da_mmd_loss.item(), i) logger.log_scalar("iter_total_kd_loss".format(logger_id), kd_loss.item(), i) if is_debug: break del kd_loss del teacher_da_mmd_loss # torch.cuda.empty_cache() return total_loss / len(source_dataloader), teacher_da_temp_loss / len(source_dataloader), \ kd_temp_loss / len(source_dataloader), kd_source_loss / len(source_dataloader), kd_target_loss / len(source_dataloader)