Beispiel #1
0
    def train(model, source_loader, target_loader, criterion, optimizer_g,
              optimizer_cls, scheduler_g, scheduler_cls, current_step, writer,
              cons):

        model.train()
        loss_total = 0
        loss_adv_total = 0
        data_total = 0

        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        for _ in trange(len(source_loader)):
            (inputs, targets), (inputs_t, _) = next(batch_iterator)

            if isinstance(inputs, dict):
                for k, v in inputs.items():
                    batch_size = v.size(0)
                    inputs[k] = v.to(configs.device, non_blocking=True)
            else:
                batch_size = inputs.size(0)
                inputs = inputs.to(configs.device, non_blocking=True)

            if isinstance(inputs_t, dict):
                for k, v in inputs_t.items():
                    batch_size = v.size(0)
                    inputs_t[k] = v.to(configs.device, non_blocking=True)
            else:
                batch_size = inputs_t.size(0)
                inputs_t = inputs_t.to(configs.device, non_blocking=True)

            if isinstance(targets, dict):
                for k, v in targets.items():
                    targets[k] = v.to(configs.device, non_blocking=True)
            else:
                targets = targets.to(configs.device, non_blocking=True)

            outputs = model(inputs)

            pred_t1, pred_t2 = model.module.inst_seg_net(
                {
                    'features': inputs_t['features'],
                    'one_hot_vectors': inputs_t['one_hot_vectors']
                },
                constant=cons,
                adaptation=True)

            loss_s = criterion(outputs, targets)

            # Adversarial loss
            loss_adv = -1 * discrepancy_loss(pred_t1, pred_t2)

            loss = loss_s + loss_adv
            loss.backward()
            optimizer_g.step()
            optimizer_cls.step()
            optimizer_g.zero_grad()
            optimizer_cls.zero_grad()

            loss_adv_total += loss_adv.item() * batch_size

            # Gen Training
            for _ in range(configs.train.gen_num_train):
                pred_t1, pred_t2 = model.module.inst_seg_net(
                    {
                        'features': inputs_t['features'],
                        'one_hot_vectors': inputs_t['one_hot_vectors']
                    },
                    constant=cons,
                    adaptation=True)
                loss_adv = -1 * discrepancy_loss(pred_t1, pred_t2)
                loss_adv.backward()
                loss_adv_total += loss_adv.item() * batch_size
                optimizer_g.step()
                optimizer_g.zero_grad()

            loss_total += loss_s.item() * batch_size
            data_total += batch_size

            writer.add_scalar('loss_s/train', loss_total / data_total,
                              current_step)
            writer.add_scalar('loss_adv/train', loss_adv_total / data_total,
                              current_step)
            current_step += batch_size

        if scheduler_g is not None:
            scheduler_g.step()

        if scheduler_cls is not None:
            scheduler_cls.step()
    def train(model, source_loader, target_loader, criterion, discrepancy,
              optimizer_g,
              optimizer_cls, optimizer_dis, scheduler_g, scheduler_cls,
              current_step, writer, cons):

        model.train()
        loss_total = 0
        loss_adv_total = 0
        loss_node_total = 0
        data_total = 0

        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        for _ in trange(len(source_loader)):
            (inputs, targets), (inputs_t, _) = next(batch_iterator)

            if isinstance(inputs, dict):
                for k, v in inputs.items():
                    batch_size = v.size(0)
                    inputs[k] = v.to(configs.device, non_blocking=True)
            else:
                batch_size = inputs.size(0)
                inputs = inputs.to(configs.device, non_blocking=True)

            if isinstance(inputs_t, dict):
                for k, v in inputs_t.items():
                    batch_size = v.size(0)
                    inputs_t[k] = v.to(configs.device, non_blocking=True)
            else:
                batch_size = inputs_t.size(0)
                inputs_t = inputs_t.to(configs.device, non_blocking=True)

            if isinstance(targets, dict):
                for k, v in targets.items():
                    targets[k] = v.to(configs.device, non_blocking=True)
            else:
                targets = targets.to(configs.device, non_blocking=True)

            optimizer_g.zero_grad()
            optimizer_cls.zero_grad()
            optimizer_dis.zero_grad()

            outputs = model(inputs)

            outputs_target = model(inputs_t, cons, True)

            loss_s = criterion(outputs, targets)

            # Adversarial loss
            loss_adv = -1 * discrepancy(outputs_target)

            loss = loss_s + loss_adv
            loss.backward()
            optimizer_g.step()
            optimizer_cls.step()
            optimizer_g.zero_grad()
            optimizer_cls.zero_grad()

            # Local Alignment
            feat_node_s = model(inputs, adaptation_s=True)

            feat_node_t = model(inputs_t, adaptation_t=True)

            sigma_list = [0.01, 0.1, 1, 10, 100]
            loss_node_adv = (mmd.mix_rbf_mmd2(feat_node_s["seg_mmd_feat"],
                                              feat_node_t["seg_mmd_feat"],
                                              sigma_list) +
                             mmd.mix_rbf_mmd2(feat_node_s["cen_mmd_feat"],
                                              feat_node_t["cen_mmd_feat"],
                                              sigma_list) +
                             mmd.mix_rbf_mmd2(feat_node_s["box_mmd_feat"],
                                              feat_node_t["box_mmd_feat"],
                                              sigma_list))
            loss = loss_node_adv

            loss.backward()
            optimizer_dis.step()
            optimizer_dis.zero_grad()

            loss_total += loss_s.item() * batch_size
            loss_adv_total += loss_adv.item() * batch_size
            loss_node_total += loss_node_adv.item() * batch_size
            data_total += batch_size

            writer.add_scalar('loss_s/train', loss_total / data_total,
                              current_step)
            writer.add_scalar('loss_adv/train', loss_adv_total / data_total,
                              current_step)
            writer.add_scalar('loss_node/train', loss_node_total / data_total,
                              current_step)
            current_step += batch_size

        if scheduler_g is not None:
            scheduler_g.step()

        if scheduler_cls is not None:
            scheduler_cls.step()