def hinton_train_without_label(teacher_model, student_model, T, optimizer, device, train_loader, is_debug=False): total_loss = 0. # One epoch step gradient for target optimizer.zero_grad() teacher_model.train() student_model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() if torch.cuda.device_count() > 1: teacher_logits = teacher_model.module.nforward(data) student_logits = student_model.module.nforward(data) else: teacher_logits = teacher_model.nforward(data) student_logits = student_model.nforward(data) loss = hinton_distillation_wo_ce(teacher_logits, student_logits, T) total_loss += float(loss.item()) loss.backward() # torch.nn.utils.clip_grad_value_(model.parameters(), 10) optimizer.step() if is_debug: break del loss del teacher_logits del student_logits # torch.cuda.empty_cache() return total_loss / len(train_loader)
def grl_multi_target_hinton_train_alt(current_ep, epochs, teacher_models, student_model, optimizer_das, optimizer_kd, device, source_dataloader, targets_dataloader, T, alpha, beta, gamma, batch_norm, is_cst, is_debug=False, **kwargs): logger = kwargs["logger"] if "logger_id" not in kwargs: logger_id = "" else: logger_id = kwargs["logger_id"] if batch_norm: for teacher_model in teacher_models: teacher_model.train() student_model.train() total_losses = torch.zeros(len(teacher_models)) teacher_da_temp_losses = torch.zeros(len(teacher_models)) kd_temp_losses = torch.zeros(len(teacher_models)) kd_target_loss = 0. kd_source_loss = 0. iter_targets = [0] * len(targets_dataloader) for i, d in enumerate(targets_dataloader): iter_targets[i] = iter(d) iter_source = iter(source_dataloader) for i in range(1, len(source_dataloader) + 1): data_source, label_source = iter_source.next() data_source = data_source.to(device) label_source = label_source.to(device) for ix, it in enumerate(iter_targets): try: data_target, _ = it.next() except StopIteration: it = iter(targets_dataloader[ix]) data_target, _ = it.next() if data_target.shape[0] != data_source.shape[0]: data_target = data_target[:data_source.shape[0]] data_target = data_target.to(device) optimizer_das[ix].zero_grad() p = float(i + (current_ep - 1) * len(source_dataloader)) / epochs / len(source_dataloader) delta = 2. / (1. + np.exp(-10 * p)) - 1 teacher_label_source_pred, teacher_source_loss_adv = teacher_models[ ix](data_source, delta) teacher_source_loss_cls = F.cross_entropy( F.log_softmax(teacher_label_source_pred, dim=1), label_source) _, teacher_target_loss_adv = teacher_models[ix](data_target, delta, source=False) teacher_loss_adv = teacher_source_loss_adv + teacher_target_loss_adv teacher_da_grl_loss = (1 - beta) * (teacher_source_loss_cls + gamma * teacher_loss_adv) teacher_da_temp_losses[ix] += teacher_da_grl_loss.mean().item() teacher_da_grl_loss.mean().backward() optimizer_das[ix].step() # May need to have 2 optimizers optimizer_das[ix].zero_grad() optimizer_kd.zero_grad() teacher_source_logits, _ = teacher_models[ix](data_source, delta, source=True) teacher_target_logits, _ = teacher_models[ix](data_target, delta, source=True) student_source_logits, _ = student_model(data_source, delta, source=True) student_target_logits, student_target_loss_adv = student_model( data_target, delta, source=False) source_kd_loss = hinton_distillation_sw(teacher_source_logits, student_source_logits, label_source, T, alpha).abs() if is_cst: target_kd_loss = hinton_distillation_wo_ce( teacher_target_logits, student_target_logits, T).abs() + alpha * student_target_loss_adv else: target_kd_loss = hinton_distillation_wo_ce( teacher_target_logits, student_target_logits, T).abs() kd_source_loss += source_kd_loss.mean().item() kd_target_loss += target_kd_loss.mean().item() kd_loss = beta * (target_kd_loss + source_kd_loss) kd_temp_losses[ix] += kd_loss.mean().item() total_losses[ix] += teacher_da_grl_loss.mean().item( ) + kd_loss.mean().item() kd_loss.mean().backward() optimizer_kd.step() optimizer_kd.zero_grad() if is_debug: break del kd_loss del teacher_da_grl_loss # torch.cuda.empty_cache() return total_losses / len(source_dataloader), teacher_da_temp_losses / len(source_dataloader), \ kd_temp_losses/ len(source_dataloader)
def mmd_hinton_train_alt(current_epoch, epochs, teacher_model, student_model, optimizer_da, optimizer_kd, device, source_dataloader, target_dataloader, T, alpha, beta, kd_loss_fn, is_debug=False, **kwargs): logger = kwargs["logger"] if "logger_id" not in kwargs: logger_id = "" else: logger_id = kwargs["logger_id"] #teacher_model.train() #student_model.train() total_loss = 0. teacher_da_temp_loss = 0. kd_temp_loss = 0. kd_target_loss = 0. kd_source_loss = 0. iter_source = iter(source_dataloader) iter_target = iter(target_dataloader) for i in range(1, len(source_dataloader) + 1): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if data_source.shape[0] != data_target.shape[0]: if data_target.shape[0] < source_dataloader.batch_size: iter_target = iter(target_dataloader) data_target, _ = iter_target.next() if data_source.shape[0] < source_dataloader.batch_size: data_target = data_target[:data_source.shape[0]] data_source, label_source = data_source.to(device), label_source.to( device) data_target = data_target.to(device) # Teacher domain adaptation optimizer_da.zero_grad() teacher_label_source_pred, teacher_loss_mmd, _ = teacher_model( data_source, data_target) teacher_source_loss_cls = F.nll_loss( F.log_softmax(teacher_label_source_pred, dim=1), label_source) gamma = 2 / (1 + np.exp(-10 * (i) / len(source_dataloader))) - 1 teacher_da_mmd_loss = (1 - beta) * (teacher_source_loss_cls + gamma * teacher_loss_mmd) teacher_da_temp_loss += teacher_da_mmd_loss.mean().item() # Possible to do end2end or alternative here: For now it's alternative teacher_da_mmd_loss.mean().backward() optimizer_da.step() # May need to have 2 optimizers optimizer_da.zero_grad() #Knowledge distillation: We only learn on target logits now optimizer_kd.zero_grad() teacher_source_logits, teacher_loss_mmd, teacher_target_logits = teacher_model( data_source, data_target) student_source_logits, student_loss_mmd, student_target_logits = student_model( data_source, data_target) source_kd_loss = hinton_distillation(teacher_source_logits, student_source_logits, label_source, T, alpha, kd_loss_fn).abs() target_kd_loss = hinton_distillation_wo_ce(teacher_target_logits, student_target_logits, T, kd_loss_fn).abs() kd_source_loss += source_kd_loss.mean().item() kd_target_loss += target_kd_loss.mean().item() kd_loss = beta * (target_kd_loss + source_kd_loss) kd_temp_loss += kd_loss.mean().item() total_loss += teacher_da_mmd_loss.mean().item() + kd_loss.mean().item() kd_loss.mean().backward() optimizer_kd.step() optimizer_kd.zero_grad() if logger is not None: logger.log_scalar("iter_total_training_loss".format(logger_id), teacher_da_mmd_loss.item() + kd_loss.item(), i) logger.log_scalar("iter_total_da_loss".format(logger_id), teacher_da_mmd_loss.item(), i) logger.log_scalar("iter_total_kd_loss".format(logger_id), kd_loss.item(), i) if is_debug: break del kd_loss del teacher_da_mmd_loss # torch.cuda.empty_cache() return total_loss / len(source_dataloader), teacher_da_temp_loss / len(source_dataloader), \ kd_temp_loss / len(source_dataloader), kd_source_loss / len(source_dataloader), kd_target_loss / len(source_dataloader)