def forward(self, inputs): mask_logits1 = inputs['mask_logits1'] # (B, 2, N) mask_logits2 = inputs['mask_logits2'] center_reg1 = inputs['center_reg1'] center_reg2 = inputs['center_reg2'] # (B, 3) center1 = inputs['center1'] # (B, 3) center2 = inputs['center2'] heading_scores1 = inputs['heading_scores1'] # (B, NH) heading_residuals_normalized1 = inputs[ 'heading_residuals_normalized1'] # (B, NH) heading_residuals1 = inputs['heading_residuals1'] # (B, NH) size_scores1 = inputs['size_scores1'] # (B, NS) size_residuals_normalized1 = inputs[ 'size_residuals_normalized1'] # (B, NS, 3) size_residuals1 = inputs['size_residuals1'] # (B, NS, 3) heading_scores2 = inputs['heading_scores2'] # (B, NH) heading_residuals_normalized2 = inputs[ 'heading_residuals_normalized2'] # (B, NH) heading_residuals2 = inputs['heading_residuals2'] # (B, NH) size_scores2 = inputs['size_scores2'] # (B, NS) size_residuals_normalized2 = inputs[ 'size_residuals_normalized2'] # (B, NS, 3) size_residuals2 = inputs['size_residuals2'] # (B, NS, 3) batch_size = center1.size(0) batch_id = torch.arange(batch_size, device=center1.device) # Basic Classification and Regression losses mask_loss = discrepancy_loss(mask_logits1, mask_logits2) heading_loss = discrepancy_loss(heading_scores1, heading_scores2) size_loss = discrepancy_loss(size_scores1, size_scores2) center_loss = F.smooth_l1_loss(center1, center2) center_reg_loss = F.smooth_l1_loss(center_reg1, center_reg2) # Summing up loss = mask_loss + self.box_loss_weight * ( center_loss + center_reg_loss + heading_loss + size_loss) return loss
def train(model, source_loader, target_loader, criterion, optimizer_g, optimizer_cls, scheduler_g, scheduler_cls, current_step, writer, cons): model.train() loss_total = 0 loss_adv_total = 0 data_total = 0 batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader)) for _ in trange(len(source_loader)): (inputs, targets), (inputs_t, _) = next(batch_iterator) if isinstance(inputs, dict): for k, v in inputs.items(): batch_size = v.size(0) inputs[k] = v.to(configs.device, non_blocking=True) else: batch_size = inputs.size(0) inputs = inputs.to(configs.device, non_blocking=True) if isinstance(inputs_t, dict): for k, v in inputs_t.items(): batch_size = v.size(0) inputs_t[k] = v.to(configs.device, non_blocking=True) else: batch_size = inputs_t.size(0) inputs_t = inputs_t.to(configs.device, non_blocking=True) if isinstance(targets, dict): for k, v in targets.items(): targets[k] = v.to(configs.device, non_blocking=True) else: targets = targets.to(configs.device, non_blocking=True) outputs = model(inputs) pred_t1, pred_t2 = model.module.inst_seg_net( { 'features': inputs_t['features'], 'one_hot_vectors': inputs_t['one_hot_vectors'] }, constant=cons, adaptation=True) loss_s = criterion(outputs, targets) # Adversarial loss loss_adv = -1 * discrepancy_loss(pred_t1, pred_t2) loss = loss_s + loss_adv loss.backward() optimizer_g.step() optimizer_cls.step() optimizer_g.zero_grad() optimizer_cls.zero_grad() loss_adv_total += loss_adv.item() * batch_size # Gen Training for _ in range(configs.train.gen_num_train): pred_t1, pred_t2 = model.module.inst_seg_net( { 'features': inputs_t['features'], 'one_hot_vectors': inputs_t['one_hot_vectors'] }, constant=cons, adaptation=True) loss_adv = -1 * discrepancy_loss(pred_t1, pred_t2) loss_adv.backward() loss_adv_total += loss_adv.item() * batch_size optimizer_g.step() optimizer_g.zero_grad() loss_total += loss_s.item() * batch_size data_total += batch_size writer.add_scalar('loss_s/train', loss_total / data_total, current_step) writer.add_scalar('loss_adv/train', loss_adv_total / data_total, current_step) current_step += batch_size if scheduler_g is not None: scheduler_g.step() if scheduler_cls is not None: scheduler_cls.step()