Ejemplo n.º 1
0
    def forward(self, inputs):

        mask_logits1 = inputs['mask_logits1']  # (B, 2, N)
        mask_logits2 = inputs['mask_logits2']

        center_reg1 = inputs['center_reg1']
        center_reg2 = inputs['center_reg2']  # (B, 3)

        center1 = inputs['center1']  # (B, 3)
        center2 = inputs['center2']

        heading_scores1 = inputs['heading_scores1']  # (B, NH)
        heading_residuals_normalized1 = inputs[
            'heading_residuals_normalized1']  # (B, NH)
        heading_residuals1 = inputs['heading_residuals1']  # (B, NH)
        size_scores1 = inputs['size_scores1']  # (B, NS)
        size_residuals_normalized1 = inputs[
            'size_residuals_normalized1']  # (B, NS, 3)
        size_residuals1 = inputs['size_residuals1']  # (B, NS, 3)

        heading_scores2 = inputs['heading_scores2']  # (B, NH)
        heading_residuals_normalized2 = inputs[
            'heading_residuals_normalized2']  # (B, NH)
        heading_residuals2 = inputs['heading_residuals2']  # (B, NH)
        size_scores2 = inputs['size_scores2']  # (B, NS)
        size_residuals_normalized2 = inputs[
            'size_residuals_normalized2']  # (B, NS, 3)
        size_residuals2 = inputs['size_residuals2']  # (B, NS, 3)

        batch_size = center1.size(0)
        batch_id = torch.arange(batch_size, device=center1.device)

        # Basic Classification and Regression losses
        mask_loss = discrepancy_loss(mask_logits1, mask_logits2)
        heading_loss = discrepancy_loss(heading_scores1, heading_scores2)
        size_loss = discrepancy_loss(size_scores1, size_scores2)

        center_loss = F.smooth_l1_loss(center1, center2)
        center_reg_loss = F.smooth_l1_loss(center_reg1, center_reg2)

        # Summing up
        loss = mask_loss + self.box_loss_weight * (
            center_loss + center_reg_loss + heading_loss + size_loss)

        return loss
Ejemplo n.º 2
0
    def train(model, source_loader, target_loader, criterion, optimizer_g,
              optimizer_cls, scheduler_g, scheduler_cls, current_step, writer,
              cons):

        model.train()
        loss_total = 0
        loss_adv_total = 0
        data_total = 0

        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        for _ in trange(len(source_loader)):
            (inputs, targets), (inputs_t, _) = next(batch_iterator)

            if isinstance(inputs, dict):
                for k, v in inputs.items():
                    batch_size = v.size(0)
                    inputs[k] = v.to(configs.device, non_blocking=True)
            else:
                batch_size = inputs.size(0)
                inputs = inputs.to(configs.device, non_blocking=True)

            if isinstance(inputs_t, dict):
                for k, v in inputs_t.items():
                    batch_size = v.size(0)
                    inputs_t[k] = v.to(configs.device, non_blocking=True)
            else:
                batch_size = inputs_t.size(0)
                inputs_t = inputs_t.to(configs.device, non_blocking=True)

            if isinstance(targets, dict):
                for k, v in targets.items():
                    targets[k] = v.to(configs.device, non_blocking=True)
            else:
                targets = targets.to(configs.device, non_blocking=True)

            outputs = model(inputs)

            pred_t1, pred_t2 = model.module.inst_seg_net(
                {
                    'features': inputs_t['features'],
                    'one_hot_vectors': inputs_t['one_hot_vectors']
                },
                constant=cons,
                adaptation=True)

            loss_s = criterion(outputs, targets)

            # Adversarial loss
            loss_adv = -1 * discrepancy_loss(pred_t1, pred_t2)

            loss = loss_s + loss_adv
            loss.backward()
            optimizer_g.step()
            optimizer_cls.step()
            optimizer_g.zero_grad()
            optimizer_cls.zero_grad()

            loss_adv_total += loss_adv.item() * batch_size

            # Gen Training
            for _ in range(configs.train.gen_num_train):
                pred_t1, pred_t2 = model.module.inst_seg_net(
                    {
                        'features': inputs_t['features'],
                        'one_hot_vectors': inputs_t['one_hot_vectors']
                    },
                    constant=cons,
                    adaptation=True)
                loss_adv = -1 * discrepancy_loss(pred_t1, pred_t2)
                loss_adv.backward()
                loss_adv_total += loss_adv.item() * batch_size
                optimizer_g.step()
                optimizer_g.zero_grad()

            loss_total += loss_s.item() * batch_size
            data_total += batch_size

            writer.add_scalar('loss_s/train', loss_total / data_total,
                              current_step)
            writer.add_scalar('loss_adv/train', loss_adv_total / data_total,
                              current_step)
            current_step += batch_size

        if scheduler_g is not None:
            scheduler_g.step()

        if scheduler_cls is not None:
            scheduler_cls.step()